pytorch

Форк
0
/
backend_resolver.cpp 
71 строка · 2.1 Кб
1
#include <torch/csrc/jit/backends/backend_resolver.h>
2
#include <torch/csrc/jit/frontend/sugared_value.h>
3
#include <torch/custom_class.h>
4

5
namespace torch {
6
namespace jit {
7
namespace {
8
// Essentially ClassNamespaceValue from import_source.cpp without the
9
// SourceImporterImpl reference. This helps resolve the
10
// __torch__.torch.classes.backends.{backend_name} symbols in the generated code
11
// for the LoweredModule.
12
struct ClassNamespaceValue : public SugaredValue {
13
  explicit ClassNamespaceValue(c10::QualifiedName name)
14
      : basename_(std::move(name)) {}
15

16
  std::shared_ptr<SugaredValue> attr(
17
      const SourceRange& loc,
18
      GraphFunction& m,
19
      const std::string& name) override {
20
    auto fullName = c10::QualifiedName(basename_, name);
21

22
    // Check to see if it is a custom class.
23
    if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
24
      return std::make_shared<ClassValue>(custom_class);
25
    }
26

27
    // If it's not a custom class, assume it's another namespace
28
    return std::make_shared<ClassNamespaceValue>(std::move(fullName));
29
  }
30

31
  std::string kind() const override {
32
    return "Class Namespace";
33
  }
34

35
 private:
36
  c10::QualifiedName basename_;
37
};
38

39
// A resolver just for resolving custom backend class lookups in the
40
// LoweredModule classes generated by the rest of the cdoe in this file.
41
struct LoweredModuleResolver : public Resolver {
42
  std::shared_ptr<SugaredValue> resolveValue(
43
      const std::string& name,
44
      GraphFunction& m,
45
      const SourceRange& loc) override {
46
    if (name == "torch") {
47
      return std::make_shared<BuiltinModule>("aten");
48
    } else if (name == "__torch__") {
49
      return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
50
    } else if (name == "Exception") {
51
      return std::make_shared<ExceptionValue>(name);
52
    }
53

54
    return nullptr;
55
  }
56

57
  TypePtr resolveType(const std::string& name, const SourceRange& loc)
58
      override {
59
    return nullptr;
60
  }
61
};
62
} // namespace
63

64
std::shared_ptr<Resolver> loweredModuleResolver() {
65
  std::shared_ptr<Resolver> resolver =
66
      std::make_shared<LoweredModuleResolver>();
67
  return resolver;
68
}
69

70
} // namespace jit
71
} // namespace torch
72

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.