pytorch

Форк
0
/
tensor.cpp 
1260 строк · 43.2 Кб
1
#include <gtest/gtest.h>
2
#include <test/cpp/api/support.h>
3

4
#include <c10/util/irange.h>
5
#include <torch/torch.h>
6

7
#include <cmath>
8
#include <cstddef>
9
#include <vector>
10

11
#include <test/cpp/common/support.h>
12

13
using namespace torch::test;
14

15
template <typename T>
16
bool exactly_equal(at::Tensor left, T right) {
17
  return left.item<T>() == right;
18
}
19

20
template <typename T>
21
bool almost_equal(at::Tensor left, T right, double tolerance = 1e-4) {
22
  return std::abs(left.item<T>() - right) < tolerance;
23
}
24

25
#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_)            \
26
  ASSERT_TRUE(                                                             \
27
      tensor.device().type() == at::Device((device_), (index_)).type());   \
28
  ASSERT_TRUE(                                                             \
29
      tensor.device().index() == at::Device((device_), (index_)).index()); \
30
  ASSERT_EQ(tensor.dtype(), (type_));                                      \
31
  ASSERT_TRUE(tensor.layout() == (layout_))
32

33
TEST(TensorTest, ToDtype) {
34
  auto tensor = at::empty({3, 4});
35
  REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
36

37
  tensor = tensor.to(at::kInt);
38
  REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
39

40
  tensor = tensor.to(at::kChar);
41
  REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kChar, at::kStrided);
42

43
  tensor = tensor.to(at::kDouble);
44
  REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kDouble, at::kStrided);
45

46
  tensor = tensor.to(at::TensorOptions(at::kInt));
47
  REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
48

49
  tensor = tensor.to(at::TensorOptions(at::kChar));
50
  REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kChar, at::kStrided);
51

52
  tensor = tensor.to(at::TensorOptions(at::kDouble));
53
  REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kDouble, at::kStrided);
54
}
55

56
TEST(TensorTest, ToTensorAndTensorAttributes) {
57
  auto tensor = at::empty({3, 4});
58
  REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
59

60
  auto other = at::empty({3, 4}, at::kInt);
61
  tensor = tensor.to(other);
62
  REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
63

64
  other = at::empty({3, 4}, at::kDouble);
65
  tensor = tensor.to(other.dtype());
66
  REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kDouble, at::kStrided);
67
  tensor = tensor.to(other.device());
68
  REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kDouble, at::kStrided);
69

70
  other = at::empty({3, 4}, at::kLong);
71
  tensor = tensor.to(other.device(), other.dtype());
72
  REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kLong, at::kStrided);
73

74
  other = at::empty({3, 4}, at::kInt);
75
  tensor = tensor.to(other.options());
76
  REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
77
}
78

79
// Not currently supported.
80
// TEST(TensorTest, ToLayout) {
81
//   auto tensor = at::empty({3, 4});
82
//   REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
83
//
84
//   tensor = tensor.to(at::kSparse);
85
//   REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kSparse);
86
//
87
//   tensor = tensor.to(at::kStrided);
88
//   REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
89
// }
90

91
TEST(TensorTest, ToOptionsWithRequiresGrad) {
92
  {
93
    // Respects requires_grad
94
    auto tensor = torch::empty({3, 4}, at::requires_grad());
95
    ASSERT_TRUE(tensor.requires_grad());
96

97
    tensor = tensor.to(at::kDouble);
98
    ASSERT_TRUE(tensor.requires_grad());
99

100
    // Throws if requires_grad is set in TensorOptions
101
    // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
102
    ASSERT_THROW(
103
        tensor.to(at::TensorOptions().requires_grad(true)), c10::Error);
104

105
    // Doesn't throw if requires_grad is not set
106
    tensor.to(at::TensorOptions());
107
    tensor.to(at::TensorOptions().requires_grad(false));
108
  }
109
  {
110
    auto tensor = torch::empty({3, 4});
111
    ASSERT_FALSE(tensor.requires_grad());
112

113
    // Respects requires_grad
114
    tensor = tensor.to(at::kDouble);
115
    ASSERT_FALSE(tensor.requires_grad());
116

117
    // Throws if requires_grad is set in TensorOptions
118
    // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
119
    ASSERT_THROW(
120
        tensor.to(at::TensorOptions().requires_grad(true)), c10::Error);
121

122
    // Doesn't throw if requires_grad is not set
123
    tensor.to(at::TensorOptions());
124
    tensor.to(at::TensorOptions().requires_grad(false));
125
  }
126
}
127

128
TEST(TensorTest, ToDoesNotCopyWhenOptionsAreAllTheSame) {
129
  {
130
    auto tensor = at::empty({3, 4}, at::kFloat);
131
    auto hopefully_not_copy = tensor.to(at::kFloat);
132
    ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
133
  }
134
  {
135
    auto tensor = at::empty({3, 4}, at::kFloat);
136
    auto hopefully_not_copy = tensor.to(tensor.options());
137
    ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
138
  }
139
  {
140
    auto tensor = at::empty({3, 4}, at::kFloat);
141
    auto hopefully_not_copy = tensor.to(tensor.dtype());
142
    ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
143
  }
144
  {
145
    auto tensor = at::empty({3, 4}, at::kFloat);
146
    auto hopefully_not_copy = tensor.to(tensor.device());
147
    ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
148
  }
149
  {
150
    auto tensor = at::empty({3, 4}, at::kFloat);
151
    auto hopefully_not_copy = tensor.to(tensor);
152
    ASSERT_EQ(hopefully_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
153
  }
154
}
155

156
TEST(TensorTest, AtTensorCtorScalar) {
157
  auto tensor = at::tensor(123);
158
  ASSERT_EQ(tensor.numel(), 1);
159
  ASSERT_EQ(tensor.dtype(), at::kInt);
160
  ASSERT_EQ(tensor[0].item<int32_t>(), 123);
161

162
  tensor = at::tensor(123.456f);
163
  ASSERT_EQ(tensor.numel(), 1);
164
  ASSERT_EQ(tensor.dtype(), at::kFloat);
165
  ASSERT_TRUE(almost_equal(tensor[0], 123.456f));
166

167
  tensor = at::tensor(123.456);
168
  ASSERT_EQ(tensor.numel(), 1);
169
  ASSERT_EQ(tensor.dtype(), at::kDouble);
170
  ASSERT_TRUE(almost_equal(tensor[0], 123.456));
171

172
  tensor = at::tensor(123, at::dtype(at::kFloat)) + 0.5;
173
  ASSERT_EQ(tensor.numel(), 1);
174
  ASSERT_EQ(tensor.dtype(), at::kFloat);
175
  ASSERT_TRUE(almost_equal(tensor[0], 123.5));
176

177
  tensor = at::tensor(c10::complex<float>(1.0, 2.0)) + 0.5;
178
  ASSERT_EQ(tensor.numel(), 1);
179
  ASSERT_EQ(tensor.dtype(), at::kComplexFloat);
180
  ASSERT_TRUE(almost_equal(tensor[0], c10::complex<float>(1.5, 2.0)));
181

182
  tensor =
183
      at::tensor(c10::complex<float>(1.0, 2.0), at::dtype(at::kComplexFloat)) +
184
      0.5;
185
  ASSERT_EQ(tensor.numel(), 1);
186
  ASSERT_EQ(tensor.dtype(), at::kComplexFloat);
187
  ASSERT_TRUE(almost_equal(tensor[0], c10::complex<float>(1.5, 2.0)));
188

189
  tensor = at::tensor(c10::complex<double>(1.0, 2.0)) + 0.5;
190
  ASSERT_EQ(tensor.numel(), 1);
191
  ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
192
  ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5, 2.0)));
193

194
  tensor =
195
      at::tensor(c10::complex<float>(1.0, 2.0), at::dtype(at::kComplexDouble)) +
196
      0.5;
197
  ASSERT_EQ(tensor.numel(), 1);
198
  ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
199
  ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5, 2.0)));
200
}
201

202
TEST(TensorTest, AtTensorCtorSingleDim) {
203
  auto tensor = at::tensor({1, 2, 3});
204
  ASSERT_EQ(tensor.numel(), 3);
205
  ASSERT_EQ(tensor.dtype(), at::kInt);
206
  ASSERT_TRUE(exactly_equal(tensor[0], 1));
207
  ASSERT_TRUE(exactly_equal(tensor[1], 2));
208
  ASSERT_TRUE(exactly_equal(tensor[2], 3));
209

210
  tensor = at::tensor(std::vector<int>({1, 2, 3}));
211
  ASSERT_EQ(tensor.numel(), 3);
212
  ASSERT_EQ(tensor.dtype(), at::kInt);
213
  ASSERT_TRUE(exactly_equal(tensor[0], 1));
214
  ASSERT_TRUE(exactly_equal(tensor[1], 2));
215
  ASSERT_TRUE(exactly_equal(tensor[2], 3));
216

217
  tensor = at::tensor({1.5, 2.25, 3.125});
218
  ASSERT_EQ(tensor.numel(), 3);
219
  ASSERT_EQ(tensor.dtype(), at::kDouble);
220
  ASSERT_TRUE(almost_equal(tensor[0], 1.5));
221
  ASSERT_TRUE(almost_equal(tensor[1], 2.25));
222
  ASSERT_TRUE(almost_equal(tensor[2], 3.125));
223

224
  tensor = at::tensor(
225
      {c10::complex<float>(1.5, 0.15),
226
       c10::complex<float>(1.5, 0.15),
227
       c10::complex<float>(3.125, 0.3125)});
228
  ASSERT_EQ(tensor.numel(), 3);
229
  ASSERT_EQ(tensor.dtype(), at::kComplexFloat);
230
  ASSERT_TRUE(almost_equal(tensor[0], c10::complex<float>(1.5, 0.15)));
231
  ASSERT_TRUE(almost_equal(tensor[1], c10::complex<float>(1.5, 0.15)));
232
  ASSERT_TRUE(almost_equal(tensor[2], c10::complex<float>(3.125, 0.3125)));
233

234
  tensor = at::tensor(
235
      {c10::complex<double>(1.5, 0.15),
236
       c10::complex<double>(1.5, 0.15),
237
       c10::complex<double>(3.125, 0.3125)});
238
  ASSERT_EQ(tensor.numel(), 3);
239
  ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
240
  ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5, 0.15)));
241
  ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(1.5, 0.15)));
242
  ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.125, 0.3125)));
243

244
  tensor = at::tensor({1.1, 2.2, 3.3}, at::dtype(at::kInt));
245
  ASSERT_EQ(tensor.numel(), 3);
246
  ASSERT_EQ(tensor.dtype(), at::kInt);
247
  ASSERT_EQ(tensor.layout(), at::kStrided);
248
  ASSERT_TRUE(exactly_equal(tensor[0], 1));
249
  ASSERT_TRUE(exactly_equal(tensor[1], 2));
250
  ASSERT_TRUE(exactly_equal(tensor[2], 3));
251

252
  tensor = at::tensor(std::vector<double>({1.5, 2.25, 3.125}));
253
  ASSERT_EQ(tensor.numel(), 3);
254
  ASSERT_EQ(tensor.dtype(), at::kDouble);
255
  ASSERT_TRUE(almost_equal(tensor[0], 1.5));
256
  ASSERT_TRUE(almost_equal(tensor[1], 2.25));
257
  ASSERT_TRUE(almost_equal(tensor[2], 3.125));
258

259
  tensor = at::tensor(std::vector<c10::complex<float>>(
260
      {c10::complex<float>(1.5, 0.15),
261
       c10::complex<float>(1.5, 0.15),
262
       c10::complex<float>(3.125, 0.3125)}));
263
  ASSERT_EQ(tensor.numel(), 3);
264
  ASSERT_EQ(tensor.dtype(), at::kComplexFloat);
265
  ASSERT_TRUE(almost_equal(tensor[0], c10::complex<float>(1.5, 0.15)));
266
  ASSERT_TRUE(almost_equal(tensor[1], c10::complex<float>(1.5, 0.15)));
267
  ASSERT_TRUE(almost_equal(tensor[2], c10::complex<float>(3.125, 0.3125)));
268

269
  tensor = at::tensor(std::vector<c10::complex<double>>(
270
      {c10::complex<double>(1.5, 0.15),
271
       c10::complex<double>(1.5, 0.15),
272
       c10::complex<double>(3.125, 0.3125)}));
273
  ASSERT_EQ(tensor.numel(), 3);
274
  ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
275
  ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5, 0.15)));
276
  ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(1.5, 0.15)));
277
  ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.125, 0.3125)));
278

279
  std::vector<int> v = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
280
  tensor = at::tensor(v);
281
  ASSERT_EQ(tensor.numel(), v.size());
282
  ASSERT_EQ(tensor.dtype(), at::kInt);
283
  for (const auto i : c10::irange(v.size())) {
284
    ASSERT_TRUE(exactly_equal(tensor[i], v.at(i)));
285
  }
286

287
  std::vector<double> w = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0};
288
  tensor = at::tensor(w);
289
  ASSERT_EQ(tensor.numel(), w.size());
290
  ASSERT_EQ(tensor.dtype(), at::kDouble);
291
  for (const auto i : c10::irange(w.size())) {
292
    ASSERT_TRUE(almost_equal(tensor[i], w.at(i)));
293
  }
294

295
  std::vector<c10::complex<double>> x = {
296
      {1.1, -1.1},
297
      {2.2, -2.2},
298
      {3.3, -3.3},
299
      {4.4, -4.4},
300
      {5.5, -5.5},
301
      {6.6, -6.6},
302
      {7.7, -7.7},
303
      {8.8, -8.8},
304
      {9.9, -9.9},
305
      {10.0, -10.0}};
306
  tensor = at::tensor(x);
307
  ASSERT_EQ(tensor.numel(), x.size());
308
  ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
309
  for (const auto i : c10::irange(x.size())) {
310
    ASSERT_TRUE(almost_equal(tensor[i], x.at(i)));
311
  }
312
}
313

314
TEST(TensorTest, AtTensorCastRealToComplex) {
315
  auto tensor =
316
      at::tensor(std::vector<double>({1.5, 2.5, 3.5}), at::kComplexDouble);
317
  ASSERT_EQ(tensor.numel(), 3);
318
  ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
319
  ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
320
  ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(2.5)));
321
  ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.5)));
322

323
  tensor = at::tensor({1.5, 2.5, 3.5}, at::kComplexDouble);
324
  ASSERT_EQ(tensor.numel(), 3);
325
  ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
326
  ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
327
  ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(2.5)));
328
  ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.5)));
329

330
  tensor = at::tensor(1.5, at::kComplexDouble);
331
  ASSERT_EQ(tensor.numel(), 1);
332
  ASSERT_EQ(tensor.dtype(), at::kComplexDouble);
333
  ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
334
}
335

336
TEST(TensorTest, AtTensorCastComplexToRealErrorChecks) {
337
  {
338
    ASSERT_THROWS_WITH(
339
        at::tensor(c10::complex<float>(0.1, 0.2), at::kFloat),
340
        "\"tensor_cpu\" not implemented for 'Float'");
341
  }
342
  {
343
    ASSERT_THROWS_WITH(
344
        at::tensor({c10::complex<float>(0.1, 0.2)}, at::kFloat),
345
        "\"tensor_cpu\" not implemented for 'Float'");
346
  }
347
  {
348
    ASSERT_THROWS_WITH(
349
        at::tensor(
350
            std::vector<c10::complex<float>>{c10::complex<float>(0.1, 0.2)},
351
            at::kFloat),
352
        "\"tensor_cpu\" not implemented for 'Float'");
353
  }
354
}
355

356
TEST(TensorTest, TorchTensorCtorScalarIntegralType) {
357
  auto tensor = torch::tensor(123);
358
  ASSERT_EQ(tensor.numel(), 1);
359
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
360
  ASSERT_EQ(tensor.dtype(), at::kLong);
361
  ASSERT_EQ(tensor.item<int64_t>(), 123);
362
}
363

364
void test_TorchTensorCtorScalarFloatingType_expected_dtype(
365
    c10::ScalarType default_dtype) {
366
  AutoDefaultDtypeMode dtype_mode(default_dtype);
367

368
  auto tensor = torch::tensor(123.456f);
369
  ASSERT_EQ(tensor.numel(), 1);
370
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
371
  ASSERT_EQ(tensor.dtype(), default_dtype);
372
  ASSERT_TRUE(almost_equal(tensor, 123.456f));
373

374
  tensor = torch::tensor(123.456);
375
  ASSERT_EQ(tensor.numel(), 1);
376
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
377
  ASSERT_EQ(tensor.dtype(), default_dtype);
378
  ASSERT_TRUE(almost_equal(tensor, 123.456));
379

380
  tensor = torch::tensor({123.456});
381
  ASSERT_EQ(tensor.numel(), 1);
382
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1}));
383
  ASSERT_EQ(tensor.dtype(), default_dtype);
384
  ASSERT_TRUE(almost_equal(tensor[0], 123.456));
385
}
386

387
TEST(TensorTest, TorchTensorCtorScalarFloatingType) {
388
  test_TorchTensorCtorScalarFloatingType_expected_dtype(
389
      /*default_dtype=*/torch::kFloat);
390
  test_TorchTensorCtorScalarFloatingType_expected_dtype(
391
      /*default_dtype=*/torch::kDouble);
392
}
393

394
TEST(TensorTest, TorchTensorCtorScalarBoolType) {
395
  auto tensor = torch::tensor(true);
396
  ASSERT_EQ(tensor.numel(), 1);
397
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({}));
398
  ASSERT_EQ(tensor.dtype(), at::kBool);
399
  ASSERT_TRUE(exactly_equal(tensor, true));
400

401
  tensor = torch::tensor({true});
402
  ASSERT_EQ(tensor.numel(), 1);
403
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1}));
404
  ASSERT_EQ(tensor.dtype(), at::kBool);
405
  ASSERT_TRUE(exactly_equal(tensor[0], true));
406
}
407

408
TEST(TensorTest, TorchTensorCtorSingleDimIntegralType) {
409
  auto tensor = torch::tensor({1, 2, 3});
410
  ASSERT_EQ(tensor.numel(), 3);
411
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
412
  ASSERT_EQ(tensor.dtype(), at::kLong);
413
  ASSERT_TRUE(exactly_equal(tensor[0], 1));
414
  ASSERT_TRUE(exactly_equal(tensor[1], 2));
415
  ASSERT_TRUE(exactly_equal(tensor[2], 3));
416

417
  tensor = torch::tensor(at::ArrayRef<int>({1, 2, 3}));
418
  ASSERT_EQ(tensor.numel(), 3);
419
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
420
  ASSERT_EQ(tensor.dtype(), at::kLong);
421
  ASSERT_TRUE(exactly_equal(tensor[0], 1));
422
  ASSERT_TRUE(exactly_equal(tensor[1], 2));
423
  ASSERT_TRUE(exactly_equal(tensor[2], 3));
424

425
  tensor = torch::tensor(std::vector<int>({1, 2, 3}));
426
  ASSERT_EQ(tensor.numel(), 3);
427
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
428
  ASSERT_EQ(tensor.dtype(), at::kLong);
429
  ASSERT_TRUE(exactly_equal(tensor[0], 1));
430
  ASSERT_TRUE(exactly_equal(tensor[1], 2));
431
  ASSERT_TRUE(exactly_equal(tensor[2], 3));
432

433
  tensor = torch::tensor(at::ArrayRef<int64_t>({1, 2, 3}));
434
  ASSERT_EQ(tensor.numel(), 3);
435
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
436
  ASSERT_EQ(tensor.dtype(), at::kLong);
437
  ASSERT_TRUE(exactly_equal(tensor[0], 1));
438
  ASSERT_TRUE(exactly_equal(tensor[1], 2));
439
  ASSERT_TRUE(exactly_equal(tensor[2], 3));
440

441
  tensor = torch::tensor(std::vector<int64_t>({1, 2, 3}));
442
  ASSERT_EQ(tensor.numel(), 3);
443
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
444
  ASSERT_EQ(tensor.dtype(), at::kLong);
445
  ASSERT_TRUE(exactly_equal(tensor[0], 1));
446
  ASSERT_TRUE(exactly_equal(tensor[1], 2));
447
  ASSERT_TRUE(exactly_equal(tensor[2], 3));
448
}
449

450
void test_TorchTensorCtorSingleDimFloatingType_expected_dtype(
451
    c10::ScalarType default_dtype) {
452
  AutoDefaultDtypeMode dtype_mode(default_dtype);
453

454
  auto tensor = torch::tensor({1.5, 2.25, 3.125});
455
  ASSERT_EQ(tensor.numel(), 3);
456
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
457
  ASSERT_EQ(tensor.dtype(), default_dtype);
458
  ASSERT_TRUE(almost_equal(tensor[0], 1.5));
459
  ASSERT_TRUE(almost_equal(tensor[1], 2.25));
460
  ASSERT_TRUE(almost_equal(tensor[2], 3.125));
461

462
  tensor = torch::tensor({1.5f, 2.25f, 3.125f});
463
  ASSERT_EQ(tensor.numel(), 3);
464
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
465
  ASSERT_EQ(tensor.dtype(), default_dtype);
466
  ASSERT_TRUE(almost_equal(tensor[0], 1.5f));
467
  ASSERT_TRUE(almost_equal(tensor[1], 2.25f));
468
  ASSERT_TRUE(almost_equal(tensor[2], 3.125f));
469

470
  tensor = torch::tensor(at::ArrayRef<float>({1.5f, 2.25f, 3.125f}));
471
  ASSERT_EQ(tensor.numel(), 3);
472
  ASSERT_EQ(tensor.dtype(), default_dtype);
473
  ASSERT_TRUE(almost_equal(tensor[0], 1.5));
474
  ASSERT_TRUE(almost_equal(tensor[1], 2.25));
475
  ASSERT_TRUE(almost_equal(tensor[2], 3.125));
476

477
  tensor = torch::tensor(std::vector<float>({1.5f, 2.25f, 3.125f}));
478
  ASSERT_EQ(tensor.numel(), 3);
479
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
480
  ASSERT_EQ(tensor.dtype(), default_dtype);
481
  ASSERT_TRUE(almost_equal(tensor[0], 1.5));
482
  ASSERT_TRUE(almost_equal(tensor[1], 2.25));
483
  ASSERT_TRUE(almost_equal(tensor[2], 3.125));
484

485
  tensor = torch::tensor(at::ArrayRef<double>({1.5, 2.25, 3.125}));
486
  ASSERT_EQ(tensor.numel(), 3);
487
  ASSERT_EQ(tensor.dtype(), default_dtype);
488
  ASSERT_TRUE(almost_equal(tensor[0], 1.5));
489
  ASSERT_TRUE(almost_equal(tensor[1], 2.25));
490
  ASSERT_TRUE(almost_equal(tensor[2], 3.125));
491

492
  tensor = torch::tensor(std::vector<double>({1.5, 2.25, 3.125}));
493
  ASSERT_EQ(tensor.numel(), 3);
494
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
495
  ASSERT_EQ(tensor.dtype(), default_dtype);
496
  ASSERT_TRUE(almost_equal(tensor[0], 1.5));
497
  ASSERT_TRUE(almost_equal(tensor[1], 2.25));
498
  ASSERT_TRUE(almost_equal(tensor[2], 3.125));
499
}
500

501
TEST(TensorTest, TorchTensorCtorSingleDimFloatingType) {
502
  test_TorchTensorCtorSingleDimFloatingType_expected_dtype(
503
      /*default_dtype=*/torch::kFloat);
504
  test_TorchTensorCtorSingleDimFloatingType_expected_dtype(
505
      /*default_dtype=*/torch::kDouble);
506
}
507

508
TEST(TensorTest, TorchTensorCtorSingleDimBoolType) {
509
  auto tensor = torch::tensor({true, false, true});
510
  ASSERT_EQ(tensor.numel(), 3);
511
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
512
  ASSERT_EQ(tensor.dtype(), at::kBool);
513
  ASSERT_TRUE(exactly_equal(tensor[0], true));
514
  ASSERT_TRUE(exactly_equal(tensor[1], false));
515
  ASSERT_TRUE(exactly_equal(tensor[2], true));
516

517
  tensor = torch::tensor(at::ArrayRef<bool>({true, false, true}));
518
  ASSERT_EQ(tensor.numel(), 3);
519
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({3}));
520
  ASSERT_EQ(tensor.dtype(), at::kBool);
521
  ASSERT_TRUE(exactly_equal(tensor[0], true));
522
  ASSERT_TRUE(exactly_equal(tensor[1], false));
523
  ASSERT_TRUE(exactly_equal(tensor[2], true));
524
}
525

526
TEST(TensorTest, TorchTensorCtorMultiDimIntegralType) {
527
  {
528
    auto tensor = torch::tensor({{1, 2}});
529
    ASSERT_EQ(tensor.dtype(), torch::kLong);
530
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
531
    ASSERT_TRUE(torch::allclose(
532
        tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
533
    ASSERT_FALSE(tensor.requires_grad());
534
  }
535
  {
536
    auto tensor = torch::tensor({{1}, {2}});
537
    ASSERT_EQ(tensor.dtype(), torch::kLong);
538
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 1}));
539
    ASSERT_TRUE(torch::allclose(
540
        tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
541
    ASSERT_FALSE(tensor.requires_grad());
542
  }
543
  {
544
    auto tensor = torch::tensor({{{1, 2}}});
545
    ASSERT_EQ(tensor.dtype(), torch::kLong);
546
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 2}));
547
    ASSERT_TRUE(torch::allclose(
548
        tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
549
    ASSERT_FALSE(tensor.requires_grad());
550
  }
551
  {
552
    auto tensor = torch::tensor({{{1}, {2}}});
553
    ASSERT_EQ(tensor.dtype(), torch::kLong);
554
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2, 1}));
555
    ASSERT_TRUE(torch::allclose(
556
        tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
557
    ASSERT_FALSE(tensor.requires_grad());
558
  }
559
  {
560
    auto tensor = torch::tensor({{1, 2}, {3, 4}});
561
    ASSERT_EQ(tensor.dtype(), torch::kLong);
562
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 2}));
563
    ASSERT_TRUE(torch::allclose(
564
        tensor, torch::arange(1, 5, torch::kLong).view(tensor.sizes())));
565
    ASSERT_FALSE(tensor.requires_grad());
566
  }
567
  {
568
    auto tensor = torch::tensor({{{{{{{{{{1}}}}}}}}}});
569
    ASSERT_EQ(tensor.dtype(), torch::kLong);
570
    ASSERT_EQ(
571
        tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 1}));
572
    ASSERT_TRUE(torch::allclose(
573
        tensor, torch::full({1}, 1, torch::kLong).view(tensor.sizes())));
574
    ASSERT_FALSE(tensor.requires_grad());
575
  }
576
  {
577
    auto tensor = torch::tensor({{{{{{{{{{1, 2}}}}}}}}}});
578
    ASSERT_EQ(tensor.dtype(), torch::kLong);
579
    ASSERT_EQ(
580
        tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 2}));
581
    ASSERT_TRUE(torch::allclose(
582
        tensor, torch::arange(1, 3, torch::kLong).view(tensor.sizes())));
583
    ASSERT_FALSE(tensor.requires_grad());
584
  }
585
}
586

587
void test_TorchTensorCtorMultiDimFloatingType_expected_dtype(
588
    c10::ScalarType default_dtype) {
589
  AutoDefaultDtypeMode dtype_mode(default_dtype);
590
  {
591
    auto tensor = torch::tensor({{1.0, 2.0}});
592
    ASSERT_EQ(tensor.dtype(), default_dtype);
593
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
594
    ASSERT_TRUE(torch::allclose(
595
        tensor, torch::arange(1, 3, default_dtype).view(tensor.sizes())));
596
    ASSERT_FALSE(tensor.requires_grad());
597
  }
598
  {
599
    auto tensor = torch::tensor(
600
        {{{{{{{{1.0, 2.0, 3.0}}}}},
601
           {{{{{4.0, 5.0, 6.0}}}}},
602
           {{{{{7.0, 8.0, 9.0}}}}}}}});
603
    ASSERT_EQ(tensor.dtype(), default_dtype);
604
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 3, 1, 1, 1, 1, 3}));
605
    ASSERT_TRUE(torch::allclose(
606
        tensor, torch::arange(1, 10, default_dtype).view(tensor.sizes())));
607
    ASSERT_FALSE(tensor.requires_grad());
608
  }
609
}
610

611
TEST(TensorTest, TorchTensorCtorMultiDimFloatingType) {
612
  test_TorchTensorCtorMultiDimFloatingType_expected_dtype(
613
      /*default_dtype=*/torch::kFloat);
614
  test_TorchTensorCtorMultiDimFloatingType_expected_dtype(
615
      /*default_dtype=*/torch::kDouble);
616
}
617

618
TEST(TensorTest, TorchTensorCtorMultiDimBoolType) {
619
  {
620
    auto tensor = torch::tensor({{true, false}});
621
    ASSERT_EQ(tensor.dtype(), torch::kBool);
622
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
623
    auto expected = torch::empty(tensor.sizes(), torch::kBool);
624
    expected[0][0] = true;
625
    expected[0][1] = false;
626
    ASSERT_TRUE(torch::equal(tensor, expected));
627
    ASSERT_FALSE(tensor.requires_grad());
628
  }
629
  {
630
    auto tensor = torch::tensor({{true}, {false}});
631
    ASSERT_EQ(tensor.dtype(), torch::kBool);
632
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 1}));
633
    auto expected = torch::empty(tensor.sizes(), torch::kBool);
634
    expected[0][0] = true;
635
    expected[1][0] = false;
636
    ASSERT_TRUE(torch::equal(tensor, expected));
637
    ASSERT_FALSE(tensor.requires_grad());
638
  }
639
}
640

641
TEST(TensorTest, TorchTensorCtorMultiDimWithOptions) {
642
  {
643
    auto tensor = torch::tensor({{1, 2}}, torch::dtype(torch::kInt));
644
    ASSERT_EQ(tensor.dtype(), torch::kInt);
645
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2}));
646
    ASSERT_TRUE(torch::allclose(
647
        tensor, torch::arange(1, 3, torch::kInt).view(tensor.sizes())));
648
    ASSERT_FALSE(tensor.requires_grad());
649
  }
650
  {
651
    auto tensor = torch::tensor(
652
        {{1, 2}, {3, 4}}, torch::dtype(torch::kFloat).requires_grad(true));
653
    ASSERT_EQ(tensor.dtype(), torch::kFloat);
654
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 2}));
655
    ASSERT_TRUE(torch::allclose(
656
        tensor, torch::arange(1, 5, torch::kFloat).view(tensor.sizes())));
657
    ASSERT_TRUE(tensor.requires_grad());
658
  }
659
}
660

661
TEST(TensorTest, TorchTensorCtorMultiDimErrorChecks) {
662
  {
663
    ASSERT_THROWS_WITH(
664
        torch::tensor({{{2, 3, 4}, {{5, 6}, {7}}}}),
665
        "Expected all sub-lists to have sizes: 2 (e.g. {5, 6}), but got sub-list {7} with sizes: 1");
666
  }
667
  {
668
    ASSERT_THROWS_WITH(
669
        torch::tensor({{{1, 2.0}, {1, 2.0}}}),
670
        "Expected all elements of the tensor to have the same scalar type: Int, but got element of scalar type: Double");
671
  }
672
  {
673
    ASSERT_THROWS_WITH(
674
        torch::tensor({{{true, 2.0, 3}, {true, 2.0, 3}}}),
675
        "Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Double");
676
  }
677
  {
678
    ASSERT_THROWS_WITH(
679
        torch::tensor({{{true}, {2}}}),
680
        "Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Int");
681
  }
682
  {
683
    ASSERT_THROWS_WITH(
684
        torch::tensor({{{true, 2}}}),
685
        "Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Int");
686
  }
687
}
688

689
TEST(TensorTest, TorchTensorCastRealToComplex) {
690
  auto tensor = torch::tensor(
691
      std::vector<double>({1.5, 2.5, 3.5}), torch::kComplexDouble);
692
  ASSERT_EQ(tensor.numel(), 3);
693
  ASSERT_EQ(tensor.dtype(), torch::kComplexDouble);
694
  ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
695
  ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(2.5)));
696
  ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.5)));
697

698
  tensor = torch::tensor({1.5, 2.5, 3.5}, torch::kComplexDouble);
699
  ASSERT_EQ(tensor.numel(), 3);
700
  ASSERT_EQ(tensor.dtype(), torch::kComplexDouble);
701
  ASSERT_TRUE(almost_equal(tensor[0], c10::complex<double>(1.5)));
702
  ASSERT_TRUE(almost_equal(tensor[1], c10::complex<double>(2.5)));
703
  ASSERT_TRUE(almost_equal(tensor[2], c10::complex<double>(3.5)));
704

705
  tensor = torch::tensor(1.5, torch::kComplexDouble);
706
  ASSERT_EQ(tensor.numel(), 1);
707
  ASSERT_EQ(tensor.dtype(), torch::kComplexDouble);
708
  ASSERT_TRUE(almost_equal(tensor, c10::complex<double>(1.5)));
709
}
710

711
TEST(TensorTest, TorchTensorCastComplexToRealErrorChecks) {
712
  {
713
    ASSERT_THROWS_WITH(
714
        torch::tensor(c10::complex<float>(0.1, 0.2), torch::kFloat),
715
        "value cannot be converted to type float without overflow");
716
  }
717
  {
718
    ASSERT_THROWS_WITH(
719
        torch::tensor(
720
            {c10::complex<float>(0.1, 0.2), c10::complex<float>(0.3, 0.4)},
721
            torch::kFloat),
722
        "value cannot be converted to type float without overflow");
723
  }
724
  {
725
    ASSERT_THROWS_WITH(
726
        torch::tensor(
727
            std::vector<c10::complex<float>>{
728
                c10::complex<float>(0.1, 0.2), c10::complex<float>(0.3, 0.4)},
729
            torch::kFloat),
730
        "can not do torch::tensor(complex, dtype=non-complex) because complex can not be casted to real number without loss of information");
731
  }
732
}
733

734
void test_TorchTensorCtorMultiDim_CUDA_expected_dtype(
735
    c10::ScalarType default_dtype) {
736
  AutoDefaultDtypeMode dtype_mode(default_dtype);
737

738
  auto tensor = torch::tensor(
739
      {{{{{{{{1.0, 2.0, 3.0}}}}},
740
         {{{{{4.0, 5.0, 6.0}}}}},
741
         {{{{{7.0, 8.0, 9.0}}}}}}}},
742
      torch::dtype(default_dtype).device(torch::kCUDA));
743
  ASSERT_TRUE(tensor.device().is_cuda());
744
  ASSERT_EQ(tensor.dtype(), default_dtype);
745
  ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 3, 1, 1, 1, 1, 3}));
746
  ASSERT_TRUE(torch::allclose(
747
      tensor,
748
      torch::arange(1, 10, default_dtype)
749
          .view(tensor.sizes())
750
          .to(torch::kCUDA)));
751
  ASSERT_FALSE(tensor.requires_grad());
752
}
753

754
TEST(TensorTest, TorchTensorCtorMultiDim_CUDA) {
755
  test_TorchTensorCtorMultiDim_CUDA_expected_dtype(
756
      /*default_dtype=*/torch::kFloat);
757
  test_TorchTensorCtorMultiDim_CUDA_expected_dtype(
758
      /*default_dtype=*/torch::kDouble);
759
}
760

761
void test_TorchTensorCtorZeroSizedDim_expected_dtype(
762
    c10::ScalarType default_dtype) {
763
  AutoDefaultDtypeMode dtype_mode(default_dtype);
764
  {
765
    auto tensor = torch::tensor({});
766
    ASSERT_EQ(tensor.numel(), 0);
767
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({0}));
768
    ASSERT_EQ(tensor.dtype(), default_dtype);
769
    ASSERT_FALSE(tensor.requires_grad());
770
  }
771
  {
772
    auto tensor = torch::tensor({{}, {}});
773
    ASSERT_EQ(tensor.numel(), 0);
774
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({2, 0}));
775
    ASSERT_EQ(tensor.dtype(), default_dtype);
776
    ASSERT_FALSE(tensor.requires_grad());
777
  }
778
  {
779
    auto tensor = torch::tensor({{{}, {}}});
780
    ASSERT_EQ(tensor.numel(), 0);
781
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 2, 0}));
782
    ASSERT_EQ(tensor.dtype(), default_dtype);
783
    ASSERT_FALSE(tensor.requires_grad());
784
  }
785
  {
786
    auto tensor = torch::tensor({{{}}});
787
    ASSERT_EQ(tensor.numel(), 0);
788
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 0}));
789
    ASSERT_EQ(tensor.dtype(), default_dtype);
790
    ASSERT_FALSE(tensor.requires_grad());
791
  }
792
  {
793
    auto tensor = torch::tensor({{{{{{{{}}}}}}}});
794
    ASSERT_EQ(tensor.numel(), 0);
795
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 0}));
796
    ASSERT_EQ(tensor.dtype(), default_dtype);
797
    ASSERT_FALSE(tensor.requires_grad());
798
  }
799
  {
800
    auto tensor = torch::tensor({{{{{{{{}}}}, {{{{}}}}}}}});
801
    ASSERT_EQ(tensor.numel(), 0);
802
    ASSERT_EQ(tensor.sizes(), std::vector<int64_t>({1, 1, 1, 2, 1, 1, 1, 0}));
803
    ASSERT_EQ(tensor.dtype(), default_dtype);
804
    ASSERT_FALSE(tensor.requires_grad());
805
  }
806
  {
807
    auto tensor = torch::tensor({{{{{{{{{{}}}}}}}}}});
808
    ASSERT_EQ(tensor.numel(), 0);
809
    ASSERT_EQ(
810
        tensor.sizes(), std::vector<int64_t>({1, 1, 1, 1, 1, 1, 1, 1, 1, 0}));
811
    ASSERT_EQ(tensor.dtype(), default_dtype);
812
    ASSERT_FALSE(tensor.requires_grad());
813
  }
814
}
815

816
TEST(TensorTest, TorchTensorCtorZeroSizedDim) {
817
  test_TorchTensorCtorZeroSizedDim_expected_dtype(
818
      /*default_dtype=*/torch::kFloat);
819
  test_TorchTensorCtorZeroSizedDim_expected_dtype(
820
      /*default_dtype=*/torch::kDouble);
821
}
822

823
void test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(
824
    c10::ScalarType default_dtype) {
825
  AutoDefaultDtypeMode dtype_mode(default_dtype);
826

827
  ASSERT_EQ(torch::tensor({1., 2., 3.}).dtype(), default_dtype);
828
  ASSERT_EQ(torch::tensor({{1., 2., 3.}}).dtype(), default_dtype);
829
  ASSERT_EQ(
830
      torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(),
831
      default_dtype);
832
  ASSERT_EQ(
833
      torch::tensor({{1., 2., 3.}}, torch::TensorOptions()).dtype(),
834
      default_dtype);
835
}
836

837
TEST(TensorTest, TorchTensorCtorWithoutSpecifyingDtype) {
838
  ASSERT_EQ(torch::tensor({1, 2, 3}).dtype(), torch::kLong);
839
  ASSERT_EQ(torch::tensor({{1, 2, 3}}).dtype(), torch::kLong);
840
  ASSERT_EQ(
841
      torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kLong);
842
  ASSERT_EQ(
843
      torch::tensor({{1, 2, 3}}, torch::TensorOptions()).dtype(), torch::kLong);
844

845
  test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(
846
      /*default_dtype=*/torch::kFloat);
847
  test_TorchTensorCtorWithoutSpecifyingDtype_expected_dtype(
848
      /*default_dtype=*/torch::kDouble);
849
}
850

851
void test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(
852
    c10::ScalarType default_dtype) {
853
  AutoDefaultDtypeMode dtype_mode(default_dtype);
854

855
  ASSERT_EQ(
856
      torch::tensor({1, 2, 3}, torch::TensorOptions()).dtype(), torch::kLong);
857
  ASSERT_EQ(
858
      torch::tensor(at::ArrayRef<int>({1, 2, 3}), torch::TensorOptions())
859
          .dtype(),
860
      torch::kLong);
861
  ASSERT_EQ(
862
      torch::tensor(std::vector<int>({1, 2, 3}), torch::TensorOptions())
863
          .dtype(),
864
      torch::kLong);
865

866
  ASSERT_EQ(
867
      torch::tensor({1., 2., 3.}, torch::TensorOptions()).dtype(),
868
      default_dtype);
869
  ASSERT_EQ(
870
      torch::tensor(at::ArrayRef<double>({1., 2., 3.}), torch::TensorOptions())
871
          .dtype(),
872
      default_dtype);
873
  ASSERT_EQ(
874
      torch::tensor(std::vector<double>({1., 2., 3.}), torch::TensorOptions())
875
          .dtype(),
876
      default_dtype);
877

878
  ASSERT_EQ(
879
      torch::tensor({1.f, 2.f, 3.f}, torch::TensorOptions()).dtype(),
880
      default_dtype);
881
  ASSERT_EQ(
882
      torch::tensor(
883
          at::ArrayRef<float>({1.f, 2.f, 3.f}), torch::TensorOptions())
884
          .dtype(),
885
      default_dtype);
886
  ASSERT_EQ(
887
      torch::tensor(std::vector<float>({1.f, 2.f, 3.f}), torch::TensorOptions())
888
          .dtype(),
889
      default_dtype);
890
}
891

892
TEST(TensorTest, TorchTensorCtorWithNonDtypeOptions) {
893
  test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(
894
      /*default_dtype=*/torch::kFloat);
895
  test_TorchTensorCtorWithNonDtypeOptions_expected_dtype(
896
      /*default_dtype=*/torch::kDouble);
897
}
898

899
void test_Arange_expected_dtype(c10::ScalarType default_dtype) {
900
  AutoDefaultDtypeMode dtype_mode(default_dtype);
901

902
  ASSERT_EQ(torch::arange(0., 5).dtype(), default_dtype);
903
}
904

905
TEST(TensorTest, Arange) {
906
  {
907
    auto x = torch::arange(0, 5);
908
    ASSERT_EQ(x.dtype(), torch::kLong);
909
  }
910
  test_Arange_expected_dtype(torch::kFloat);
911
  test_Arange_expected_dtype(torch::kDouble);
912
}
913

914
TEST(TensorTest, PrettyPrintTensorDataContainer) {
915
  { ASSERT_EQ(c10::str(torch::detail::TensorDataContainer(1.1)), "1.1"); }
916
  {
917
    ASSERT_EQ(
918
        c10::str(torch::detail::TensorDataContainer({1.1, 2.2})), "{1.1, 2.2}");
919
  }
920
  {
921
    ASSERT_EQ(
922
        c10::str(torch::detail::TensorDataContainer({{1, 2}, {3, 4}})),
923
        "{{1, 2}, {3, 4}}");
924
  }
925
  {
926
    ASSERT_EQ(
927
        c10::str(torch::detail::TensorDataContainer(
928
            {{{{{{{{1.1, 2.2, 3.3}}}}},
929
               {{{{{4.4, 5.5, 6.6}}}}},
930
               {{{{{7.7, 8.8, 9.9}}}}}}}})),
931
        "{{{{{{{{1.1, 2.2, 3.3}}}}}, {{{{{4.4, 5.5, 6.6}}}}}, {{{{{7.7, 8.8, 9.9}}}}}}}}");
932
  }
933
  {
934
    ASSERT_EQ(
935
        c10::str(torch::detail::TensorDataContainer({{{{{{{{{{1}}}}}}}}}})),
936
        "{{{{{{{{{{1}}}}}}}}}}");
937
  }
938
  {
939
    ASSERT_EQ(
940
        c10::str(torch::detail::TensorDataContainer({{{{{{{{{{}}}}}}}}}})),
941
        "{{{{{{{{{{}}}}}}}}}}");
942
  }
943
  {
944
    ASSERT_EQ(
945
        c10::str(torch::detail::TensorDataContainer({{{{{{{{{{1, 2}}}}}}}}}})),
946
        "{{{{{{{{{{1, 2}}}}}}}}}}");
947
  }
948
  {
949
    ASSERT_EQ(
950
        c10::str(torch::detail::TensorDataContainer(
951
            at::ArrayRef<double>({1.1, 2.2}))),
952
        "{1.1, 2.2}");
953
  }
954
  {
955
    ASSERT_EQ(
956
        c10::str(torch::detail::TensorDataContainer(
957
            std::vector<double>({1.1, 2.2}))),
958
        "{1.1, 2.2}");
959
  }
960
}
961

962
TEST(TensorTest, TensorDataContainerCallingAccessorOfWrongType) {
963
  {
964
    ASSERT_THROWS_WITH(
965
        torch::detail::TensorDataContainer(1.1).init_list(),
966
        "Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`");
967
    ASSERT_THROWS_WITH(
968
        torch::detail::TensorDataContainer(1.1).tensor(),
969
        "Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`");
970
  }
971
  {
972
    ASSERT_THROWS_WITH(
973
        torch::detail::TensorDataContainer({1.1, 2.2}).scalar(),
974
        "Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`");
975
    ASSERT_THROWS_WITH(
976
        torch::detail::TensorDataContainer({1.1, 2.2}).tensor(),
977
        "Can only call `tensor()` on a TensorDataContainer that has `is_tensor() == true`");
978
  }
979
  {
980
    ASSERT_THROWS_WITH(
981
        torch::detail::TensorDataContainer(at::ArrayRef<double>({1.1, 2.2}))
982
            .scalar(),
983
        "Can only call `scalar()` on a TensorDataContainer that has `is_scalar() == true`");
984
    ASSERT_THROWS_WITH(
985
        torch::detail::TensorDataContainer(at::ArrayRef<double>({1.1, 2.2}))
986
            .init_list(),
987
        "Can only call `init_list()` on a TensorDataContainer that has `is_init_list() == true`");
988
  }
989
}
990

991
TEST(TensorTest, FromBlob) {
992
  std::vector<double> v = {1.0, 2.0, 3.0};
993
  auto tensor = torch::from_blob(
994
      v.data(), v.size(), torch::dtype(torch::kFloat64).requires_grad(true));
995
  ASSERT_TRUE(tensor.requires_grad());
996
  ASSERT_EQ(tensor.dtype(), torch::kFloat64);
997
  ASSERT_EQ(tensor.numel(), 3);
998
  ASSERT_EQ(tensor[0].item<double>(), 1);
999
  ASSERT_EQ(tensor[1].item<double>(), 2);
1000
  ASSERT_EQ(tensor[2].item<double>(), 3);
1001
  // Above syntax did not copy the data, and has nullptr deleter context.
1002
  ASSERT_EQ(tensor.storage().data_ptr().get_context(), nullptr);
1003
}
1004

1005
TEST(TensorTest, FromBlobUsesDeleter) {
1006
  bool called = false;
1007
  {
1008
    std::vector<int32_t> v = {1, 2, 3};
1009
    auto tensor = torch::from_blob(
1010
        v.data(),
1011
        v.size(),
1012
        /*deleter=*/[&called](void* data) { called = true; },
1013
        torch::kInt32);
1014
  }
1015
  ASSERT_TRUE(called);
1016
}
1017

1018
TEST(TensorTest, FromBlobWithStrides) {
1019
  // clang-format off
1020
  std::vector<int32_t> v = {
1021
    1, 2, 3,
1022
    4, 5, 6,
1023
    7, 8, 9
1024
  };
1025
  // clang-format on
1026
  auto tensor = torch::from_blob(
1027
      v.data(),
1028
      /*sizes=*/{3, 3},
1029
      /*strides=*/{1, 3},
1030
      torch::kInt32);
1031
  ASSERT_EQ(tensor.dtype(), torch::kInt32);
1032
  ASSERT_EQ(tensor.numel(), 9);
1033
  const std::vector<int64_t> expected_strides = {1, 3};
1034
  ASSERT_EQ(tensor.strides(), expected_strides);
1035
  for (const auto i : c10::irange(tensor.size(0))) {
1036
    for (const auto j : c10::irange(tensor.size(1))) {
1037
      // NOTE: This is column major because the strides are swapped.
1038
      EXPECT_EQ(tensor[i][j].item<int32_t>(), 1 + (j * tensor.size(1)) + i);
1039
    }
1040
  }
1041
}
1042

1043
TEST(TensorTest, Item) {
1044
  {
1045
    torch::Tensor tensor = torch::tensor(3.14);
1046
    torch::Scalar scalar = tensor.item();
1047
    ASSERT_NEAR(scalar.to<float>(), 3.14, 1e-5);
1048
  }
1049
  {
1050
    torch::Tensor tensor = torch::tensor(123);
1051
    torch::Scalar scalar = tensor.item();
1052
    ASSERT_EQ(scalar.to<int>(), 123);
1053
  }
1054
}
1055

1056
TEST(TensorTest, Item_CUDA) {
1057
  {
1058
    torch::Tensor tensor = torch::tensor(3.14, torch::kCUDA);
1059
    torch::Scalar scalar = tensor.item();
1060
    ASSERT_NEAR(scalar.to<float>(), 3.14, 1e-5);
1061
  }
1062
  {
1063
    torch::Tensor tensor = torch::tensor(123, torch::kCUDA);
1064
    torch::Scalar scalar = tensor.item();
1065
    ASSERT_EQ(scalar.to<int>(), 123);
1066
  }
1067
}
1068

1069
TEST(TensorTest, DataPtr) {
1070
  auto tensor = at::empty({3, 4}, at::kFloat);
1071
  auto tensor_not_copy = tensor.to(tensor.options());
1072
  ASSERT_EQ(tensor_not_copy.data_ptr<float>(), tensor.data_ptr<float>());
1073
  ASSERT_EQ(tensor_not_copy.data_ptr(), tensor.data_ptr());
1074
}
1075

1076
TEST(TensorTest, Data) {
1077
  const auto tensor = torch::rand({3, 3});
1078
  ASSERT_TRUE(torch::equal(tensor, tensor.data()));
1079
}
1080

1081
TEST(TensorTest, BackwardAndGrad) {
1082
  auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
1083
  auto y = x * x;
1084
  y.backward();
1085
  ASSERT_EQ(x.grad().item<float>(), 10.0);
1086
}
1087

1088
TEST(TensorTest, BackwardCreatesOnesGrad) {
1089
  const auto x =
1090
      torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
1091
  x.backward();
1092
  ASSERT_TRUE(torch::equal(x.grad(), torch::ones_like(x)));
1093
}
1094

1095
TEST(TensorTest, BackwardNonScalarOutputs) {
1096
  auto x = torch::randn({5, 5}, torch::requires_grad());
1097
  auto y = x * x;
1098
  ASSERT_THROWS_WITH(
1099
      y.backward(), "grad can be implicitly created only for scalar outputs");
1100
}
1101

1102
TEST(TensorTest, BackwardComplexScalarOutput) {
1103
  auto x = torch::randn({5, 5}, torch::requires_grad());
1104
  auto y = (x * c10::Scalar(c10::complex<float>(0, 0.5))).sum();
1105
  ASSERT_THROWS_WITH(
1106
      y.backward(), "grad can be computed only for real scalar outputs");
1107
}
1108

1109
TEST(TensorTest, IsLeaf) {
1110
  auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
1111
  auto y = x * x;
1112
  ASSERT_TRUE(x.is_leaf());
1113
  ASSERT_FALSE(y.is_leaf());
1114
}
1115

1116
TEST(TensorTest, OutputNr) {
1117
  auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
1118
  auto y = x * x;
1119
  ASSERT_EQ(x.output_nr(), 0);
1120
  ASSERT_EQ(y.output_nr(), 0);
1121
}
1122

1123
TEST(TensorTest, Version) {
1124
  auto x = torch::ones(3);
1125
  ASSERT_EQ(x._version(), 0);
1126
  x.mul_(2);
1127
  ASSERT_EQ(x._version(), 1);
1128
  x.add_(1);
1129
  ASSERT_EQ(x._version(), 2);
1130
}
1131

1132
TEST(TensorTest, Detach) {
1133
  auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
1134
  auto y = x * x;
1135
  const auto y_detached = y.detach();
1136
  ASSERT_FALSE(y.is_leaf());
1137
  ASSERT_TRUE(y_detached.is_leaf());
1138
  ASSERT_FALSE(y_detached.requires_grad());
1139
}
1140

1141
TEST(TensorTest, DetachInplace) {
1142
  auto x = torch::tensor({5}, torch::dtype(torch::kFloat).requires_grad(true));
1143
  auto y = x * x;
1144
  auto y_detached = y.detach_();
1145
  ASSERT_TRUE(y.is_leaf());
1146
  ASSERT_FALSE(y.requires_grad());
1147
  ASSERT_TRUE(y_detached.is_leaf());
1148
  ASSERT_FALSE(y_detached.requires_grad());
1149
}
1150

1151
TEST(TensorTest, SetData) {
1152
  auto x = torch::randn({5});
1153
  auto y = torch::randn({5});
1154
  ASSERT_FALSE(torch::equal(x, y));
1155
  ASSERT_NE(x.data_ptr<float>(), y.data_ptr<float>());
1156

1157
  x.set_data(y);
1158
  ASSERT_TRUE(torch::equal(x, y));
1159
  ASSERT_EQ(x.data_ptr<float>(), y.data_ptr<float>());
1160
}
1161

1162
TEST(TensorTest, RequiresGradInplace) {
1163
  auto x = torch::tensor({5.0});
1164
  x.requires_grad_(true);
1165
  ASSERT_TRUE(x.requires_grad());
1166

1167
  auto y = x * x;
1168
  ASSERT_THROWS_WITH(
1169
      y.requires_grad_(false),
1170
      "you can only change requires_grad flags of leaf variables.");
1171

1172
  x.requires_grad_(false);
1173
  ASSERT_FALSE(x.requires_grad());
1174

1175
  const auto int_tensor =
1176
      torch::tensor({5}, at::TensorOptions().dtype(torch::kInt));
1177
  ASSERT_THROWS_WITH(
1178
      int_tensor.requires_grad_(true),
1179
      "Only Tensors of floating point and complex dtype can require gradients");
1180
}
1181

1182
TEST(TensorTest, StdDimension) {
1183
  // Test that std(0) doesn't select the std(unbiased=False) overload (gh-40287)
1184
  auto x = torch::randn({4, 3});
1185
  auto std = x.std(0);
1186

1187
  ASSERT_EQ(x.var(0).numel(), 3);
1188
  ASSERT_EQ(x.std(0).numel(), 3);
1189

1190
  ASSERT_EQ(x.var(0, /*unbiased=*/true).numel(), 3);
1191
  ASSERT_EQ(x.std(0, /*unbiased=*/true).numel(), 3);
1192

1193
  ASSERT_EQ(torch::var(x, 0).numel(), 3);
1194
  ASSERT_EQ(std::get<0>(torch::var_mean(x, 0)).numel(), 3);
1195
  ASSERT_EQ(torch::std(x, 0).numel(), 3);
1196
  ASSERT_EQ(std::get<0>(torch::std_mean(x, 0)).numel(), 3);
1197

1198
  ASSERT_EQ(torch::var(x, 0, /*unbiased=*/true).numel(), 3);
1199
  ASSERT_EQ(std::get<0>(torch::var_mean(x, 0, /*unbiased=*/true)).numel(), 3);
1200
  ASSERT_EQ(torch::std(x, 0, /*unbiased=*/true).numel(), 3);
1201
  ASSERT_EQ(std::get<0>(torch::std_mean(x, 0, /*unbiased=*/true)).numel(), 3);
1202
}
1203

1204
TEST(TensorTest, ReshapeAlias) {
1205
  // Tests the behavior of the _reshape_alias private operator so
1206
  // that it matches the behavior of as_strided and view.
1207
  auto x = torch::randn({3, 3});
1208
  ASSERT_TRUE(torch::equal(
1209
      torch::_reshape_alias(x, {2, 2}, {1, 2}),
1210
      torch::as_strided(x, {2, 2}, {1, 2})));
1211
  ASSERT_TRUE(torch::equal(torch::_reshape_alias(x, {9}, {1}), x.view({-1})));
1212

1213
  // Test that the backward works fine.
1214
  auto y = torch::randn({3, 3}, torch::requires_grad(true));
1215
  auto z = torch::clone(y).detach().requires_grad_(true);
1216
  (y * y).view({-1}).mean().backward();
1217
  torch::_reshape_alias((z * z), {9}, {1}).mean().backward();
1218
  ASSERT_TRUE(torch::equal(y.grad(), z.grad()));
1219
}
1220

1221
TEST(TensorTest, BackendMetadata) {
1222
  // Tests ability to assign custom backend metadata to tensor.
1223

1224
  struct CustomBackendMetadata : public c10::BackendMeta {
1225
    mutable bool cloned_{false}; // for testing this field will mutate when
1226
                                 // clone() is called by shallow_copy_from.
1227
    c10::intrusive_ptr<c10::BackendMeta> clone(
1228
        const c10::intrusive_ptr<c10::BackendMeta>& ptr) const override {
1229
      cloned_ = true;
1230
      return c10::BackendMeta::clone(ptr);
1231
    }
1232
  };
1233

1234
  at::Tensor y;
1235
  c10::intrusive_ptr<c10::BackendMeta> tmeta{};
1236
  CustomBackendMetadata* custom_tmeta{nullptr};
1237

1238
  {
1239
    auto x = torch::ones({3, 3});
1240
    auto impl{x.unsafeGetTensorImpl()};
1241
    ASSERT_TRUE(impl != nullptr);
1242

1243
    tmeta = impl->get_backend_meta_intrusive_ptr();
1244
    ASSERT_TRUE(tmeta == nullptr);
1245
    c10::intrusive_ptr<c10::BackendMeta> new_tmeta{
1246
        std::unique_ptr<c10::BackendMeta>(new CustomBackendMetadata())};
1247
    impl->set_backend_meta(new_tmeta);
1248
    tmeta = impl->get_backend_meta_intrusive_ptr();
1249
    ASSERT_TRUE(tmeta == new_tmeta);
1250
    custom_tmeta = dynamic_cast<CustomBackendMetadata*>(tmeta.get());
1251
    ASSERT_TRUE(custom_tmeta != nullptr);
1252
    ASSERT_TRUE(custom_tmeta->cloned_ == false);
1253
    y.unsafeGetTensorImpl()->shallow_copy_from(x.getIntrusivePtr());
1254
  }
1255

1256
  ASSERT_TRUE(
1257
      tmeta == y.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr());
1258
  ASSERT_TRUE(tmeta.get() == y.unsafeGetTensorImpl()->get_backend_meta());
1259
  ASSERT_TRUE(custom_tmeta->cloned_ == true);
1260
}
1261

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.