pytorch

Форк
0
/
Exception.cpp 
261 строка · 6.3 Кб
1
#include <c10/util/Exception.h>
2
#include <c10/util/Logging.h>
3
#include <c10/util/Type.h>
4

5
#include <sstream>
6
#include <string>
7
#include <utility>
8

9
namespace c10 {
10

11
Error::Error(std::string msg, Backtrace backtrace, const void* caller)
12
    : msg_(std::move(msg)), backtrace_(std::move(backtrace)), caller_(caller) {
13
  refresh_what();
14
}
15

16
// PyTorch-style error message
17
// Error::Error(SourceLocation source_location, const std::string& msg)
18
// NB: This is defined in Logging.cpp for access to GetFetchStackTrace
19

20
// Caffe2-style error message
21
Error::Error(
22
    const char* file,
23
    const uint32_t line,
24
    const char* condition,
25
    const std::string& msg,
26
    Backtrace backtrace,
27
    const void* caller)
28
    : Error(
29
          str("[enforce fail at ",
30
              detail::StripBasename(file),
31
              ":",
32
              line,
33
              "] ",
34
              condition,
35
              ". ",
36
              msg),
37
          std::move(backtrace),
38
          caller) {}
39

40
std::string Error::compute_what(bool include_backtrace) const {
41
  std::ostringstream oss;
42

43
  oss << msg_;
44

45
  if (context_.size() == 1) {
46
    // Fold error and context in one line
47
    oss << " (" << context_[0] << ")";
48
  } else {
49
    for (const auto& c : context_) {
50
      oss << "\n  " << c;
51
    }
52
  }
53

54
  if (include_backtrace && backtrace_) {
55
    oss << "\n" << backtrace_->get();
56
  }
57

58
  return oss.str();
59
}
60

61
const Backtrace& Error::backtrace() const {
62
  return backtrace_;
63
}
64

65
const char* Error::what() const noexcept {
66
  return what_
67
      .ensure([this] {
68
        try {
69
          return compute_what(/*include_backtrace*/ true);
70
        } catch (...) {
71
          // what() is noexcept, we need to return something here.
72
          return std::string{"<Error computing Error::what()>"};
73
        }
74
      })
75
      .c_str();
76
}
77

78
void Error::refresh_what() {
79
  // Do not compute what_ eagerly, as it would trigger the computation of the
80
  // backtrace. Instead, invalidate it, it will be computed on first access.
81
  // refresh_what() is only called by non-const public methods which are not
82
  // supposed to be called concurrently with any other method, so it is safe to
83
  // invalidate here.
84
  what_.reset();
85
  what_without_backtrace_ = compute_what(/*include_backtrace*/ false);
86
}
87

88
void Error::add_context(std::string new_msg) {
89
  context_.push_back(std::move(new_msg));
90
  // TODO: Calling add_context O(n) times has O(n^2) cost.  We can fix
91
  // this perf problem by populating the fields lazily... if this ever
92
  // actually is a problem.
93
  // NB: If you do fix this, make sure you do it in a thread safe way!
94
  // what() is almost certainly expected to be thread safe even when
95
  // accessed across multiple threads
96
  refresh_what();
97
}
98

99
namespace detail {
100

101
void torchCheckFail(
102
    const char* func,
103
    const char* file,
104
    uint32_t line,
105
    const std::string& msg) {
106
  throw ::c10::Error({func, file, line}, msg);
107
}
108

109
void torchCheckFail(
110
    const char* func,
111
    const char* file,
112
    uint32_t line,
113
    const char* msg) {
114
  throw ::c10::Error({func, file, line}, msg);
115
}
116

117
void torchInternalAssertFail(
118
    const char* func,
119
    const char* file,
120
    uint32_t line,
121
    const char* condMsg,
122
    const char* userMsg) {
123
  torchCheckFail(func, file, line, c10::str(condMsg, userMsg));
124
}
125

126
// This should never be called. It is provided in case of compilers
127
// that don't do any dead code stripping in debug builds.
128
void torchInternalAssertFail(
129
    const char* func,
130
    const char* file,
131
    uint32_t line,
132
    const char* condMsg,
133
    const std::string& userMsg) {
134
  torchCheckFail(func, file, line, c10::str(condMsg, userMsg));
135
}
136

137
} // namespace detail
138

139
namespace WarningUtils {
140

141
namespace {
142
WarningHandler* getBaseHandler() {
143
  static WarningHandler base_warning_handler_ = WarningHandler();
144
  return &base_warning_handler_;
145
}
146

147
class ThreadWarningHandler {
148
 public:
149
  ThreadWarningHandler() = delete;
150

151
  static WarningHandler* get_handler() {
152
    if (!warning_handler_) {
153
      warning_handler_ = getBaseHandler();
154
    }
155
    return warning_handler_;
156
  }
157

158
  static void set_handler(WarningHandler* handler) {
159
    warning_handler_ = handler;
160
  }
161

162
 private:
163
  static thread_local WarningHandler* warning_handler_;
164
};
165

166
thread_local WarningHandler* ThreadWarningHandler::warning_handler_ = nullptr;
167

168
} // namespace
169

170
void set_warning_handler(WarningHandler* handler) noexcept(true) {
171
  ThreadWarningHandler::set_handler(handler);
172
}
173

174
WarningHandler* get_warning_handler() noexcept(true) {
175
  return ThreadWarningHandler::get_handler();
176
}
177

178
bool warn_always = false;
179

180
void set_warnAlways(bool setting) noexcept(true) {
181
  warn_always = setting;
182
}
183

184
bool get_warnAlways() noexcept(true) {
185
  return warn_always;
186
}
187

188
WarnAlways::WarnAlways(bool setting /*=true*/)
189
    : prev_setting(get_warnAlways()) {
190
  set_warnAlways(setting);
191
}
192

193
WarnAlways::~WarnAlways() {
194
  set_warnAlways(prev_setting);
195
}
196

197
} // namespace WarningUtils
198

199
void warn(const Warning& warning) {
200
  WarningUtils::ThreadWarningHandler::get_handler()->process(warning);
201
}
202

203
Warning::Warning(
204
    warning_variant_t type,
205
    const SourceLocation& source_location,
206
    std::string msg,
207
    const bool verbatim)
208
    : type_(type),
209
      source_location_(source_location),
210
      msg_(std::move(msg)),
211
      verbatim_(verbatim) {}
212

213
Warning::Warning(
214
    warning_variant_t type,
215
    SourceLocation source_location,
216
    detail::CompileTimeEmptyString msg,
217
    const bool verbatim)
218
    : Warning(type, source_location, "", verbatim) {}
219

220
Warning::Warning(
221
    warning_variant_t type,
222
    SourceLocation source_location,
223
    const char* msg,
224
    const bool verbatim)
225
    : type_(type),
226
      source_location_(source_location),
227
      msg_(std::string(msg)),
228
      verbatim_(verbatim) {}
229

230
Warning::warning_variant_t Warning::type() const {
231
  return type_;
232
}
233

234
const SourceLocation& Warning::source_location() const {
235
  return source_location_;
236
}
237

238
const std::string& Warning::msg() const {
239
  return msg_;
240
}
241

242
bool Warning::verbatim() const {
243
  return verbatim_;
244
}
245

246
void WarningHandler::process(const Warning& warning) {
247
  LOG_AT_FILE_LINE(
248
      WARNING, warning.source_location().file, warning.source_location().line)
249
      << "Warning: " << warning.msg() << " (function "
250
      << warning.source_location().function << ")";
251
}
252

253
std::string GetExceptionString(const std::exception& e) {
254
#ifdef __GXX_RTTI
255
  return demangle(typeid(e).name()) + ": " + e.what();
256
#else
257
  return std::string("Exception (no RTTI available): ") + e.what();
258
#endif // __GXX_RTTI
259
}
260

261
} // namespace c10
262

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.