pytorch

Форк
0
/
tensor_flatten.cpp 
127 строк · 3.8 Кб
1
#include <torch/csrc/utils/tensor_flatten.h>
2

3
#include <map>
4
#include <unordered_map>
5

6
namespace torch {
7
namespace utils {
8

9
using namespace at;
10

11
std::vector<TensorGroup> take_tensors(
12
    TensorList tensors,
13
    size_t size_limit,
14
    bool fine_grained) {
15
  std::vector<TensorGroup> results;
16
  // an overapproximation, but at least we won't have to copy stuff around
17
  results.reserve(tensors.size());
18
  std::map<int64_t, TensorGroup> groups;
19
  size_t cur_group_size = 0;
20

21
  for (const auto& tensor : tensors) {
22
    size_t tensor_size = 0;
23
    if (tensor.is_sparse()) {
24
      const auto& indices = tensor._indices();
25
      const auto& values = tensor._values();
26
      tensor_size = indices.numel() * indices.element_size() +
27
          values.numel() * indices.element_size();
28
    } else {
29
      tensor_size = tensor.numel() * tensor.element_size();
30
    }
31

32
    auto& type_group = groups[static_cast<int64_t>(type_id(tensor))];
33
    type_group.tensors.push_back(tensor);
34

35
    if (fine_grained) {
36
      cur_group_size += tensor_size;
37
      // Regardless the type, the current total size exceeds the limit
38
      if (cur_group_size >= size_limit) {
39
        // Spill all types to separate groups in results
40
        for (auto& entry : groups) {
41
          auto& group = entry.second;
42
          results.emplace_back(std::move(group));
43
        }
44
        cur_group_size = 0;
45
        groups.clear();
46
      }
47
    } else {
48
      type_group.size += tensor_size;
49
      if (type_group.size >= size_limit) {
50
        results.emplace_back();
51
        std::swap(results.back(), type_group);
52
      }
53
    }
54
  }
55
  // End case. Look for any remaining groups and return them.
56
  for (auto& entry : groups) {
57
    auto& group = entry.second;
58
    if (group.tensors.empty()) {
59
      continue;
60
    }
61
    results.emplace_back(std::move(group));
62
  }
63
  return results;
64
}
65

66
void reorder_tensors_like(std::vector<Tensor>& tensors, TensorList order) {
67
  AT_ASSERT(tensors.size() == order.size());
68
  std::unordered_map<size_t, std::vector<size_t>> type_id_to_indices;
69
  for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i)
70
    type_id_to_indices[type_id(tensors[i])].push_back(i);
71

72
  std::unordered_map<size_t, size_t> type_id_to_type_used;
73
  std::vector<Tensor> ordered_tensors;
74
  ordered_tensors.reserve(tensors.size());
75
  for (auto& tmpl_tensor : order) {
76
    size_t tmpl_type_id = type_id(tmpl_tensor);
77
    auto& indices = type_id_to_indices[tmpl_type_id];
78
    auto& used = type_id_to_type_used[tmpl_type_id];
79
    ordered_tensors.push_back(tensors[indices[used++]]);
80
  }
81
  std::swap(tensors, ordered_tensors);
82
}
83

84
namespace {
85

86
at::Tensor get_indices(const at::Tensor& t) {
87
  return t._indices();
88
}
89

90
at::Tensor get_values(const at::Tensor& t) {
91
  return t._values();
92
}
93

94
} // namespace
95

96
std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(
97
    at::TensorList tensors) {
98
  auto flat_indices = utils::flatten_dense_tensors(fmap(tensors, &get_indices));
99
  auto flat_values = utils::flatten_dense_tensors(fmap(tensors, &get_values));
100
  return std::make_pair(flat_indices, flat_values);
101
}
102

103
std::vector<at::Tensor> unflatten_sparse_tensors(
104
    const at::Tensor& flat_indices,
105
    const at::Tensor& flat_values,
106
    at::TensorList tensors) {
107
  if (tensors.empty())
108
    return {};
109

110
  auto indices =
111
      utils::unflatten_dense_tensors(flat_indices, fmap(tensors, &get_indices));
112
  auto values =
113
      utils::unflatten_dense_tensors(flat_values, fmap(tensors, &get_values));
114

115
  std::vector<at::Tensor> outputs;
116
  outputs.reserve(tensors.size());
117
  for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i) {
118
    auto& ref_t = tensors[i];
119
    auto t =
120
        at::_sparse_coo_tensor_unsafe(indices[i], values[i], ref_t.sizes());
121
    outputs.emplace_back(t._coalesced_(ref_t.is_coalesced()));
122
  }
123
  return outputs;
124
}
125

126
} // namespace utils
127
} // namespace torch
128

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.