pytorch

Форк
0
/
python_tree_views.cpp 
409 строк · 15.5 Кб
1
#include <torch/csrc/jit/python/python_tree_views.h>
2

3
#include <torch/csrc/jit/frontend/tree_views.h>
4

5
#include <pybind11/pybind11.h>
6
#include <pybind11/stl.h>
7
#include <torch/csrc/utils/pybind.h>
8

9
#include <sstream>
10

11
namespace py = pybind11;
12

13
namespace torch::jit {
14

15
c10::optional<std::string> maybeConvertToString(const py::object& obj) {
16
  if (obj.is_none()) {
17
    return c10::nullopt;
18
  }
19
  std::stringstream ss;
20
  ss << py::str(obj);
21
  return ss.str();
22
}
23

24
struct SourceRangeFactory {
25
  SourceRangeFactory(
26
      std::string text,
27
      const py::object& filename,
28
      size_t file_lineno,
29
      size_t leading_whitespace_chars)
30
      : source_(std::make_shared<Source>(
31
            std::move(text),
32
            maybeConvertToString(filename),
33
            file_lineno)),
34
        leading_whitespace_chars_(leading_whitespace_chars) {}
35

36
  SourceRange create(int line, int start_col, int end_col) {
37
    auto [start_byte_offset, end_byte_offset] = line_col_to_byte_offs(
38
        line,
39
        // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
40
        start_col + leading_whitespace_chars_,
41
        // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
42
        end_col + leading_whitespace_chars_);
43
    return SourceRange(source_, start_byte_offset, end_byte_offset);
44
  }
45

46
  std::tuple<size_t, size_t> line_col_to_byte_offs(
47
      int line,
48
      int start_col,
49
      int end_col) {
50
    // lines are counted from 1.
51
    line--;
52
    auto line_start = source_->offset_for_line(line);
53
    return std::make_tuple<size_t, size_t>(
54
        line_start + start_col, line_start + end_col);
55
  }
56

57
  std::shared_ptr<Source> source_;
58
  std::vector<size_t> line_len_prefix_sum_;
59
  size_t leading_whitespace_chars_;
60
};
61

62
template <typename T>
63
List<T> wrap_list(const SourceRange& fallback_pos, std::vector<T>&& vec) {
64
  if (vec.empty())
65
    return List<T>::create(fallback_pos, std::move(vec));
66
  return List<T>::create(vec.front().range(), std::move(vec));
67
}
68

69
template <typename T>
70
Maybe<T> wrap_maybe(const SourceRange& fallback_pos, T* val) {
71
  return val ? Maybe<T>::create(val->range(), *val)
72
             : Maybe<T>::create(fallback_pos);
73
}
74

75
void initTreeViewBindings(PyObject* module) {
76
  auto _C = py::handle(module).cast<py::module>();
77
  auto m = _C.def_submodule("_jit_tree_views");
78

79
  py::class_<SourceRange>(m, "SourceRange")
80
      .def(
81
          "highlight",
82
          [](const SourceRange& self) {
83
            std::ostringstream stream;
84
            self.highlight(stream);
85
            return stream.str();
86
          })
87
      .def("__repr__", [](const SourceRange& self) { return self.str(); })
88
      .def(
89
          "__str__",
90
          [](const SourceRange& self) {
91
            return "SourceRange at:\n" + self.str();
92
          })
93
      .def_property_readonly("start", &SourceRange::start)
94
      .def_property_readonly("end", &SourceRange::end);
95
  py::class_<SourceRangeFactory>(m, "SourceRangeFactory")
96
      .def(py::init<std::string&&, py::object, size_t, size_t>())
97
      .def("make_range", &SourceRangeFactory::create)
98
      .def(
99
          "make_raw_range",
100
          [](const SourceRangeFactory& self, size_t start, size_t end) {
101
            return SourceRange(self.source_, start, end);
102
          })
103
      .def_property_readonly("source", [](const SourceRangeFactory& self) {
104
        auto text_view = self.source_->text_str().str();
105
        return text_view;
106
      });
107

108
  py::class_<TreeView>(m, "TreeView")
109
      .def("range", &TreeView::range)
110
      .def(
111
          "__str__",
112
          [](const TreeView& tree) {
113
            std::ostringstream stream;
114
            stream << tree.get();
115
            return stream.str();
116
          })
117
      .def("dump", [](const TreeView& tree) { tree.dump(); });
118

119
  py::class_<Ident, TreeView>(m, "Ident")
120
      .def(py::init(&Ident::create))
121
      .def_property_readonly(
122
          "name", [](const Ident& self) { return self.name(); });
123

124
  py::class_<Param, TreeView>(m, "Param")
125
      .def(py::init([](const Expr& type, const Ident& name, bool kwarg_only) {
126
        return Param::create(
127
            name.range(),
128
            name,
129
            Maybe<Expr>::create(type.range(), type),
130
            Maybe<Expr>::create(name.range()),
131
            kwarg_only);
132
      }))
133
      .def(py::init(
134
          [](const Maybe<Expr>& type, const Ident& name, bool kwarg_only) {
135
            return Param::create(
136
                name.range(),
137
                name,
138
                type,
139
                Maybe<Expr>::create(name.range()),
140
                kwarg_only);
141
          }));
142
  py::class_<Attribute, TreeView>(m, "Attribute")
143
      .def(py::init([](const Ident& name, const Expr& value) {
144
        return Attribute::create(name.range(), name, value);
145
      }));
146
  m.def("TrueLiteral", [](const SourceRange& range) {
147
    return Expr(Compound::create(TK_TRUE, range, {}));
148
  });
149
  m.def("FalseLiteral", [](const SourceRange& range) {
150
    return Expr(Compound::create(TK_FALSE, range, {}));
151
  });
152
  m.def("NoneLiteral", [](const SourceRange& range) {
153
    return Expr(Compound::create(TK_NONE, range, {}));
154
  });
155

156
  py::class_<Stmt, TreeView>(m, "Stmt") // NOLINT(bugprone-unused-raii)
157
      .def(py::init([](const TreeView& thing) { return Stmt(thing.get()); }));
158
  py::class_<Expr, TreeView>(m, "Expr"); // NOLINT(bugprone-unused-raii)
159
  py::class_<Def, TreeView>(m, "Def")
160
      .def(py::init(
161
          [](const Ident& name, const Decl& decl, std::vector<Stmt> body) {
162
            const auto& r = name.range();
163
            return Def::create(r, name, decl, wrap_list(r, std::move(body)));
164
          }))
165
      .def("decl", [](const Def& def) { return def.decl(); })
166
      .def("name", [](const Def& def) { return def.name(); });
167
  py::class_<Property, TreeView>(m, "Property")
168
      .def(py::init([](const SourceRange& r,
169
                       const Ident& name,
170
                       const Def& getter,
171
                       Def* setter) {
172
        return Property::create(r, name, getter, wrap_maybe(r, setter));
173
      }))
174
      .def("name", [](const Property& property) { return property.name(); })
175
      .def(
176
          "getter_name",
177
          [](const Property& property) { return property.getter().name(); })
178
      .def("setter_name", [](const Property& property) {
179
        if (property.setter().present()) {
180
          return c10::optional<Ident>(property.setter().get().name());
181
        }
182

183
        return c10::optional<Ident>(c10::nullopt);
184
      });
185

186
  py::class_<ClassDef, TreeView>(m, "ClassDef")
187
      .def(py::init([](const Ident& name,
188
                       std::vector<Stmt> body,
189
                       std::vector<Property> props,
190
                       std::vector<Assign> assigns) {
191
        const auto& r = name.range();
192
        return ClassDef::create(
193
            r,
194
            name,
195
            Maybe<Expr>::create(r),
196
            wrap_list(r, std::move(body)),
197
            wrap_list(r, std::move(props)),
198
            wrap_list(r, std::move(assigns)));
199
      }));
200

201
  py::class_<Decl, TreeView>(m, "Decl").def(py::init(
202
      [](const SourceRange& r, std::vector<Param> params, Expr* return_type) {
203
        return Decl::create(
204
            r, wrap_list(r, std::move(params)), wrap_maybe(r, return_type));
205
      }));
206

207
  py::class_<Delete, Stmt>(m, "Delete")
208
      .def(py::init([](const SourceRange& range, std::vector<Expr> targets) {
209
        return Delete::create(range, wrap_list(range, std::move(targets)));
210
      }));
211

212
  py::class_<WithItem, Expr>(m, "WithItem")
213
      .def(py::init([](const SourceRange& range, const Expr& target, Var* var) {
214
        return WithItem::create(range, target, wrap_maybe(range, var));
215
      }));
216

217
  py::class_<Assign, Stmt>(m, "Assign")
218
      .def(py::init([](std::vector<Expr> lhs, const Expr& rhs) {
219
        auto li = wrap_list(rhs.range(), std::move(lhs));
220
        return Assign::create(
221
            li.range(),
222
            li,
223
            Maybe<Expr>::create(rhs.range(), rhs),
224
            Maybe<Expr>::create(li.range()));
225
      }))
226
      .def(py::init([](std::vector<Expr> lhs, const Expr& rhs, Expr* type) {
227
        auto li = wrap_list(rhs.range(), std::move(lhs));
228
        return Assign::create(
229
            li.range(),
230
            li,
231
            Maybe<Expr>::create(rhs.range(), rhs),
232
            wrap_maybe(li.range(), type));
233
      }));
234
  py::class_<AugAssign, Stmt>(m, "AugAssign")
235
      .def(py::init(
236
          [](const Expr& lhs, const std::string& kind_str, const Expr& rhs) {
237
            const auto& r = lhs.range();
238
            auto kind =
239
                AugAssignKind(Compound::create(stringToKind(kind_str), r, {}));
240
            return AugAssign::create(r, lhs, kind, rhs);
241
          }));
242
  py::class_<Return, Stmt>(m, "Return")
243
      .def(py::init([](const SourceRange& range, Expr* value) {
244
        return Return::create(
245
            range, value ? *value : Expr(Compound::create(TK_NONE, range, {})));
246
      }));
247
  py::class_<Raise, Stmt>(m, "Raise")
248
      .def(py::init([](const SourceRange& range, const Expr& expr) {
249
        return Raise::create(range, expr);
250
      }));
251
  py::class_<Assert, Stmt>(m, "Assert")
252
      .def(py::init([](const SourceRange& range, const Expr& test, Expr* msg) {
253
        return Assert::create(range, test, wrap_maybe(range, msg));
254
      }));
255
  py::class_<Pass, Stmt>(m, "Pass").def(
256
      py::init([](const SourceRange& range) { return Pass::create(range); }));
257
  py::class_<Break, Stmt>(m, "Break")
258
      .def(py::init(
259
          [](const SourceRange& range) { return Break::create(range); }));
260
  py::class_<Continue, Stmt>(m, "Continue")
261
      .def(py::init(
262
          [](const SourceRange& range) { return Continue::create(range); }));
263
  py::class_<Dots, Expr>(m, "Dots").def(
264
      py::init([](const SourceRange& range) { return Dots::create(range); }));
265
  py::class_<If, Stmt>(m, "If").def(
266
      py::init([](const SourceRange& range,
267
                  const Expr& cond,
268
                  std::vector<Stmt> true_branch,
269
                  std::vector<Stmt> false_branch) {
270
        return If::create(
271
            range,
272
            cond,
273
            wrap_list(range, std::move(true_branch)),
274
            wrap_list(range, std::move(false_branch)));
275
      }));
276
  py::class_<While, Stmt>(m, "While")
277
      .def(py::init([](const SourceRange& range,
278
                       const Expr& cond,
279
                       std::vector<Stmt> body) {
280
        return While::create(range, cond, wrap_list(range, std::move(body)));
281
      }));
282
  py::class_<With, Stmt>(m, "With").def(
283
      py::init([](const SourceRange& range,
284
                  std::vector<WithItem> targets,
285
                  std::vector<Stmt> body) {
286
        return With::create(
287
            range,
288
            wrap_list(range, std::move(targets)),
289
            wrap_list(range, std::move(body)));
290
      }));
291
  py::class_<For, Stmt>(m, "For").def(py::init([](const SourceRange& range,
292
                                                  std::vector<Expr>& targets,
293
                                                  std::vector<Expr>& itrs,
294
                                                  std::vector<Stmt> body) {
295
    return For::create(
296
        range,
297
        wrap_list(range, std::move(targets)),
298
        wrap_list(range, std::move(itrs)),
299
        wrap_list(range, std::move(body)));
300
  }));
301
  py::class_<ExprStmt, Stmt>(m, "ExprStmt").def(py::init([](const Expr& expr) {
302
    return ExprStmt::create(expr.range(), expr);
303
  }));
304

305
  py::class_<Var, Expr>(m, "Var")
306
      .def(py::init(
307
          [](const Ident& name) { return Var::create(name.range(), name); }))
308
      .def_property_readonly("name", [](const Var& var) { return var.name(); });
309
  py::class_<BinOp, Expr>(m, "BinOp")
310
      .def(py::init(
311
          [](const std::string& kind, const Expr& lhs, const Expr& rhs) {
312
            return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs);
313
          }));
314
  // NB: we take range here, because unary ops precede their exprs, so we need
315
  // to include them
316
  py::class_<UnaryOp, Expr>(m, "UnaryOp")
317
      .def(py::init([](const SourceRange& range,
318
                       const std::string& kind,
319
                       const Expr& expr) {
320
        auto resolved_kind = stringToKind(kind);
321
        resolved_kind = resolved_kind == '-' ? TK_UNARY_MINUS : resolved_kind;
322
        return UnaryOp::create(range, resolved_kind, expr);
323
      }));
324
  py::class_<Const, Expr>(m, "Const")
325
      .def(py::init([](const SourceRange& range, const std::string& value) {
326
        return Const::create(range, value);
327
      }));
328
  py::class_<StringLiteral, Expr>(m, "StringLiteral")
329
      .def(py::init([](const SourceRange& range, const std::string& value) {
330
        return StringLiteral::create(range, value);
331
      }));
332
  py::class_<Apply, Expr>(m, "Apply")
333
      .def(py::init([](const Expr& expr,
334
                       std::vector<Expr> args,
335
                       std::vector<Attribute> kwargs) {
336
        const auto& r = expr.range();
337
        return Apply::create(
338
            expr.range(),
339
            expr,
340
            wrap_list(r, std::move(args)),
341
            wrap_list(r, std::move(kwargs)));
342
      }));
343
  py::class_<Select, Expr>(m, "Select")
344
      .def(py::init([](const Expr& expr, const Ident& field) {
345
        return Select::create(expr.range(), expr, field);
346
      }));
347
  py::class_<TernaryIf, Expr>(m, "TernaryIf")
348
      .def(py::init(
349
          [](const Expr& cond, const Expr& true_expr, const Expr& false_expr) {
350
            return TernaryIf::create(cond.range(), cond, true_expr, false_expr);
351
          }));
352
  py::class_<ListComp, Expr>(m, "ListComp")
353
      .def(py::init([](const SourceRange& range,
354
                       const Expr& elt,
355
                       const Expr& target,
356
                       const Expr& iter) {
357
        return ListComp::create(range, elt, target, iter);
358
      }));
359
  py::class_<DictComp, Expr>(m, "DictComp")
360
      .def(py::init([](const SourceRange& range,
361
                       const Expr& key,
362
                       const Expr& value,
363
                       const Expr& target,
364
                       const Expr& iter) {
365
        return DictComp::create(range, key, value, target, iter);
366
      }));
367
  py::class_<ListLiteral, Expr>(m, "ListLiteral")
368
      .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
369
        return ListLiteral::create(range, wrap_list(range, std::move(args)));
370
      }));
371
  py::class_<TupleLiteral, Expr>(m, "TupleLiteral")
372
      .def(py::init([](const SourceRange& range, std::vector<Expr> args) {
373
        return TupleLiteral::create(range, wrap_list(range, std::move(args)));
374
      }));
375
  py::class_<DictLiteral, Expr>(m, "DictLiteral")
376
      .def(py::init([](const SourceRange& range,
377
                       std::vector<Expr> keys,
378
                       std::vector<Expr> values) {
379
        return DictLiteral::create(
380
            range,
381
            wrap_list(range, std::move(keys)),
382
            wrap_list(range, std::move(values)));
383
      }));
384
  py::class_<Subscript, Expr>(m, "Subscript")
385
      .def(py::init([](const Expr& base, std::vector<Expr> subscript_exprs) {
386
        return Subscript::create(
387
            base.range(),
388
            base,
389
            wrap_list(base.range(), std::move(subscript_exprs)));
390
      }));
391
  py::class_<SliceExpr, Expr>(m, "SliceExpr")
392
      .def(py::init(
393
          [](const SourceRange& range, Expr* lower, Expr* upper, Expr* step) {
394
            return SliceExpr::create(
395
                range,
396
                wrap_maybe(range, lower),
397
                wrap_maybe(range, upper),
398
                wrap_maybe(range, step));
399
          }));
400
  py::class_<Starred, Expr>(m, "Starred")
401
      .def(py::init([](const SourceRange& range, const Expr& expr) {
402
        return Starred::create(range, expr);
403
      }));
404
  py::class_<Maybe<Expr>, TreeView>(m, "EmptyTypeAnnotation")
405
      .def(py::init(
406
          [](const SourceRange& range) { return Maybe<Expr>::create(range); }));
407
}
408

409
} // namespace torch::jit
410

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.