5
import caffe2.python.fakelowp.init_shared_libs
7
from caffe2.proto import caffe2_pb2
8
from caffe2.python import core, workspace
9
from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net
10
from caffe2.python.fakelowp.test_utils import print_test_debug_info
12
from hypothesis import given, settings
13
import hypothesis.strategies as st
14
import caffe2.python.serialized_test.serialized_test_util as serial
16
core.GlobalInit(["caffe2", "--caffe2_log_level=-3", "--glow_global_fp16=1"])
19
class TestBatchMatMul(serial.SerializedTestCase):
21
C=st.integers(min_value=1, max_value=10),
22
M=st.integers(min_value=1, max_value=50),
23
K=st.integers(min_value=1, max_value=512),
24
N=st.integers(min_value=1, max_value=50),
25
rand_seed=st.integers(0, 65534),
26
trans_a=st.booleans(),
27
trans_b=st.booleans(),
28
run_ints=st.booleans()
30
@settings(deadline=datetime.timedelta(seconds=10))
31
def test_batch_matmul(self, M, K, N, C, rand_seed, trans_a, trans_b, run_ints):
32
np.random.seed(rand_seed)
33
workspace.ResetWorkspace()
38
X = np.random.randint(low=1, high=3, size=((C, M, K))).astype(np.float32)
40
X = 100 * (np.random.rand(*(batch_dims + [M, K])).astype(np.float32) - 0.5)
42
X = X.swapaxes(-1, -2)
45
Y = np.random.randint(low=1, high=3, size=((C, K, N))).astype(np.float32)
47
Y = 100 * (np.random.rand(*(batch_dims + [K, N])).astype(np.float32) - 0.5)
49
Y = Y.swapaxes(-1, -2)
51
pred_net = caffe2_pb2.NetDef()
52
pred_net.name = "pred"
53
pred_net.external_input.extend(["X", "Y"])
54
pred_net.external_output.append("out")
55
pred_net.op.add().CopyFrom(
57
'BatchMatMul', ['X', 'Y'], 'out', trans_a=trans_a, trans_b=trans_b
61
pred_net_ref = core.Net("pred_net_ref")
64
pred_net_ref.BatchMatMulFP16Acc32Fake(
65
["X", "Y"], ['out'], trans_a=trans_a, trans_b=trans_b)
67
print("dims", batch_dims, X.shape, Y.shape)
68
pred_net_onnxified = onnxifi_caffe2_net(pred_net,
69
{"X": X.shape, "Y": Y.shape},
73
num_onnxified_ops = sum(
74
1 if o.type == "Onnxifi" else 0 for o in pred_net_onnxified.op)
75
np.testing.assert_equal(num_onnxified_ops, 1)
77
workspace.FeedBlob("X", X)
78
workspace.FeedBlob("Y", Y)
79
workspace.CreateNet(pred_net_onnxified)
80
workspace.CreateNet(pred_net_ref)
83
workspace.RunNet(pred_net_onnxified.name)
84
out_glow = workspace.FetchBlob('out')
87
workspace.RunNet(pred_net_ref)
88
out_c2_fakefp16 = workspace.FetchBlob('out')
90
diff = np.abs(out_c2_fakefp16 - out_glow)
92
if not np.allclose(out_glow, out_c2_fakefp16):
93
print_test_debug_info("bmm", {
96
"n": N, "X": X.shape, "Y": Y.shape,
100
"out_glow": out_glow,
101
"out_c2_fakefp16": out_c2_fakefp16,
107
if __name__ == "__main__":