pytorch
132 строки · 4.8 Кб
1#include <torch/csrc/jit/operator_upgraders/version_map.h>
2
3#include <algorithm>
4#include <string>
5#include <unordered_map>
6#include <vector>
7
8namespace torch::jit {
9
10// this flag is used to make sure the elements in the version map
11// are sorted according to when the upgraders are introduced.
12static bool isVersionMapSorted = false;
13
14// Main entry point for all operators that have valid upgraders.
15// Note for developers: The list of upgraders should be SORTED
16// by the version number where the upgrader is registered.
17static std::unordered_map<std::string, std::vector<UpgraderEntry>> operatorVersionMap(
18{{"aten::logspace",
19{{9,
20"logspace_0_8",
21"aten::logspace(Scalar start, Scalar end, int? steps=None, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
22{"aten::logspace.out",
23{{9,
24"logspace_out_0_8",
25"aten::logspace.out(Scalar start, Scalar end, int? steps=None, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)"}}},
26{"aten::linspace",
27{{8,
28"linspace_0_7",
29"aten::linspace(Scalar start, Scalar end, int? steps=None, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
30{"aten::linspace.out",
31{{8,
32"linspace_out_0_7",
33"aten::linspace.out(Scalar start, Scalar end, int? steps=None, *, Tensor(a!) out) -> Tensor(a!)"}}},
34{"aten::div.Tensor",
35{{4,
36"div_Tensor_0_3",
37"aten::div.Tensor(Tensor self, Tensor other) -> Tensor"}}},
38{"aten::div.Tensor_mode",
39{{4,
40"div_Tensor_mode_0_3",
41"aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor"}}},
42{"aten::div.Scalar",
43{{4,
44"div_Scalar_0_3",
45"aten::div.Scalar(Tensor self, Scalar other) -> Tensor"}}},
46{"aten::div.Scalar_mode",
47{{4,
48"div_Scalar_mode_0_3",
49"aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor"}}},
50{"aten::div.out",
51{{4,
52"div_out_0_3",
53"aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"}}},
54{"aten::div.out_mode",
55{{4,
56"div_out_mode_0_3",
57"aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)"}}},
58{"aten::div_.Tensor",
59{{4,
60"div__Tensor_0_3",
61"aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"}}},
62{"aten::div_.Tensor_mode",
63{{4,
64"div__Tensor_mode_0_3",
65"aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)"}}},
66{"aten::div_.Scalar",
67{{4,
68"div__Scalar_0_3",
69"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"}}},
70{"aten::div_.Scalar_mode",
71{{4,
72"div__Scalar_mode_0_3",
73"aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)"}}},
74{"aten::full",
75{{5,
76"full_0_4",
77"aten::full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
78{"aten::full.names",
79{{5,
80"full_names_0_4",
81"aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"}}},
82{"aten::full.out",
83{{5,
84"full_out_0_4",
85"aten::full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)"}}},
86{"aten::gelu", {{10, "gelu_0_9", "aten::gelu(Tensor self) -> Tensor"}}},
87{"aten::gelu.out",
88{{10,
89"gelu_out_0_9",
90"aten::gelu.out(Tensor self, *, Tensor(a!) out) -> Tensor"}}}});
91
92const std::unordered_map<std::string, std::vector<UpgraderEntry>>&
93get_operator_version_map() {
94if (!isVersionMapSorted) {
95for (auto entry : operatorVersionMap) {
96std::sort(
97entry.second.begin(),
98entry.second.end(),
99[](const auto& a, const auto& b) {
100return a.bumped_at_version > b.bumped_at_version;
101});
102}
103isVersionMapSorted = true;
104}
105return operatorVersionMap;
106}
107
108void test_only_add_entry(const std::string& op_name, UpgraderEntry entry) {
109test_only_reset_flag();
110operatorVersionMap[op_name].emplace_back(std::move(entry));
111}
112
113void test_only_remove_entry(const std::string& op_name) {
114test_only_reset_flag();
115operatorVersionMap.erase(op_name);
116}
117
118void test_only_reset_flag() {
119isVersionMapSorted = false;
120}
121
122static bool calculatePackageVersionBasedOnUpgraders = false;
123
124void calculate_package_version_based_on_upgraders(bool val) {
125calculatePackageVersionBasedOnUpgraders = val;
126}
127
128bool get_version_calculator_flag() {
129return calculatePackageVersionBasedOnUpgraders;
130}
131
132} // namespace torch::jit
133