pytorch

Форк
0
/
debug_info.cpp 
233 строки · 9.1 Кб
1
#include <torch/csrc/jit/frontend/source_range.h>
2
#include <torch/csrc/jit/mobile/debug_info.h>
3
#include <torch/csrc/jit/mobile/type_parser.h>
4
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
5
#include <torch/csrc/jit/serialization/source_range_serialization.h>
6

7
#include <ATen/core/ivalue.h>
8
#include <torch/csrc/jit/serialization/pickle.h>
9

10
#include <c10/util/string_view.h>
11

12
namespace torch {
13
namespace jit {
14

15
namespace {
16

17
C10_ALWAYS_INLINE std::string debugHandlesNotFoundMessage(
18
    const std::string& debug_handles_string) {
19
  return "Debug info for handle(s): " + debug_handles_string +
20
      ", was not found.";
21
}
22

23
std::pair<std::vector<StackEntry>, std::string> getStackTraceWithModuleHierarchy(
24
    const DebugInfoTuple& source_callstack,
25
    const std::string& caller_name) {
26
  std::vector<StackEntry> entries;
27

28
  const SourceRange& range =
29
      std::get<kDebugInfoTupleSourceRangeIndex>(source_callstack);
30
  InlinedCallStackPtr callstack_ptr =
31
      std::get<kDebugInfoTupleInlinedCSIndex>(source_callstack);
32
  std::string prev_function_name = caller_name;
33
  std::string module_info;
34
  if (!callstack_ptr) {
35
    // If not cs then top level node
36
    entries.emplace_back(StackEntry{prev_function_name, range});
37
    return {std::move(entries), std::move(module_info)};
38
  } else {
39
    while (callstack_ptr) {
40
      const auto& opt_module_instance_info = callstack_ptr->module_instance();
41
      if (opt_module_instance_info.has_value()) {
42
        const auto& module_instance_info = opt_module_instance_info.value();
43
        // Sometimes (e.g., in lowered backends) we augment instance name with
44
        // type name instead of losing type name. In those cases instance_name
45
        // includes both instance name and type name. See
46
        // callstack_debug_info_serialization.cpp
47
        if (module_instance_info.class_type()) {
48
          module_info.append(".").append(
49
              utils::get_module_info(module_instance_info));
50
        } else {
51
          module_info.append(".").append(module_instance_info.instance_name());
52
        }
53
      } else {
54
        module_info.append(".UNKNOWN_INSTANCE(UNKNOWN_TYPE)");
55
      }
56
      // Now add source range info to stack
57
      entries.emplace_back(
58
          StackEntry{prev_function_name, callstack_ptr->source_range()});
59
      prev_function_name = callstack_ptr->function_name();
60
      // Function name appended here
61
      // It is renamed to prev_function_name because for StackEntry
62
      // it will be appended in the next iteration. This is the format
63
      // in which format_stack_trace expects function names.
64
      module_info.append("::").append(prev_function_name);
65

66
      if (callstack_ptr->callee()) {
67
        callstack_ptr = callstack_ptr->callee().value();
68
      } else {
69
        callstack_ptr = c10::intrusive_ptr<InlinedCallStack>();
70
      }
71
    }
72
    entries.emplace_back(StackEntry{prev_function_name, range});
73
    return {std::move(entries), std::move(module_info)};
74
  }
75
}
76

77
// This function construct stacktrace with module hierarchy
78
// Module hierarchy will contain information about where in the
79
// module hierarchy this source is. For example if conv2d op
80
// exist in hierarcy A->B->C->Conv2d with type annotations of
81
// A -> TopM, B->MyModule, C->SomeModule, then module hierarchy
82
// will be TopM(A).MyModule(B).SomeModule(C).Conv2d(conv)
83
// Source level stack information will be from model source code.
84
std::pair<std::string, std::string> getStackTraceWithModuleHierarchy(
85
    const std::vector<DebugInfoTuple>& source_callstacks,
86
    const std::string& root_scope_string,
87
    const std::string& top_module_type_name) {
88
  std::vector<StackEntry> stack_entries;
89
  std::string module_info =
90
      root_scope_string + "(" + top_module_type_name + ")";
91
  std::string caller_fn_name = "<unknown>";
92
  module_info.append("::").append(caller_fn_name);
93
  for (const auto& debug_info : source_callstacks) {
94
    auto debug_info_pair =
95
        getStackTraceWithModuleHierarchy(debug_info, caller_fn_name);
96
    auto entries = std::move(debug_info_pair.first);
97
    stack_entries.insert(stack_entries.end(), entries.begin(), entries.end());
98
    module_info.append(debug_info_pair.second);
99
  }
100
  // Only last entry in the callstack will have a node name of interest.
101
  // Rest are likely CallMethod/CallFunction nodes
102
  auto last_entry = source_callstacks.back();
103
  const std::string& node_name =
104
      std::get<kDebugInfoTupleNodeNameIndex>(last_entry);
105
  module_info.append(".").append(node_name);
106
  std::ostringstream ss;
107
  ss << "Module hierarchy:" << module_info << "\n";
108
  format_stack_trace(ss, stack_entries);
109
  return {ss.str(), std::move(module_info)};
110
}
111

112
} // namespace
113

114
MobileDebugTable::MobileDebugTable(
115
    std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader,
116
    const std::shared_ptr<CompilationUnit>& cu) {
117
  ska::flat_hash_map<int64_t, SourceRange> source_range_map;
118
  const std::vector<std::string>& record_names = reader->getAllRecords();
119
  const c10::string_view suffix(".debug_pkl");
120
  for (const auto& record_name : record_names) {
121
    if (c10::string_view(record_name).ends_with(suffix)) {
122
      auto [debug_data, debug_size] = reader->getRecord(record_name);
123
      auto ivalueTuple = jit::unpickle(
124
          reinterpret_cast<const char*>(debug_data.get()),
125
          debug_size,
126
          nullptr,
127
          {},
128
          c10::parseType);
129
      const auto& ivalues = ivalueTuple.toTuple()->elements();
130
      IValue lines;
131
      std::unique_ptr<SourceRangeDeserializer> deserializer;
132
      if (ivalues.size() == 3 && ivalues[0].isString() &&
133
          kFormatWithStringTable == ivalues[0].toStringRef()) {
134
        // new format
135
        deserializer = std::make_unique<SourceRangeDeserializer>(ivalues[1]);
136
        lines = ivalues[2];
137
      } else {
138
        deserializer = std::make_unique<SourceRangeDeserializer>();
139
        lines = ivalueTuple;
140
      }
141

142
      for (auto& val : lines.toTuple()->elements()) {
143
        auto tup_elems = std::move(*std::move(val).toTuple()).elements();
144
        // For BC we decode only tuples with 3 elements
145
        // assuming it contains
146
        // byte_offset, debug_handle (=source range tag), source range
147
        if (tup_elems.size() == 3) {
148
          int64_t debug_handle = tup_elems[kSourceRangeTagIndex].toInt();
149
          auto source_range =
150
              deserializer->deserialize(tup_elems[kSourceRangeIndex]);
151
          source_range_map.emplace(debug_handle, std::move(source_range));
152
        }
153
      }
154
    }
155
  }
156
  const std::string callstack_debug_file("callstack_debug_map.pkl");
157
  if (reader->hasRecord("callstack_debug_map.pkl")) {
158
    auto [callstack_data, callstack_data_size] =
159
        reader->getRecord(callstack_debug_file);
160
    CallStackDebugInfoUnpickler unpickler;
161
    callstack_ptr_map_ = unpickler.unpickle(
162
        std::move(callstack_data), callstack_data_size, source_range_map, cu);
163
  }
164
}
165

166
std::string MobileDebugTable::getModuleHierarchyInfo(
167
    const int64_t debug_handle,
168
    const std::string& top_module_type_name) const {
169
  const auto it = callstack_ptr_map_.find(debug_handle);
170
  if (it == callstack_ptr_map_.end()) {
171
    return debugHandlesNotFoundMessage(std::to_string(debug_handle));
172
  }
173
  return (getStackTraceWithModuleHierarchy(
174
              {it->second}, "top", top_module_type_name))
175
      .second;
176
}
177

178
std::string MobileDebugTable::getModuleHierarchyInfo(
179
    const std::vector<int64_t>& debug_handles,
180
    const std::string& top_module_type_name) const {
181
  return getSourceDebugModuleHierarchyInfo(debug_handles, top_module_type_name)
182
      .second;
183
}
184

185
std::string MobileDebugTable::getSourceDebugString(
186
    const int64_t debug_handle,
187
    const std::string& top_module_type_name) const {
188
  const auto it = callstack_ptr_map_.find(debug_handle);
189
  if (it == callstack_ptr_map_.end()) {
190
    return debugHandlesNotFoundMessage(std::to_string(debug_handle));
191
  }
192
  return (getStackTraceWithModuleHierarchy(
193
              {it->second}, "top", top_module_type_name))
194
      .first;
195
}
196

197
std::string MobileDebugTable::getSourceDebugString(
198
    const std::vector<int64_t>& debug_handles,
199
    const std::string& top_module_type_name) const {
200
  return getSourceDebugModuleHierarchyInfo(debug_handles, top_module_type_name)
201
      .first;
202
}
203

204
std::pair<std::string, std::string> MobileDebugTable::
205
    getSourceDebugModuleHierarchyInfo(
206
        const std::vector<int64_t>& debug_handles,
207
        const std::string& top_module_type_name) const {
208
  std::vector<DebugInfoTuple> debug_infos;
209
  bool debug_handle_not_found{false};
210
  for (auto it = debug_handles.rbegin(); it != debug_handles.rend(); ++it) {
211
    auto debug_handle = *it;
212
    const auto cs_it = callstack_ptr_map_.find(debug_handle);
213
    if (cs_it == callstack_ptr_map_.end()) {
214
      debug_handle_not_found = true;
215
      break;
216
    }
217
    debug_infos.emplace_back(cs_it->second);
218
  }
219
  if (debug_handle_not_found) {
220
    std::string debug_handles_string = "debug_handles:{";
221
    for (const auto debug_handle : debug_handles) {
222
      debug_handles_string += std::to_string(debug_handle);
223
    }
224
    debug_handles_string += "}";
225
    debug_handles_string = debugHandlesNotFoundMessage(debug_handles_string);
226
    return {debug_handles_string, debug_handles_string};
227
  }
228
  return (getStackTraceWithModuleHierarchy(
229
      debug_infos, "top", top_module_type_name));
230
}
231

232
} // namespace jit
233
} // namespace torch
234

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.