pytorch

Форк
0
/
OperatorCallTracer.cpp 
29 строк · 903.0 Байт
1
#include <torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h>
2

3
namespace torch {
4
namespace jit {
5
namespace mobile {
6
OperatorCallTracer::OperatorCallTracer() {
7
  getCalledOperators().withLock([](std::set<std::string>& called_operators) {
8
    called_operators.clear();
9
  });
10

11
  auto recorder_cb =
12
      [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
13
    c10::optional<c10::OperatorName> op_name = fn.operator_name();
14
    if (op_name.has_value()) {
15
      getCalledOperators().withLock(
16
          [op_name](std::set<std::string>& called_operators) {
17
            called_operators.insert(c10::toString(*op_name));
18
          });
19
    }
20
    return nullptr;
21
  };
22

23
  handle_ = at::addGlobalCallback(at::RecordFunctionCallback(recorder_cb)
24
                                      .scopes({at::RecordScope::FUNCTION}));
25
}
26

27
} // namespace mobile
28
} // namespace jit
29
} // namespace torch
30

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.