pytorch

Форк
0
/
jit_log.cpp 
189 строк · 5.1 Кб
1
#include <cstdlib>
2
#include <iomanip>
3
#include <iostream>
4
#include <sstream>
5
#include <string>
6
#include <unordered_map>
7
#include <vector>
8

9
#include <ATen/core/function.h>
10
#include <c10/util/Exception.h>
11
#include <c10/util/StringUtil.h>
12
#include <torch/csrc/jit/api/function_impl.h>
13
#include <torch/csrc/jit/frontend/error_report.h>
14
#include <torch/csrc/jit/ir/ir.h>
15
#include <torch/csrc/jit/jit_log.h>
16
#include <torch/csrc/jit/serialization/python_print.h>
17

18
namespace torch {
19
namespace jit {
20

21
class JitLoggingConfig {
22
 public:
23
  static JitLoggingConfig& getInstance() {
24
    static JitLoggingConfig instance;
25
    return instance;
26
  }
27
  JitLoggingConfig(JitLoggingConfig const&) = delete;
28
  void operator=(JitLoggingConfig const&) = delete;
29

30
 private:
31
  std::string logging_levels;
32
  std::unordered_map<std::string, size_t> files_to_levels;
33
  std::ostream* out;
34

35
  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
36
  JitLoggingConfig() {
37
    const char* jit_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL");
38
    logging_levels.assign(jit_log_level == nullptr ? "" : jit_log_level);
39
    out = &std::cerr;
40
    parse();
41
  }
42
  void parse();
43

44
 public:
45
  std::string getLoggingLevels() const {
46
    return this->logging_levels;
47
  }
48
  void setLoggingLevels(std::string levels) {
49
    this->logging_levels = std::move(levels);
50
    parse();
51
  }
52

53
  const std::unordered_map<std::string, size_t>& getFilesToLevels() const {
54
    return this->files_to_levels;
55
  }
56

57
  void setOutputStream(std::ostream& out_stream) {
58
    this->out = &out_stream;
59
  }
60

61
  std::ostream& getOutputStream() {
62
    return *(this->out);
63
  }
64
};
65

66
std::string get_jit_logging_levels() {
67
  return JitLoggingConfig::getInstance().getLoggingLevels();
68
}
69

70
void set_jit_logging_levels(std::string level) {
71
  JitLoggingConfig::getInstance().setLoggingLevels(std::move(level));
72
}
73

74
void set_jit_logging_output_stream(std::ostream& stream) {
75
  JitLoggingConfig::getInstance().setOutputStream(stream);
76
}
77

78
std::ostream& get_jit_logging_output_stream() {
79
  return JitLoggingConfig::getInstance().getOutputStream();
80
}
81

82
// gets a string representation of a node header
83
// (e.g. outputs, a node kind and outputs)
84
std::string getHeader(const Node* node) {
85
  std::stringstream ss;
86
  node->print(ss, 0, {}, false, false, false, false);
87
  return ss.str();
88
}
89

90
void JitLoggingConfig::parse() {
91
  std::stringstream in_ss;
92
  in_ss << "function:" << this->logging_levels;
93

94
  files_to_levels.clear();
95
  std::string line;
96
  while (std::getline(in_ss, line, ':')) {
97
    if (line.empty()) {
98
      continue;
99
    }
100

101
    auto index_at = line.find_last_of('>');
102
    auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
103
    size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1;
104
    auto end_index = line.find_last_of('.') == std::string::npos
105
        ? line.size()
106
        : line.find_last_of('.');
107
    auto filename = line.substr(begin_index, end_index - begin_index);
108
    files_to_levels.insert({filename, logging_level});
109
  }
110
}
111

112
bool is_enabled(const char* cfname, JitLoggingLevels level) {
113
  const auto& files_to_levels =
114
      JitLoggingConfig::getInstance().getFilesToLevels();
115
  std::string fname{cfname};
116
  fname = c10::detail::StripBasename(fname);
117
  const auto end_index = fname.find_last_of('.') == std::string::npos
118
      ? fname.size()
119
      : fname.find_last_of('.');
120
  const auto fname_no_ext = fname.substr(0, end_index);
121

122
  const auto it = files_to_levels.find(fname_no_ext);
123
  if (it == files_to_levels.end()) {
124
    return false;
125
  }
126

127
  return level <= static_cast<JitLoggingLevels>(it->second);
128
}
129

130
// Unfortunately, in `GraphExecutor` where `log_function` is invoked
131
// we won't have access to an original function, so we have to construct
132
// a dummy function to give to PythonPrint
133
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
134
  torch::jit::GraphFunction func("source_dump", graph, nullptr);
135
  std::vector<at::IValue> constants;
136
  PrintDepsTable deps;
137
  PythonPrint pp(constants, deps);
138
  pp.printFunction(func);
139
  return pp.str();
140
}
141

142
std::string jit_log_prefix(
143
    const std::string& prefix,
144
    const std::string& in_str) {
145
  std::stringstream in_ss(in_str);
146
  std::stringstream out_ss;
147
  std::string line;
148
  while (std::getline(in_ss, line)) {
149
    out_ss << prefix << line << std::endl;
150
  }
151

152
  return out_ss.str();
153
}
154

155
std::string jit_log_prefix(
156
    JitLoggingLevels level,
157
    const char* fn,
158
    int l,
159
    const std::string& in_str) {
160
  std::stringstream prefix_ss;
161
  prefix_ss << "[";
162
  prefix_ss << level << " ";
163
  prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
164
  prefix_ss << std::setfill('0') << std::setw(3) << l;
165
  prefix_ss << "] ";
166

167
  return jit_log_prefix(prefix_ss.str(), in_str);
168
}
169

170
std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
171
  switch (level) {
172
    case JitLoggingLevels::GRAPH_DUMP:
173
      out << "DUMP";
174
      break;
175
    case JitLoggingLevels::GRAPH_UPDATE:
176
      out << "UPDATE";
177
      break;
178
    case JitLoggingLevels::GRAPH_DEBUG:
179
      out << "DEBUG";
180
      break;
181
    default:
182
      TORCH_INTERNAL_ASSERT(false, "Invalid level");
183
  }
184

185
  return out;
186
}
187

188
} // namespace jit
189
} // namespace torch
190

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.