pytorch

Форк
0
/
graph_utils.cpp 
93 строки · 2.9 Кб
1
#include <torch/csrc/jit/ir/graph_utils.h>
2

3
namespace torch {
4
namespace jit {
5

6
TypePtr getTensorType(const at::Tensor& t, bool complete) {
7
  auto r = TensorType::create(t);
8
  if (!complete) {
9
    r = r->dimensionedOnly();
10
  }
11
  return r;
12
}
13

14
TypePtr inferShapeAndTypeForInput(
15
    TypePtr input_type,
16
    Stack::const_iterator& s_iter,
17
    const Stack::const_iterator& s_iter_end,
18
    bool complete) {
19
  if (auto tuple_type = input_type->cast<TupleType>()) {
20
    std::vector<TypePtr> types;
21
    for (const auto& sub_type : tuple_type->containedTypes()) {
22
      TORCH_INTERNAL_ASSERT(s_iter != s_iter_end);
23
      types.emplace_back(
24
          inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete));
25
    }
26
    return TupleType::create(types);
27
  } else if (auto list_type = input_type->cast<ListType>()) {
28
    const TypePtr& sub_type = list_type->getElementType();
29
    auto elem_type =
30
        inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete);
31
    return ListType::create(elem_type);
32
  } else if (auto tensor_type = input_type->cast<TensorType>()) {
33
    auto type = getTensorType(s_iter->toTensor(), complete);
34
    s_iter++;
35
    return type;
36
  } else if (auto optional_type = input_type->cast<OptionalType>()) {
37
    const TypePtr& sub_type = optional_type->getElementType();
38
    auto elem_type =
39
        inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete);
40
    return OptionalType::create(elem_type);
41
  } else {
42
    // Primitive type, keep as is.
43
    s_iter++;
44
    return input_type;
45
  }
46
}
47

48
void setInputTensorTypes(
49
    Graph& g,
50
    const Stack& stack,
51
    bool complete,
52
    const std::vector<int>& param_count_list) {
53
  at::ArrayRef<Value*> input_values = g.inputs();
54
  auto s_iter = stack.begin();
55
  size_t list_idx = 0;
56
  if (!param_count_list.empty()) {
57
    TORCH_INTERNAL_ASSERT(
58
        input_values.size() == param_count_list.size(),
59
        " input_values:",
60
        input_values.size(),
61
        " vs param_count_list:",
62
        param_count_list.size());
63
  }
64
  for (auto v : input_values) {
65
    // Leave packed param types alone. This is needed for downstream passes
66
    // (like alias analysis) to work properly. This will be unpacked later
67
    // in unpackQuantizedWeights.
68
    if (auto named_type = v->type()->cast<c10::NamedType>()) {
69
      if (auto qualname = named_type->name()) {
70
        if (getCustomClass(qualname->qualifiedName())) {
71
          if (param_count_list.empty()) {
72
            AT_ASSERT(s_iter != stack.end());
73
            s_iter++;
74
          } else {
75
            if (param_count_list[list_idx] > 0) {
76
              AT_ASSERT(s_iter != stack.end());
77
            }
78
            s_iter += param_count_list[list_idx];
79
          }
80
          list_idx++;
81
          continue;
82
        }
83
      }
84
    }
85
    auto type =
86
        inferShapeAndTypeForInput(v->type(), s_iter, stack.end(), complete);
87
    v->setType(type);
88
    list_idx++;
89
  }
90
}
91

92
} // namespace jit
93
} // namespace torch
94

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.