5
#include <torch/csrc/Export.h>
6
#include <torch/csrc/jit/frontend/inline_loop_condition.h>
7
#include <torch/csrc/jit/ir/ir.h>
11
void InlineBlockBeforeNode(Node* before_node, Block* block) {
12
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
13
auto block_node = *it++;
14
block_node->moveBefore(before_node);
18
// The loop node is initially emitted as:
19
// Loop(max_trip_count)
20
// block0(loop_counter) {
27
// Here, we inline the loop condition and convert the loop to the form:
28
// Loop(max_trip_count, start_condition)
29
// block0(loop_counter, loop_carried_block*) {
31
// BlockExit(continue_condition, loop_carried_block*)
33
static void inlineLoopCondition(Node* n) {
34
Block* body_block = n->blocks().at(0);
36
auto pre_header = n->blocks().at(1);
37
auto temp_block = n->addBlock();
38
temp_block->cloneFrom(pre_header, [](Value* v) { return v; });
39
InlineBlockBeforeNode(n, temp_block);
40
n->insertInput(/*start_condition_index*/ 1, temp_block->outputs().at(0));
43
InlineBlockBeforeNode(body_block->return_node(), pre_header);
44
body_block->return_node()->insertInput(0, pre_header->outputs().at(0));
48
static void inlineLoopCondition(Block* block) {
49
for (Node* n : block->nodes()) {
50
for (Block* b : n->blocks()) {
51
inlineLoopCondition(b);
53
if (n->kind() == prim::Loop) {
54
inlineLoopCondition(n);
59
void InlineLoopCondition(std::shared_ptr<Graph>& graph) {
60
inlineLoopCondition(graph->block());
63
} // namespace torch::jit