1
#include <torch/csrc/utils/tensor_flatten.h>
4
#include <unordered_map>
11
std::vector<TensorGroup> take_tensors(
15
std::vector<TensorGroup> results;
16
// an overapproximation, but at least we won't have to copy stuff around
17
results.reserve(tensors.size());
18
std::map<int64_t, TensorGroup> groups;
19
size_t cur_group_size = 0;
21
for (const auto& tensor : tensors) {
22
size_t tensor_size = 0;
23
if (tensor.is_sparse()) {
24
const auto& indices = tensor._indices();
25
const auto& values = tensor._values();
26
tensor_size = indices.numel() * indices.element_size() +
27
values.numel() * indices.element_size();
29
tensor_size = tensor.numel() * tensor.element_size();
32
auto& type_group = groups[static_cast<int64_t>(type_id(tensor))];
33
type_group.tensors.push_back(tensor);
36
cur_group_size += tensor_size;
37
// Regardless the type, the current total size exceeds the limit
38
if (cur_group_size >= size_limit) {
39
// Spill all types to separate groups in results
40
for (auto& entry : groups) {
41
auto& group = entry.second;
42
results.emplace_back(std::move(group));
48
type_group.size += tensor_size;
49
if (type_group.size >= size_limit) {
50
results.emplace_back();
51
std::swap(results.back(), type_group);
55
// End case. Look for any remaining groups and return them.
56
for (auto& entry : groups) {
57
auto& group = entry.second;
58
if (group.tensors.empty()) {
61
results.emplace_back(std::move(group));
66
void reorder_tensors_like(std::vector<Tensor>& tensors, TensorList order) {
67
AT_ASSERT(tensors.size() == order.size());
68
std::unordered_map<size_t, std::vector<size_t>> type_id_to_indices;
69
for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i)
70
type_id_to_indices[type_id(tensors[i])].push_back(i);
72
std::unordered_map<size_t, size_t> type_id_to_type_used;
73
std::vector<Tensor> ordered_tensors;
74
ordered_tensors.reserve(tensors.size());
75
for (auto& tmpl_tensor : order) {
76
size_t tmpl_type_id = type_id(tmpl_tensor);
77
auto& indices = type_id_to_indices[tmpl_type_id];
78
auto& used = type_id_to_type_used[tmpl_type_id];
79
ordered_tensors.push_back(tensors[indices[used++]]);
81
std::swap(tensors, ordered_tensors);
86
at::Tensor get_indices(const at::Tensor& t) {
90
at::Tensor get_values(const at::Tensor& t) {
96
std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(
97
at::TensorList tensors) {
98
auto flat_indices = utils::flatten_dense_tensors(fmap(tensors, &get_indices));
99
auto flat_values = utils::flatten_dense_tensors(fmap(tensors, &get_values));
100
return std::make_pair(flat_indices, flat_values);
103
std::vector<at::Tensor> unflatten_sparse_tensors(
104
const at::Tensor& flat_indices,
105
const at::Tensor& flat_values,
106
at::TensorList tensors) {
111
utils::unflatten_dense_tensors(flat_indices, fmap(tensors, &get_indices));
113
utils::unflatten_dense_tensors(flat_values, fmap(tensors, &get_values));
115
std::vector<at::Tensor> outputs;
116
outputs.reserve(tensors.size());
117
for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i) {
118
auto& ref_t = tensors[i];
120
at::_sparse_coo_tensor_unsafe(indices[i], values[i], ref_t.sizes());
121
outputs.emplace_back(t._coalesced_(ref_t.is_coalesced()));