pytorch
152 строки · 6.9 Кб
1#include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
2
3#include <ATen/core/stack.h>
4#include <c10/macros/Export.h>
5#include <torch/csrc/jit/api/compilation_unit.h>
6#include <torch/csrc/jit/api/function_impl.h>
7#include <torch/csrc/jit/frontend/ir_emitter.h>
8#include <torch/csrc/jit/ir/ir.h>
9#include <torch/csrc/jit/operator_upgraders/upgraders.h>
10#include <torch/csrc/jit/serialization/export_bytecode.h>
11#include <string>
12#include <unordered_map>
13
14namespace torch::jit {
15
16static std::unordered_map<std::string, std::string> kUpgradersEntryMap({
17{"logspace_0_8", R"SCRIPT(
18def logspace_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, dtype: Optional[int], layout: Optional[int],
19device: Optional[Device], pin_memory: Optional[bool]):
20if (steps is None):
21return torch.logspace(start=start, end=end, steps=100, base=base, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
22return torch.logspace(start=start, end=end, steps=steps, base=base, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
23)SCRIPT"},
24{"logspace_out_0_8", R"SCRIPT(
25def logspace_out_0_8(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], base: float, *, out: Tensor):
26if (steps is None):
27return torch.logspace(start=start, end=end, steps=100, base=base, out=out)
28return torch.logspace(start=start, end=end, steps=steps, base=base, out=out)
29)SCRIPT"},
30{"linspace_0_7", R"SCRIPT(
31def linspace_0_7(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], *, dtype: Optional[int], layout: Optional[int],
32device: Optional[Device], pin_memory: Optional[bool]):
33if (steps is None):
34return torch.linspace(start=start, end=end, steps=100, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
35return torch.linspace(start=start, end=end, steps=steps, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
36)SCRIPT"},
37{"linspace_out_0_7", R"SCRIPT(
38def linspace_out_0_7(start: Union[int, float, complex], end: Union[int, float, complex], steps: Optional[int], *, out: Tensor):
39if (steps is None):
40return torch.linspace(start=start, end=end, steps=100, out=out)
41return torch.linspace(start=start, end=end, steps=steps, out=out)
42)SCRIPT"},
43{"div_Tensor_0_3", R"SCRIPT(
44def div_Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
45if (self.is_floating_point() or other.is_floating_point()):
46return self.true_divide(other)
47return self.divide(other, rounding_mode='trunc')
48)SCRIPT"},
49{"div_Tensor_mode_0_3", R"SCRIPT(
50def div_Tensor_mode_0_3(self: Tensor, other: Tensor, *, rounding_mode: Optional[str]=None) -> Tensor:
51return self.divide(other, rounding_mode=rounding_mode)
52)SCRIPT"},
53{"div_Scalar_0_3", R"SCRIPT(
54def div_Scalar_0_3(self: Tensor, other: number) -> Tensor:
55if (self.is_floating_point() or isinstance(other, float)):
56return self.true_divide(other)
57return self.divide(other, rounding_mode='trunc')
58)SCRIPT"},
59{"div_Scalar_mode_0_3", R"SCRIPT(
60def div_Scalar_mode_0_3(self: Tensor, other: number, *, rounding_mode: Optional[str]=None) -> Tensor:
61return self.divide(other, rounding_mode=rounding_mode)
62)SCRIPT"},
63{"div_out_0_3", R"SCRIPT(
64def div_out_0_3(self: Tensor, other: Tensor, *, out: Tensor) -> Tensor:
65if (self.is_floating_point() or other.is_floating_point() or out.is_floating_point()):
66return self.true_divide(other, out=out)
67return self.divide(other, rounding_mode='trunc', out=out)
68)SCRIPT"},
69{"div_out_mode_0_3", R"SCRIPT(
70def div_out_mode_0_3(self: Tensor, other: Tensor, *, rounding_mode: Optional[str]=None, out: Tensor) -> Tensor:
71return self.divide(other, rounding_mode=rounding_mode, out=out)
72)SCRIPT"},
73{"div__Tensor_0_3", R"SCRIPT(
74def div__Tensor_0_3(self: Tensor, other: Tensor) -> Tensor:
75if (self.is_floating_point() or other.is_floating_point()):
76return self.true_divide_(other)
77return self.divide_(other, rounding_mode='trunc')
78)SCRIPT"},
79{"div__Tensor_mode_0_3", R"SCRIPT(
80def div__Tensor_mode_0_3(self: Tensor, other: Tensor, *, rounding_mode: Optional[str]=None) -> Tensor:
81return self.divide_(other, rounding_mode=rounding_mode)
82)SCRIPT"},
83{"div__Scalar_0_3", R"SCRIPT(
84def div__Scalar_0_3(self: Tensor, other: number) -> Tensor:
85if (self.is_floating_point() or isinstance(other, float)):
86return self.true_divide_(other)
87return self.divide_(other, rounding_mode='trunc')
88)SCRIPT"},
89{"div__Scalar_mode_0_3", R"SCRIPT(
90def div__Scalar_mode_0_3(self: Tensor, other: number, *, rounding_mode: Optional[str]=None) -> Tensor:
91return self.divide_(other, rounding_mode=rounding_mode)
92)SCRIPT"},
93{"full_names_0_4", R"SCRIPT(
94def full_names_0_4(size:List[int], fill_value:number, *, names:Optional[List[str]]=None,
95dtype:Optional[int]=None, layout:Optional[int]=None, device:Optional[Device]=None,
96pin_memory:Optional[bool]=None) -> Tensor:
97return torch.full(size, fill_value, names=names, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
98)SCRIPT"},
99{"full_0_4", R"SCRIPT(
100def full_0_4(size:List[int], fill_value:number, *, dtype:Optional[int]=None,
101layout:Optional[int]=None, device:Optional[Device]=None,
102pin_memory:Optional[bool]=None) -> Tensor:
103if dtype is None:
104fill_value = float(fill_value)
105return torch.full(size, fill_value, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory)
106)SCRIPT"},
107{"full_out_0_4", R"SCRIPT(
108def full_out_0_4(size:List[int], fill_value:number, *, out:Tensor) -> Tensor:
109return torch.full(size, fill_value, out=out)
110)SCRIPT"},
111{"gelu_0_9", R"SCRIPT(
112def gelu_0_9(self: Tensor) -> Tensor:
113return torch.gelu(self, approximate='none')
114)SCRIPT"},
115{"gelu_out_0_9", R"SCRIPT(
116def gelu_out_0_9(self: Tensor, *, out: Tensor) -> Tensor:
117return torch.gelu(self, approximate='none', out=out)
118)SCRIPT"},
119});
120
121std::shared_ptr<Graph> create_upgrader_graph(
122const std::string& upgrader_name,
123const std::string& upgrader_body) {
124auto cu = std::make_shared<CompilationUnit>();
125cu->define(c10::nullopt, upgrader_body, nativeResolver(), nullptr);
126Function& jitFunc = cu->get_function(upgrader_name);
127GraphFunction& graphFunction = toGraphFunction(jitFunc);
128return graphFunction.graph();
129}
130
131std::unordered_map<std::string, std::shared_ptr<Graph>>
132generate_upgraders_graph() {
133std::unordered_map<std::string, std::shared_ptr<Graph>> populate_content;
134for (const auto& entry : kUpgradersEntryMap) {
135auto upgrader_graph = create_upgrader_graph(entry.first, entry.second);
136populate_content.insert(std::make_pair(entry.first, upgrader_graph));
137}
138return populate_content;
139}
140
141void populate_upgraders_graph_map() {
142if (!is_upgraders_map_populated()) {
143auto graphs = generate_upgraders_graph();
144populate_upgraders_map(std::move(graphs));
145}
146}
147
148std::unordered_map<std::string, std::string> get_upgraders_entry_map() {
149return kUpgradersEntryMap;
150}
151
152} // namespace torch::jit
153