pytorch

Форк
0
42 строки · 1.2 Кб
1
#include <c10/util/Exception.h>
2
#include <torch/csrc/jit/mobile/model_tracer/TensorUtils.h>
3

4
namespace torch {
5
namespace jit {
6
namespace mobile {
7
void for_each_tensor_in_ivalue(
8
    const c10::IValue& iv,
9
    std::function<void(const ::at::Tensor&)> const& func) {
10
  const bool is_leaf_type = iv.isString() || iv.isNone() || iv.isScalar() ||
11
      iv.isDouble() || iv.isInt() || iv.isBool() || iv.isDevice() ||
12
      iv.isIntList() || iv.isDoubleList() || iv.isBoolList();
13
  if (is_leaf_type) {
14
    // Do Nothing.
15
    return;
16
  }
17

18
  if (iv.isTensor()) {
19
    func(iv.toTensor());
20
  } else if (iv.isTuple()) {
21
    c10::intrusive_ptr<at::ivalue::Tuple> tup_ptr = iv.toTuple();
22
    for (const auto& e : tup_ptr->elements()) {
23
      for_each_tensor_in_ivalue(e, func);
24
    }
25
  } else if (iv.isList()) {
26
    c10::List<c10::IValue> l = iv.toList();
27
    for (auto&& i : l) {
28
      c10::IValue item = i;
29
      for_each_tensor_in_ivalue(item, func);
30
    }
31
  } else if (iv.isGenericDict()) {
32
    c10::Dict<c10::IValue, c10::IValue> dict = iv.toGenericDict();
33
    for (auto& it : dict) {
34
      for_each_tensor_in_ivalue(it.value(), func);
35
    }
36
  } else {
37
    AT_ERROR("Unhandled type of IValue. Got ", iv.tagKind());
38
  }
39
}
40
} // namespace mobile
41
} // namespace jit
42
} // namespace torch
43

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.