1
#include "caffe2/core/init.h"
2
#include "caffe2/core/operator.h"
3
#include "caffe2/core/tensor.h"
4
#include "caffe2/utils/math.h"
5
#include "caffe2/utils/proto_utils.h"
6
#include "gtest/gtest.h"
15
void AddNoiseInput(const vector<int64_t>& shape, const string& name, Workspace* ws) {
17
CPUContext context(option);
18
Blob* blob = ws->CreateBlob(name);
19
auto* tensor = BlobGetMutableTensor(blob, CPU);
20
tensor->Resize(shape);
22
math::RandGaussian<float, CPUContext>(
23
tensor->size(), 0.0f, 3.0f, tensor->mutable_data<float>(), &context);
24
for (auto i = 0; i < tensor->size(); ++i) {
25
tensor->mutable_data<float>()[i] =
26
std::min(-5.0f, std::max(5.0f, tensor->mutable_data<float>()[i]));
30
void compareResizeNeareast(int N,
39
def1.set_name("test");
40
def1.set_type("ResizeNearest");
44
def1.add_arg()->CopyFrom(MakeArgument("width_scale", wscale));
45
def1.add_arg()->CopyFrom(MakeArgument("height_scale", hscale));
47
AddNoiseInput(vector<int64_t>{N, C, H, W}, "X", &ws);
49
unique_ptr<OperatorBase> op1(CreateOperator(def1, &ws));
50
EXPECT_NE(nullptr, op1.get());
51
EXPECT_TRUE(op1->Run());
53
const auto& X = ws.GetBlob("X")->Get<TensorCPU>();
54
const auto& Y = ws.GetBlob("Y")->Get<TensorCPU>();
56
const float* Xdata = X.data<float>();
57
const float* Ydata = Y.data<float>();
59
int outW = W * wscale;
60
int outH = H * hscale;
63
for (int n = 0; n < N; ++n) {
64
for (int c = 0; c < C; ++c) {
65
for (int ph = 0; ph < outH; ++ph) {
66
const int iny = std::min((int)(ph / hscale), (H - 1));
67
for (int pw = 0; pw < outW; ++pw) {
68
const int inx = std::min((int)(pw / wscale), (W - 1));
69
const float v = Xdata[iny * W + inx];
70
EXPECT_EQ(Ydata[outW * ph + pw], v);
79
int randInt(int a, int b) {
80
static std::random_device rd;
81
static std::mt19937 gen(rd());
82
return std::uniform_int_distribution<int>(a, b)(gen);
85
TEST(ResizeNearestOp, ResizeNearest2x) {
86
for (auto i = 0; i < 40; ++i) {
87
auto H = randInt(1, 100);
88
auto W = randInt(1, 100);
89
auto C = randInt(1, 10);
90
auto N = randInt(1, 2);
91
compareResizeNeareast(N, C, H, W, 2.0f, 2.0f);