1
#include <torch/csrc/jit/ir/graph_utils.h>
6
TypePtr getTensorType(const at::Tensor& t, bool complete) {
7
auto r = TensorType::create(t);
9
r = r->dimensionedOnly();
14
TypePtr inferShapeAndTypeForInput(
16
Stack::const_iterator& s_iter,
17
const Stack::const_iterator& s_iter_end,
19
if (auto tuple_type = input_type->cast<TupleType>()) {
20
std::vector<TypePtr> types;
21
for (const auto& sub_type : tuple_type->containedTypes()) {
22
TORCH_INTERNAL_ASSERT(s_iter != s_iter_end);
24
inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete));
26
return TupleType::create(types);
27
} else if (auto list_type = input_type->cast<ListType>()) {
28
const TypePtr& sub_type = list_type->getElementType();
30
inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete);
31
return ListType::create(elem_type);
32
} else if (auto tensor_type = input_type->cast<TensorType>()) {
33
auto type = getTensorType(s_iter->toTensor(), complete);
36
} else if (auto optional_type = input_type->cast<OptionalType>()) {
37
const TypePtr& sub_type = optional_type->getElementType();
39
inferShapeAndTypeForInput(sub_type, s_iter, s_iter_end, complete);
40
return OptionalType::create(elem_type);
42
// Primitive type, keep as is.
48
void setInputTensorTypes(
52
const std::vector<int>& param_count_list) {
53
at::ArrayRef<Value*> input_values = g.inputs();
54
auto s_iter = stack.begin();
56
if (!param_count_list.empty()) {
57
TORCH_INTERNAL_ASSERT(
58
input_values.size() == param_count_list.size(),
61
" vs param_count_list:",
62
param_count_list.size());
64
for (auto v : input_values) {
65
// Leave packed param types alone. This is needed for downstream passes
66
// (like alias analysis) to work properly. This will be unpacked later
67
// in unpackQuantizedWeights.
68
if (auto named_type = v->type()->cast<c10::NamedType>()) {
69
if (auto qualname = named_type->name()) {
70
if (getCustomClass(qualname->qualifiedName())) {
71
if (param_count_list.empty()) {
72
AT_ASSERT(s_iter != stack.end());
75
if (param_count_list[list_idx] > 0) {
76
AT_ASSERT(s_iter != stack.end());
78
s_iter += param_count_list[list_idx];
86
inferShapeAndTypeForInput(v->type(), s_iter, stack.end(), complete);