pytorch
1#include <torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h>
2#include <mutex>
3
4namespace torch {
5namespace jit {
6namespace mobile {
7CustomClassTracer::CustomClassTracer() {
8auto recorder_cb =
9[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
10std::string name = fn.name();
11getLoadedClasses().withLock(
12[&name](CustomClassTracer::custom_classes_type& custom_classes) {
13custom_classes.insert(name);
14});
15return nullptr;
16};
17
18handle_ = at::addGlobalCallback(at::RecordFunctionCallback(recorder_cb)
19.scopes({at::RecordScope::CUSTOM_CLASS}));
20}
21
22c10::Synchronized<CustomClassTracer::custom_classes_type>& CustomClassTracer::
23getLoadedClasses() {
24static c10::Synchronized<custom_classes_type> loaded_classes;
25return loaded_classes;
26}
27
28} // namespace mobile
29} // namespace jit
30} // namespace torch
31