pytorch

Форк
0
132 строки · 4.8 Кб
1
#include <torch/csrc/jit/operator_upgraders/version_map.h>
2

3
#include <algorithm>
4
#include <string>
5
#include <unordered_map>
6
#include <vector>
7

8
namespace torch::jit {
9

10
// this flag is used to make sure the elements in the version map
11
// are sorted according to when the upgraders are introduced.
12
static bool isVersionMapSorted = false;
13

14
// Main entry point for all operators that have valid upgraders.
15
// Note for developers: The list of upgraders should be SORTED
16
// by the version number where the upgrader is registered.
17
static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersionMap(
18
    {{"aten::logspace",
19
      {{9,
20
        "logspace_0_8",
21
        "aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
22
     {"aten::logspace.out",
23
      {{9,
24
        "logspace_out_0_8",
25
        "aten::logspace.out(Scalar start, Scalar end, int? steps=None, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)"}}},
26
     {"aten::linspace",
27
      {{8,
28
        "linspace_0_7",
29
        "aten::linspace(Scalar start, Scalar end, int? steps=None, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
30
     {"aten::linspace.out",
31
      {{8,
32
        "linspace_out_0_7",
33
        "aten::linspace.out(Scalar start, Scalar end, int? steps=None, *, Tensor(a!) out) -> Tensor(a!)"}}},
34
     {"aten::div.Tensor",
35
      {{4,
36
        "div_Tensor_0_3",
37
        "aten::div.Tensor(Tensor self, Tensor other) -> Tensor"}}},
38
     {"aten::div.Tensor_mode",
39
      {{4,
40
        "div_Tensor_mode_0_3",
41
        "aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"}}},
42
     {"aten::div.Scalar",
43
      {{4,
44
        "div_Scalar_0_3",
45
        "aten::div.Scalar(Tensor self, Scalar other) -> Tensor"}}},
46
     {"aten::div.Scalar_mode",
47
      {{4,
48
        "div_Scalar_mode_0_3",
49
        "aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor"}}},
50
     {"aten::div.out",
51
      {{4,
52
        "div_out_0_3",
53
        "aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"}}},
54
     {"aten::div.out_mode",
55
      {{4,
56
        "div_out_mode_0_3",
57
        "aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)"}}},
58
     {"aten::div_.Tensor",
59
      {{4,
60
        "div__Tensor_0_3",
61
        "aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"}}},
62
     {"aten::div_.Tensor_mode",
63
      {{4,
64
        "div__Tensor_mode_0_3",
65
        "aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)"}}},
66
     {"aten::div_.Scalar",
67
      {{4,
68
        "div__Scalar_0_3",
69
        "aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"}}},
70
     {"aten::div_.Scalar_mode",
71
      {{4,
72
        "div__Scalar_mode_0_3",
73
        "aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)"}}},
74
     {"aten::full",
75
      {{5,
76
        "full_0_4",
77
        "aten::full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
78
     {"aten::full.names",
79
      {{5,
80
        "full_names_0_4",
81
        "aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
82
     {"aten::full.out",
83
      {{5,
84
        "full_out_0_4",
85
        "aten::full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)"}}},
86
     {"aten::gelu", {{10, "gelu_0_9", "aten::gelu(Tensor self) -> Tensor"}}},
87
     {"aten::gelu.out",
88
      {{10,
89
        "gelu_out_0_9",
90
        "aten::gelu.out(Tensor self, *, Tensor(a!) out) -> Tensor"}}}});
91

92
const std::unordered_map<std::string, std::vector<UpgraderEntry>>&
93
get_operator_version_map() {
94
  if (!isVersionMapSorted) {
95
    for (auto entry : operatorVersionMap) {
96
      std::sort(
97
          entry.second.begin(),
98
          entry.second.end(),
99
          [](const auto& a, const auto& b) {
100
            return a.bumped_at_version > b.bumped_at_version;
101
          });
102
    }
103
    isVersionMapSorted = true;
104
  }
105
  return operatorVersionMap;
106
}
107

108
void test_only_add_entry(const std::string& op_name, UpgraderEntry entry) {
109
  test_only_reset_flag();
110
  operatorVersionMap[op_name].emplace_back(std::move(entry));
111
}
112

113
void test_only_remove_entry(const std::string& op_name) {
114
  test_only_reset_flag();
115
  operatorVersionMap.erase(op_name);
116
}
117

118
void test_only_reset_flag() {
119
  isVersionMapSorted = false;
120
}
121

122
static bool calculatePackageVersionBasedOnUpgraders = false;
123

124
void calculate_package_version_based_on_upgraders(bool val) {
125
  calculatePackageVersionBasedOnUpgraders = val;
126
}
127

128
bool get_version_calculator_flag() {
129
  return calculatePackageVersionBasedOnUpgraders;
130
}
131

132
} // namespace torch::jit
133

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.