6
#include <unordered_map>
9
#include <ATen/core/function.h>
10
#include <c10/util/Exception.h>
11
#include <c10/util/StringUtil.h>
12
#include <torch/csrc/jit/api/function_impl.h>
13
#include <torch/csrc/jit/frontend/error_report.h>
14
#include <torch/csrc/jit/ir/ir.h>
15
#include <torch/csrc/jit/jit_log.h>
16
#include <torch/csrc/jit/serialization/python_print.h>
21
class JitLoggingConfig {
23
static JitLoggingConfig& getInstance() {
24
static JitLoggingConfig instance;
27
JitLoggingConfig(JitLoggingConfig const&) = delete;
28
void operator=(JitLoggingConfig const&) = delete;
31
std::string logging_levels;
32
std::unordered_map<std::string, size_t> files_to_levels;
35
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
37
const char* jit_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL");
38
logging_levels.assign(jit_log_level == nullptr ? "" : jit_log_level);
45
std::string getLoggingLevels() const {
46
return this->logging_levels;
48
void setLoggingLevels(std::string levels) {
49
this->logging_levels = std::move(levels);
53
const std::unordered_map<std::string, size_t>& getFilesToLevels() const {
54
return this->files_to_levels;
57
void setOutputStream(std::ostream& out_stream) {
58
this->out = &out_stream;
61
std::ostream& getOutputStream() {
66
std::string get_jit_logging_levels() {
67
return JitLoggingConfig::getInstance().getLoggingLevels();
70
void set_jit_logging_levels(std::string level) {
71
JitLoggingConfig::getInstance().setLoggingLevels(std::move(level));
74
void set_jit_logging_output_stream(std::ostream& stream) {
75
JitLoggingConfig::getInstance().setOutputStream(stream);
78
std::ostream& get_jit_logging_output_stream() {
79
return JitLoggingConfig::getInstance().getOutputStream();
82
// gets a string representation of a node header
83
// (e.g. outputs, a node kind and outputs)
84
std::string getHeader(const Node* node) {
86
node->print(ss, 0, {}, false, false, false, false);
90
void JitLoggingConfig::parse() {
91
std::stringstream in_ss;
92
in_ss << "function:" << this->logging_levels;
94
files_to_levels.clear();
96
while (std::getline(in_ss, line, ':')) {
101
auto index_at = line.find_last_of('>');
102
auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
103
size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1;
104
auto end_index = line.find_last_of('.') == std::string::npos
106
: line.find_last_of('.');
107
auto filename = line.substr(begin_index, end_index - begin_index);
108
files_to_levels.insert({filename, logging_level});
112
bool is_enabled(const char* cfname, JitLoggingLevels level) {
113
const auto& files_to_levels =
114
JitLoggingConfig::getInstance().getFilesToLevels();
115
std::string fname{cfname};
116
fname = c10::detail::StripBasename(fname);
117
const auto end_index = fname.find_last_of('.') == std::string::npos
119
: fname.find_last_of('.');
120
const auto fname_no_ext = fname.substr(0, end_index);
122
const auto it = files_to_levels.find(fname_no_ext);
123
if (it == files_to_levels.end()) {
127
return level <= static_cast<JitLoggingLevels>(it->second);
130
// Unfortunately, in `GraphExecutor` where `log_function` is invoked
131
// we won't have access to an original function, so we have to construct
132
// a dummy function to give to PythonPrint
133
std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
134
torch::jit::GraphFunction func("source_dump", graph, nullptr);
135
std::vector<at::IValue> constants;
137
PythonPrint pp(constants, deps);
138
pp.printFunction(func);
142
std::string jit_log_prefix(
143
const std::string& prefix,
144
const std::string& in_str) {
145
std::stringstream in_ss(in_str);
146
std::stringstream out_ss;
148
while (std::getline(in_ss, line)) {
149
out_ss << prefix << line << std::endl;
155
std::string jit_log_prefix(
156
JitLoggingLevels level,
159
const std::string& in_str) {
160
std::stringstream prefix_ss;
162
prefix_ss << level << " ";
163
prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
164
prefix_ss << std::setfill('0') << std::setw(3) << l;
167
return jit_log_prefix(prefix_ss.str(), in_str);
170
std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
172
case JitLoggingLevels::GRAPH_DUMP:
175
case JitLoggingLevels::GRAPH_UPDATE:
178
case JitLoggingLevels::GRAPH_DEBUG:
182
TORCH_INTERNAL_ASSERT(false, "Invalid level");