pytorch

Форк
0
/
ivalue.cpp 
63 строки · 2.3 Кб
1
#include <gtest/gtest.h>
2

3
#include <ATen/core/ivalue.h>
4

5
#include <c10/util/flat_hash_map.h>
6
#include <c10/util/irange.h>
7
#include <c10/util/tempfile.h>
8

9
#include <torch/torch.h>
10

11
#include <test/cpp/api/support.h>
12

13
#include <cstdio>
14
#include <memory>
15
#include <sstream>
16
#include <string>
17
#include <vector>
18

19
using namespace torch::test;
20
using namespace torch::nn;
21
using namespace torch::optim;
22

23
TEST(IValueTest, DeepcopyTensors) {
24
  torch::Tensor t0 = torch::randn({2, 3});
25
  torch::Tensor t1 = torch::randn({3, 4});
26
  torch::Tensor t2 = t0.detach();
27
  torch::Tensor t3 = t0;
28
  torch::Tensor t4 = t1.as_strided({2, 3}, {3, 1}, 2);
29
  std::vector<torch::Tensor> tensor_vector = {t0, t1, t2, t3, t4};
30
  c10::List<torch::Tensor> tensor_list(tensor_vector);
31
  torch::IValue tensor_list_ivalue(tensor_list);
32

33
  c10::IValue::CompIdentityIValues ivalue_compare;
34

35
  // Make sure our setup configuration is correct
36
  ASSERT_TRUE(ivalue_compare(tensor_list[0].get(), tensor_list[3].get()));
37
  ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[1].get()));
38
  ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[2].get()));
39
  ASSERT_FALSE(ivalue_compare(tensor_list[1].get(), tensor_list[4].get()));
40
  ASSERT_TRUE(tensor_list[0].get().isAliasOf(tensor_list[2].get()));
41

42
  c10::IValue copied_ivalue = tensor_list_ivalue.deepcopy();
43
  c10::List<torch::IValue> copied_list = copied_ivalue.toList();
44

45
  // Make sure our setup configuration is correct
46
  ASSERT_TRUE(ivalue_compare(copied_list[0].get(), copied_list[3].get()));
47
  ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[1].get()));
48
  ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[2].get()));
49
  ASSERT_FALSE(ivalue_compare(copied_list[1].get(), copied_list[4].get()));
50
  // NOTE: this is actually incorrect. Ideally, these _should_ be aliases.
51
  ASSERT_FALSE(copied_list[0].get().isAliasOf(copied_list[2].get()));
52

53
  ASSERT_TRUE(copied_list[0].get().toTensor().allclose(
54
      tensor_list[0].get().toTensor()));
55
  ASSERT_TRUE(copied_list[1].get().toTensor().allclose(
56
      tensor_list[1].get().toTensor()));
57
  ASSERT_TRUE(copied_list[2].get().toTensor().allclose(
58
      tensor_list[2].get().toTensor()));
59
  ASSERT_TRUE(copied_list[3].get().toTensor().allclose(
60
      tensor_list[3].get().toTensor()));
61
  ASSERT_TRUE(copied_list[4].get().toTensor().allclose(
62
      tensor_list[4].get().toTensor()));
63
}
64

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.