1
#include <torch/csrc/jit/frontend/versioned_symbols.h>
3
#include <caffe2/serialize/versions.h>
4
#include <torch/csrc/api/include/torch/jit.h>
6
#include <unordered_map>
9
// Note [Versioned Symbols]
10
// When the schema or behavior of a symbol changes, serialized Torchscript
11
// programs using that symbol are likely to break. To prevent those breaks,
12
// the symbol's historic behavior can be implemented as a Torchscript builtin
13
// and when an older Torchscript program is loaded the program's uses of the
14
// symbol can be replaced with the builtin.
16
// For example, a function _test_serialization_subcmul(a, b, alpha) might have
17
// been improperly implemented as (b - alpha * a).
18
// Some users may have written and serialized programs using that function,
19
// however, and fixing it to perform (a - alpha * b) would break their programs.
20
// Using the "Versioned Symbol" pattern lets you replace
21
// _test_serialization_subcmul in older programs with a builtin
22
// _test_serialization_subcmul<version_range> that implements the historic
23
// behavior. That way old programs preserve their semantics while new programs
24
// can take advantage of the fix.
28
// 1) Identify the file version range where the symbol should be replaced,
29
// e.g. versions 0 to 2, inclusive.
30
// 2) Create one or more builtins implementing the symbol's historic behavior.
31
// These should be named <function>_<start_version>_<end_version> and
32
// go into the "upgraders" namespace.
33
// For example, the test-only aten::_test_serialization_subcmul has a builtin
34
// for its "historic" behavior called
35
// upgraders::_test_serialization_subcmul_0_2.
36
// 3) Add a mapping from the symbol to the corresponding SymbolRange
37
// in the symbol_range_map (below).
39
// To test your versioning:
41
// 1) Serialize a module demonstrating the historic behavior.
42
// 2) Save it to test/jit/fixtures.
43
// 3) Implement your new behavior and bump the version counter.
44
// 4) Write the builtins and extend the symbol_range_map per the above
46
// 5) Create a test in jit/test_save_load.py that loads the old module
47
// and verifies it exhibits the historic behavior, then saves and
48
// loads the same module and verifies it exhibits the current behavior.
49
// See test_versioned_symbols for an example.
51
// Helper to hold the version range (inclusive on both ends) and the symbol
52
// to map to for that range.
55
const uint64_t _start_version,
56
const uint64_t _end_version,
58
: start_version_{_start_version},
59
end_version_{_end_version},
61
const uint64_t start_version_;
62
const uint64_t end_version_;
66
static std::unordered_map<Symbol, SymbolRange> symbol_range_map({
67
{Symbol::fromQualString("aten::_test_serialization_subcmul"),
70
Symbol::fromQualString("upgraders::_test_serialization_subcmul_0_2")}},
71
{Symbol::fromQualString("aten::div"),
72
{0, 3, Symbol::fromQualString("upgraders::div_0_3")}},
73
{Symbol::fromQualString("aten::div_"),
74
{0, 3, Symbol::fromQualString("upgraders::div__0_3")}},
75
{Symbol::fromQualString("aten::full"),
76
{0, 4, Symbol::fromQualString("upgraders::full_0_4")}},
79
static std::unordered_map<NodeKind, uint64_t> kind_min_version_map({
82
{aten::full, 5}, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
85
Symbol get_symbol_for_version(const Symbol name, const uint64_t version) {
86
auto it = symbol_range_map.find(name);
87
if (it == symbol_range_map.end()) {
91
auto& entry = it->second;
92
if (entry.start_version_ <= version && entry.end_version_ >= version) {
99
uint64_t get_min_version_for_kind(const NodeKind& kind) {
100
auto it = kind_min_version_map.find(kind);
101
if (it == kind_min_version_map.end()) {
108
} // namespace torch::jit