pytorch

Форк
0
/
canonicalize_modified_loop.cpp 
68 строк · 2.3 Кб
1
#include <functional>
2
#include <memory>
3
#include <string>
4

5
#include <torch/csrc/Export.h>
6
#include <torch/csrc/jit/frontend/canonicalize_modified_loop.h>
7
#include <torch/csrc/jit/ir/ir.h>
8
#include <torch/csrc/jit/ir/ir_views.h>
9

10
namespace torch::jit {
11

12
// Transforms a Loop that has both a trip count specified and a loop
13
// body condition so that the iter count is no longer specified
14
// and it is recognizable as a python while loop.
15
static void canonicalizeModifiedLoop(Node* n) {
16
  LoopView loop(n);
17
  if (loop.loopType() != LoopView::ModifiedLoop) {
18
    return;
19
  }
20

21
  auto g = n->owningGraph();
22
  WithInsertPoint node_insert(n);
23
  auto zero = g->insertConstant(0);
24
  auto one = g->insertConstant(1);
25
  auto max_trip_count = loop.maxTripCount();
26
  auto condition = g->insert(aten::gt, {max_trip_count, zero});
27
  loop.replaceMaxTripCount(
28
      g->insertConstant(std::numeric_limits<int64_t>::max()));
29

30
  auto inp_condition = toIValue(loop.inputCond());
31
  if (inp_condition == c10::nullopt || inp_condition->toBool() == false) {
32
    condition = g->insert(aten::__and__, {condition, loop.inputCond()});
33
  }
34
  loop.replaceInputCondition(condition);
35
  n->addOutput()->setType(IntType::get());
36
  WithInsertPoint loop_insert(loop.bodyBlock());
37
  n->addInput(zero);
38
  auto new_iter = loop.bodyBlock()->addInput()->setType(IntType::get());
39
  // unset unique name for jitter, its replacement does not have a name
40
  loop.currentTripCount()->setDebugName("")->replaceAllUsesWith(new_iter);
41
  auto inc_iter = g->insert(aten::add, {new_iter, one});
42
  loop.bodyBlock()->registerOutput(inc_iter);
43
  auto less_than_max_trip = g->insert(aten::lt, {inc_iter, max_trip_count});
44
  auto loop_continue = loop.nextCond();
45
  auto new_condition =
46
      g->insert(aten::__and__, {less_than_max_trip, loop_continue});
47
  loop.bodyBlock()->eraseOutput(0);
48
  loop.bodyBlock()->insertOutput(0, new_condition);
49
}
50

51
static void canonicalizeModifiedLoops(Block* block) {
52
  for (Node* n : block->nodes()) {
53
    for (Block* b : n->blocks()) {
54
      canonicalizeModifiedLoops(b);
55
    }
56
    if (n->kind() == prim::Loop) {
57
      canonicalizeModifiedLoop(n);
58
    }
59
  }
60
}
61

62
// Transforms loops so that they can be represented as python
63
// for or while loops
64
TORCH_API void CanonicalizeModifiedLoops(std::shared_ptr<Graph>& graph) {
65
  canonicalizeModifiedLoops(graph->block());
66
}
67

68
} // namespace torch::jit
69

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.