pytorch

Форк
0
/
versioned_symbols.cpp 
108 строк · 3.9 Кб
1
#include <torch/csrc/jit/frontend/versioned_symbols.h>
2

3
#include <caffe2/serialize/versions.h>
4
#include <torch/csrc/api/include/torch/jit.h>
5

6
#include <unordered_map>
7

8
namespace torch::jit {
9
// Note [Versioned Symbols]
10
// When the schema or behavior of a symbol changes, serialized Torchscript
11
// programs using that symbol are likely to break. To prevent those breaks,
12
// the symbol's historic behavior can be implemented as a Torchscript builtin
13
// and when an older Torchscript program is loaded the program's uses of the
14
// symbol can be replaced with the builtin.
15
//
16
// For example, a function _test_serialization_subcmul(a, b, alpha) might have
17
// been improperly implemented as (b - alpha * a).
18
// Some users may have written and serialized programs using that function,
19
// however, and fixing it to perform (a - alpha * b) would break their programs.
20
// Using the "Versioned Symbol" pattern lets you replace
21
// _test_serialization_subcmul in older programs with a builtin
22
// _test_serialization_subcmul<version_range> that implements the historic
23
// behavior. That way old programs preserve their semantics while new programs
24
// can take advantage of the fix.
25
//
26
// To do this:
27
//
28
// 1) Identify the file version range where the symbol should be replaced,
29
//    e.g. versions 0 to 2, inclusive.
30
// 2) Create one or more builtins implementing the symbol's historic behavior.
31
//    These should be named <function>_<start_version>_<end_version> and
32
//    go into the "upgraders" namespace.
33
//    For example, the test-only aten::_test_serialization_subcmul has a builtin
34
//    for its "historic" behavior called
35
//    upgraders::_test_serialization_subcmul_0_2.
36
// 3) Add a mapping from the symbol to the corresponding SymbolRange
37
//    in the symbol_range_map (below).
38
//
39
// To test your versioning:
40
//
41
// 1) Serialize a module demonstrating the historic behavior.
42
// 2) Save it to test/jit/fixtures.
43
// 3) Implement your new behavior and bump the version counter.
44
// 4) Write the builtins and extend the symbol_range_map per the above
45
//    instructions.
46
// 5) Create a test in jit/test_save_load.py that loads the old module
47
//    and verifies it exhibits the historic behavior, then saves and
48
//    loads the same module and verifies it exhibits the current behavior.
49
//    See test_versioned_symbols for an example.
50

51
// Helper to hold the version range (inclusive on both ends) and the symbol
52
// to map to for that range.
53
struct SymbolRange {
54
  SymbolRange(
55
      const uint64_t _start_version,
56
      const uint64_t _end_version,
57
      const Symbol _sym)
58
      : start_version_{_start_version},
59
        end_version_{_end_version},
60
        sym_{_sym} {}
61
  const uint64_t start_version_;
62
  const uint64_t end_version_;
63
  const Symbol sym_;
64
};
65

66
static std::unordered_map<Symbol, SymbolRange> symbol_range_map({
67
    {Symbol::fromQualString("aten::_test_serialization_subcmul"),
68
     {0,
69
      2,
70
      Symbol::fromQualString("upgraders::_test_serialization_subcmul_0_2")}},
71
    {Symbol::fromQualString("aten::div"),
72
     {0, 3, Symbol::fromQualString("upgraders::div_0_3")}},
73
    {Symbol::fromQualString("aten::div_"),
74
     {0, 3, Symbol::fromQualString("upgraders::div__0_3")}},
75
    {Symbol::fromQualString("aten::full"),
76
     {0, 4, Symbol::fromQualString("upgraders::full_0_4")}},
77
});
78

79
static std::unordered_map<NodeKind, uint64_t> kind_min_version_map({
80
    {aten::div, 4},
81
    {aten::div_, 4},
82
    {aten::full, 5}, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
83
});
84

85
Symbol get_symbol_for_version(const Symbol name, const uint64_t version) {
86
  auto it = symbol_range_map.find(name);
87
  if (it == symbol_range_map.end()) {
88
    return name;
89
  }
90

91
  auto& entry = it->second;
92
  if (entry.start_version_ <= version && entry.end_version_ >= version) {
93
    return entry.sym_;
94
  }
95

96
  return name;
97
}
98

99
uint64_t get_min_version_for_kind(const NodeKind& kind) {
100
  auto it = kind_min_version_map.find(kind);
101
  if (it == kind_min_version_map.end()) {
102
    return 0;
103
  }
104

105
  return it->second;
106
}
107

108
} // namespace torch::jit
109

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.