1
#include "caffe2/core/init.h"
2
#include "caffe2/core/operator.h"
3
#include "caffe2/core/tensor.h"
4
#include "caffe2/utils/math.h"
5
#include "caffe2/utils/proto_utils.h"
6
#include "gtest/gtest.h"
15
void AddNoiseInput(const vector<int64_t>& shape, const string& name, Workspace* ws) {
17
CPUContext context(option);
18
Blob* blob = ws->CreateBlob(name);
19
auto* tensor = BlobGetMutableTensor(blob, CPU);
20
tensor->Resize(shape);
22
math::RandGaussian<float, CPUContext>(
23
tensor->size(), 0.0f, 3.0f, tensor->mutable_data<float>(), &context);
24
for (auto i = 0; i < tensor->size(); ++i) {
25
tensor->mutable_data<float>()[i] =
26
std::min(-5.0f, std::max(5.0f, tensor->mutable_data<float>()[i]));
30
void compareMaxPooling(int N,
42
float maxRelErr = 1.0e-5f,
43
float absErrForRelErrFailure = 1.0e-5f) {
47
def1.set_name("test");
48
def1.set_type("MaxPool");
52
def1.add_arg()->CopyFrom(MakeArgument("kernel_h", kernelH));
53
def1.add_arg()->CopyFrom(MakeArgument("kernel_w", kernelW));
54
def1.add_arg()->CopyFrom(MakeArgument("stride_h", strideH));
55
def1.add_arg()->CopyFrom(MakeArgument("stride_w", strideW));
56
def1.add_arg()->CopyFrom(MakeArgument("pad_t", padT));
57
def1.add_arg()->CopyFrom(MakeArgument("pad_l", padL));
58
def1.add_arg()->CopyFrom(MakeArgument("pad_b", padB));
59
def1.add_arg()->CopyFrom(MakeArgument("pad_r", padR));
61
AddNoiseInput(vector<int64_t>{N, C, H, W}, "X", &ws);
63
unique_ptr<OperatorBase> op1(CreateOperator(def1, &ws));
64
EXPECT_NE(nullptr, op1.get());
65
EXPECT_TRUE(op1->Run());
67
const auto& X = ws.GetBlob("X")->Get<TensorCPU>();
68
const auto& Y = ws.GetBlob("Y")->Get<TensorCPU>();
70
// Compare all output points
71
for (int n = 0; n < Y.dim32(0); ++n) {
72
for (int c = 0; c < Y.dim32(1); ++c) {
73
for (int ph = 0; ph < Y.dim32(2); ++ph) {
74
for (int pw = 0; pw < Y.dim32(3); ++pw) {
75
// Reference implementations
76
int hstart = ph * strideH - padT;
77
int wstart = pw * strideW - padL;
78
int hend = std::min(hstart + kernelH, H);
79
int wend = std::min(wstart + kernelW, W);
80
hstart = std::max(hstart, 0);
81
wstart = std::max(wstart, 0);
82
const int pool_index = ph * Y.dim32(3) + pw;
83
float v = std::numeric_limits<float>::lowest();
84
for (int h = hstart; h < hend; ++h) {
85
for (int w = wstart; w < wend; ++w) {
87
X.data<float>() + n * X.dim(1) * X.dim(2) * X.dim(3) + c * X.dim(2) * X.dim(3);
88
const int input_index = h * W + w;
89
v = std::max(v, Xdata[input_index]);
92
EXPECT_EQ(Y.data<float>()[n * Y.dim(1) * Y.dim(2) * Y.dim(3) + c * Y.dim(2) * Y.dim(3) +
101
int randInt(int a, int b) {
102
static std::random_device rd;
103
static std::mt19937 gen(rd());
104
return std::uniform_int_distribution<int>(a, b)(gen);
107
void runMaxPool(int kernel, int stride, int pad) {
108
int N = randInt(1, 2);
109
int C = randInt(1, 12);
110
int H = randInt(50, 100);
111
int W = randInt(50, 100);
112
int planesOut = randInt(1, 6);
114
compareMaxPooling(N, C, H, W, kernel, kernel, stride, stride, pad, pad, pad, pad);
117
TEST(PoolOp, MaxPool2x2s2p0Randomized) {
118
for (int i = 0; i < 40; ++i) {
123
TEST(PoolOp, MaxPool4x4s3p2Randomized) {
124
for (int i = 0; i < 40; ++i) {
129
TEST(PoolOp, MaxPool2x2s2p0Special) {
130
// 2x2s2p0 where H/W % 4 == 0
131
compareMaxPooling(2, 10, 40, 40, 2, 2, 2, 2, 0, 0, 0, 0, 0.05f, 0.1f);
133
// 2x2s2p0 where H/W % 4 != 0
134
compareMaxPooling(2, 10, 39, 39, 2, 2, 2, 2, 0, 0, 0, 0, 0.05f, 0.1f);
136
// 2x2s2p0 where H/W % 16 == 0
137
compareMaxPooling(2, 10, 64, 64, 2, 2, 2, 2, 0, 0, 0, 0, 0.05f, 0.1f);
140
TEST(PoolOp, MaxPoolFullyRandomized) {
141
for (auto i = 0; i < 40; ++i) {
142
auto kernelH = randInt(1, 5);
143
auto kernelW = randInt(1, 5);
144
auto strideH = randInt(1, 5);
145
auto strideW = randInt(1, 5);
146
auto padL = randInt(0, kernelW - 1);
147
auto padR = randInt(0, kernelW - 1);
148
auto padT = randInt(0, kernelH - 1);
149
auto padB = randInt(0, kernelH - 1);
150
auto H = randInt(std::max(1, kernelH - padT - padB), 100);
151
auto W = randInt(std::max(1, kernelW - padL - padR), 100);
152
auto C = randInt(1, 10);
153
auto N = randInt(1, 2);
155
N, C, H, W, kernelH, kernelW, strideH, strideW, padT, padL, padB, padR);
158
} // unnamed namespace