pytorch

Форк
0
86 строк · 2.2 Кб
1
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
2

3
#include <torch/csrc/jit/ir/ir.h>
4
#include <torch/csrc/jit/ir/irparser.h>
5
#include <mutex>
6
#include <string>
7
#include <unordered_map>
8

9
namespace torch::jit {
10

11
static UpgradersMap upgradersMap;
12

13
void UpgradersMap::set_content(
14
    std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
15
  // make sure we populate the map only once
16
  std::lock_guard<std::mutex> _(lock);
17
  if (isPopulated) {
18
    return;
19
  }
20

21
  content_ = std::move(content);
22
  isPopulated = true;
23
}
24

25
int UpgradersMap::count() {
26
  std::lock_guard<std::mutex> _(lock);
27
  return content_.size();
28
}
29

30
bool UpgradersMap::is_populated() {
31
  std::lock_guard<std::mutex> _(lock);
32
  return isPopulated;
33
}
34

35
const std::unordered_map<std::string, std::shared_ptr<Graph>>& UpgradersMap::
36
    get_content() {
37
  std::lock_guard<std::mutex> _(lock);
38
  return content_;
39
}
40

41
void UpgradersMap::test_only_set_content(
42
    const std::unordered_map<std::string, std::string>& content) {
43
  std::lock_guard<std::mutex> _(lock);
44
  for (const auto& entry : content) {
45
    auto graph = std::make_shared<Graph>();
46
    torch::jit::parseIR(entry.second, graph.get());
47
    content_.insert(std::make_pair(entry.first, graph));
48
  }
49
}
50
void UpgradersMap::test_only_remove_content(
51
    const std::unordered_map<std::string, std::string>& content) {
52
  std::lock_guard<std::mutex> _(lock);
53
  for (const auto& entry : content) {
54
    content_.erase(entry.first);
55
  }
56
}
57

58
void populate_upgraders_map(
59
    std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
60
  upgradersMap.set_content(std::move(content));
61
}
62

63
int get_upgraders_map_size() {
64
  return upgradersMap.count();
65
}
66

67
bool is_upgraders_map_populated() {
68
  return upgradersMap.is_populated();
69
}
70

71
const std::unordered_map<std::string, std::shared_ptr<Graph>>&
72
dump_upgraders_map() {
73
  return upgradersMap.get_content();
74
}
75

76
void test_only_populate_upgraders(
77
    const std::unordered_map<std::string, std::string>& content) {
78
  upgradersMap.test_only_set_content(content);
79
}
80

81
void test_only_remove_upgraders(
82
    const std::unordered_map<std::string, std::string>& content) {
83
  upgradersMap.test_only_remove_content(content);
84
}
85

86
} // namespace torch::jit
87

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.