pytorch

Форк
0
/
inline_loop_condition.cpp 
63 строки · 1.7 Кб
1
#include <functional>
2
#include <memory>
3
#include <string>
4

5
#include <torch/csrc/Export.h>
6
#include <torch/csrc/jit/frontend/inline_loop_condition.h>
7
#include <torch/csrc/jit/ir/ir.h>
8

9
namespace torch::jit {
10

11
void InlineBlockBeforeNode(Node* before_node, Block* block) {
12
  for (auto it = block->nodes().begin(); it != block->nodes().end();) {
13
    auto block_node = *it++;
14
    block_node->moveBefore(before_node);
15
  }
16
}
17

18
// The loop node is initially emitted as:
19
// Loop(max_trip_count)
20
//    block0(loop_counter) {
21
//      <body>
22
//    }
23
//    block1 {
24
//      <loop condition>
25
//      -> (condition)
26
//    }
27
// Here, we inline the loop condition and convert the loop to the form:
28
// Loop(max_trip_count, start_condition)
29
//    block0(loop_counter, loop_carried_block*) {
30
//      <body>
31
//       BlockExit(continue_condition, loop_carried_block*)
32
//    }
33
static void inlineLoopCondition(Node* n) {
34
  Block* body_block = n->blocks().at(0);
35

36
  auto pre_header = n->blocks().at(1);
37
  auto temp_block = n->addBlock();
38
  temp_block->cloneFrom(pre_header, [](Value* v) { return v; });
39
  InlineBlockBeforeNode(n, temp_block);
40
  n->insertInput(/*start_condition_index*/ 1, temp_block->outputs().at(0));
41
  n->eraseBlock(2);
42

43
  InlineBlockBeforeNode(body_block->return_node(), pre_header);
44
  body_block->return_node()->insertInput(0, pre_header->outputs().at(0));
45
  n->eraseBlock(1);
46
}
47

48
static void inlineLoopCondition(Block* block) {
49
  for (Node* n : block->nodes()) {
50
    for (Block* b : n->blocks()) {
51
      inlineLoopCondition(b);
52
    }
53
    if (n->kind() == prim::Loop) {
54
      inlineLoopCondition(n);
55
    }
56
  }
57
}
58

59
void InlineLoopCondition(std::shared_ptr<Graph>& graph) {
60
  inlineLoopCondition(graph->block());
61
}
62

63
} // namespace torch::jit
64

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.