1
#include <gtest/gtest.h>
3
#include <ATen/MetaFunctions.h>
4
#include <torch/torch.h>
8
TEST(MetaTensorTest, MetaDeviceApi) {
9
auto a = at::ones({4}, at::kFloat);
10
auto b = at::ones({3, 4}, at::kFloat);
11
// at::add() will return a meta tensor if its inputs are also meta tensors.
12
auto out_meta = at::add(a.to(c10::kMeta), b.to(c10::kMeta));
14
ASSERT_EQ(a.device(), c10::kCPU);
15
ASSERT_EQ(b.device(), c10::kCPU);
16
ASSERT_EQ(out_meta.device(), c10::kMeta);
17
c10::IntArrayRef sizes_actual = out_meta.sizes();
18
std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4};
19
ASSERT_EQ(sizes_actual, sizes_expected);
22
TEST(MetaTensorTest, MetaNamespaceApi) {
23
auto a = at::ones({4}, at::kFloat);
24
auto b = at::ones({3, 4}, at::kFloat);
25
// The at::meta:: namespace take in tensors from any backend
26
// and return a meta tensor.
27
auto out_meta = at::meta::add(a, b);
29
ASSERT_EQ(a.device(), c10::kCPU);
30
ASSERT_EQ(b.device(), c10::kCPU);
31
ASSERT_EQ(out_meta.device(), c10::kMeta);
32
c10::IntArrayRef sizes_actual = out_meta.sizes();
33
std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4};
34
ASSERT_EQ(sizes_actual, sizes_expected);