pytorch

Форк
0
/
meta_tensor.cpp 
35 строк · 1.2 Кб
1
#include <gtest/gtest.h>
2

3
#include <ATen/MetaFunctions.h>
4
#include <torch/torch.h>
5

6
#include <vector>
7

8
TEST(MetaTensorTest, MetaDeviceApi) {
9
  auto a = at::ones({4}, at::kFloat);
10
  auto b = at::ones({3, 4}, at::kFloat);
11
  // at::add() will return a meta tensor if its inputs are also meta tensors.
12
  auto out_meta = at::add(a.to(c10::kMeta), b.to(c10::kMeta));
13

14
  ASSERT_EQ(a.device(), c10::kCPU);
15
  ASSERT_EQ(b.device(), c10::kCPU);
16
  ASSERT_EQ(out_meta.device(), c10::kMeta);
17
  c10::IntArrayRef sizes_actual = out_meta.sizes();
18
  std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4};
19
  ASSERT_EQ(sizes_actual, sizes_expected);
20
}
21

22
TEST(MetaTensorTest, MetaNamespaceApi) {
23
  auto a = at::ones({4}, at::kFloat);
24
  auto b = at::ones({3, 4}, at::kFloat);
25
  // The at::meta:: namespace take in tensors from any backend
26
  // and return a meta tensor.
27
  auto out_meta = at::meta::add(a, b);
28

29
  ASSERT_EQ(a.device(), c10::kCPU);
30
  ASSERT_EQ(b.device(), c10::kCPU);
31
  ASSERT_EQ(out_meta.device(), c10::kMeta);
32
  c10::IntArrayRef sizes_actual = out_meta.sizes();
33
  std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4};
34
  ASSERT_EQ(sizes_actual, sizes_expected);
35
}
36

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.