pytorch
49 строк · 1.7 Кб
1#include <gtest/gtest.h>
2#include "caffe2/core/net.h"
3#include "caffe2/core/operator.h"
4#include "caffe2/transforms/conv_to_nnpack_transform.h"
5
6namespace caffe2 {
7
8namespace {
9
10using transform::Graph;
11
12TEST(ConvToNNPackTest, TestSimple) {
13NetDef netdef;
14// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
15OperatorDef* op;
16// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
17op = AddOp(&netdef, "Conv", {"in"}, {"out"});
18// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
19op = AddOp(&netdef, "Relu", {"out"}, {"out"});
20op = AddOp(&netdef, "Conv", {"out"}, {"out"}); // if not CPU, won't transform
21op->mutable_device_option()->set_device_type(PROTO_CUDA);
22// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
23op = AddOp(&netdef, "Relu", {"out"}, {"out"});
24op = AddOp(&netdef, "Conv", {"out"}, {"out"});
25op->set_engine("NNPACK"); // does not need to be transformed
26// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
27op = AddOp(&netdef, "Relu", {"out"}, {"out"});
28// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
29op = AddOp(&netdef, "Conv", {"out"}, {"out"});
30// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
31op = AddOp(&netdef, "Relu", {"out"}, {"out"});
32
33auto t = TransformRegistry()->Create("ConvToNNPack");
34NetDef transformed_netdef = t->ApplyTo(netdef);
35
36int nnpack_count = 0;
37for (auto& op : transformed_netdef.op()) {
38if (op.type() == "Conv" && op.device_option().device_type() == PROTO_CPU) {
39EXPECT_EQ(op.engine(), "NNPACK");
40nnpack_count++;
41}
42}
43EXPECT_EQ(nnpack_count, 3);
44EXPECT_EQ(t->PatternMatch(Graph(netdef)).size(), 2); // should get 2 matches
45}
46
47} // namespace
48
49} // namespace caffe2
50