pytorch

Форк
0
/
resize_test.cc 
96 строк · 2.6 Кб
1
#include "caffe2/core/init.h"
2
#include "caffe2/core/operator.h"
3
#include "caffe2/core/tensor.h"
4
#include "caffe2/utils/math.h"
5
#include "caffe2/utils/proto_utils.h"
6
#include "gtest/gtest.h"
7

8
#include <cmath>
9
#include <random>
10

11
namespace caffe2 {
12

13
namespace {
14

15
void AddNoiseInput(const vector<int64_t>& shape, const string& name, Workspace* ws) {
16
  DeviceOption option;
17
  CPUContext context(option);
18
  Blob* blob = ws->CreateBlob(name);
19
  auto* tensor = BlobGetMutableTensor(blob, CPU);
20
  tensor->Resize(shape);
21

22
  math::RandGaussian<float, CPUContext>(
23
      tensor->size(), 0.0f, 3.0f, tensor->mutable_data<float>(), &context);
24
  for (auto i = 0; i < tensor->size(); ++i) {
25
    tensor->mutable_data<float>()[i] =
26
        std::min(-5.0f, std::max(5.0f, tensor->mutable_data<float>()[i]));
27
  }
28
}
29

30
void compareResizeNeareast(int N,
31
                       int C,
32
                       int H,
33
                       int W,
34
                       float wscale,
35
                       float hscale) {
36
  Workspace ws;
37

38
  OperatorDef def1;
39
  def1.set_name("test");
40
  def1.set_type("ResizeNearest");
41
  def1.add_input("X");
42
  def1.add_output("Y");
43

44
  def1.add_arg()->CopyFrom(MakeArgument("width_scale", wscale));
45
  def1.add_arg()->CopyFrom(MakeArgument("height_scale", hscale));
46

47
  AddNoiseInput(vector<int64_t>{N, C, H, W}, "X", &ws);
48

49
  unique_ptr<OperatorBase> op1(CreateOperator(def1, &ws));
50
  EXPECT_NE(nullptr, op1.get());
51
  EXPECT_TRUE(op1->Run());
52

53
  const auto& X = ws.GetBlob("X")->Get<TensorCPU>();
54
  const auto& Y = ws.GetBlob("Y")->Get<TensorCPU>();
55

56
  const float* Xdata = X.data<float>();
57
  const float* Ydata = Y.data<float>();
58

59
  int outW = W * wscale;
60
  int outH = H * hscale;
61

62
  // Compare all output points
63
  for (int n = 0; n < N; ++n) {
64
    for (int c = 0; c < C; ++c) {
65
      for (int ph = 0; ph < outH; ++ph) {
66
        const int iny = std::min((int)(ph / hscale), (H - 1));
67
        for (int pw = 0; pw < outW; ++pw) {
68
          const int inx = std::min((int)(pw / wscale), (W - 1));
69
          const float v = Xdata[iny * W + inx];
70
          EXPECT_EQ(Ydata[outW * ph + pw], v);
71
        }
72
      }
73
      Xdata += H * W;
74
      Ydata += outW * outH;
75
    }
76
  }
77
}
78

79
int randInt(int a, int b) {
80
  static std::random_device rd;
81
  static std::mt19937 gen(rd());
82
  return std::uniform_int_distribution<int>(a, b)(gen);
83
}
84

85
TEST(ResizeNearestOp, ResizeNearest2x) {
86
  for (auto i = 0; i < 40; ++i) {
87
    auto H = randInt(1, 100);
88
    auto W = randInt(1, 100);
89
    auto C = randInt(1, 10);
90
    auto N = randInt(1, 2);
91
    compareResizeNeareast(N, C, H, W, 2.0f, 2.0f);
92
  }
93
}
94

95
} // unnamed namespace
96
} // namespace caffe2
97

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.