pytorch
86 строк · 2.2 Кб
1#include <torch/csrc/jit/operator_upgraders/upgraders.h>
2
3#include <torch/csrc/jit/ir/ir.h>
4#include <torch/csrc/jit/ir/irparser.h>
5#include <mutex>
6#include <string>
7#include <unordered_map>
8
9namespace torch::jit {
10
11static UpgradersMap upgradersMap;
12
13void UpgradersMap::set_content(
14std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
15// make sure we populate the map only once
16std::lock_guard<std::mutex> _(lock);
17if (isPopulated) {
18return;
19}
20
21content_ = std::move(content);
22isPopulated = true;
23}
24
25int UpgradersMap::count() {
26std::lock_guard<std::mutex> _(lock);
27return content_.size();
28}
29
30bool UpgradersMap::is_populated() {
31std::lock_guard<std::mutex> _(lock);
32return isPopulated;
33}
34
35const std::unordered_map<std::string, std::shared_ptr<Graph>>& UpgradersMap::
36get_content() {
37std::lock_guard<std::mutex> _(lock);
38return content_;
39}
40
41void UpgradersMap::test_only_set_content(
42const std::unordered_map<std::string, std::string>& content) {
43std::lock_guard<std::mutex> _(lock);
44for (const auto& entry : content) {
45auto graph = std::make_shared<Graph>();
46torch::jit::parseIR(entry.second, graph.get());
47content_.insert(std::make_pair(entry.first, graph));
48}
49}
50void UpgradersMap::test_only_remove_content(
51const std::unordered_map<std::string, std::string>& content) {
52std::lock_guard<std::mutex> _(lock);
53for (const auto& entry : content) {
54content_.erase(entry.first);
55}
56}
57
58void populate_upgraders_map(
59std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
60upgradersMap.set_content(std::move(content));
61}
62
63int get_upgraders_map_size() {
64return upgradersMap.count();
65}
66
67bool is_upgraders_map_populated() {
68return upgradersMap.is_populated();
69}
70
71const std::unordered_map<std::string, std::shared_ptr<Graph>>&
72dump_upgraders_map() {
73return upgradersMap.get_content();
74}
75
76void test_only_populate_upgraders(
77const std::unordered_map<std::string, std::string>& content) {
78upgradersMap.test_only_set_content(content);
79}
80
81void test_only_remove_upgraders(
82const std::unordered_map<std::string, std::string>& content) {
83upgradersMap.test_only_remove_content(content);
84}
85
86} // namespace torch::jit
87