pytorch

Форк
0
/
test_batchmatmul_nnpi_fp16.py 
108 строк · 3.7 Кб
1
# mypy: ignore-errors
2

3
import numpy as np
4
import unittest
5
import caffe2.python.fakelowp.init_shared_libs  # noqa
6

7
from caffe2.proto import caffe2_pb2
8
from caffe2.python import core, workspace
9
from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net
10
from caffe2.python.fakelowp.test_utils import print_test_debug_info
11
import datetime
12
from hypothesis import given, settings
13
import hypothesis.strategies as st
14
import caffe2.python.serialized_test.serialized_test_util as serial
15

16
core.GlobalInit(["caffe2", "--caffe2_log_level=-3", "--glow_global_fp16=1"])
17

18

19
class TestBatchMatMul(serial.SerializedTestCase):
20
    @given(
21
        C=st.integers(min_value=1, max_value=10),
22
        M=st.integers(min_value=1, max_value=50),
23
        K=st.integers(min_value=1, max_value=512),
24
        N=st.integers(min_value=1, max_value=50),
25
        rand_seed=st.integers(0, 65534),
26
        trans_a=st.booleans(),
27
        trans_b=st.booleans(),
28
        run_ints=st.booleans()
29
    )
30
    @settings(deadline=datetime.timedelta(seconds=10))
31
    def test_batch_matmul(self, M, K, N, C, rand_seed, trans_a, trans_b, run_ints):
32
        np.random.seed(rand_seed)
33
        workspace.ResetWorkspace()
34

35
        batch_dims = [C]
36

37
        if run_ints:
38
            X = np.random.randint(low=1, high=3, size=((C, M, K))).astype(np.float32)
39
        else:
40
            X = 100 * (np.random.rand(*(batch_dims + [M, K])).astype(np.float32) - 0.5)
41
        if trans_a:
42
            X = X.swapaxes(-1, -2)
43

44
        if run_ints:
45
            Y = np.random.randint(low=1, high=3, size=((C, K, N))).astype(np.float32)
46
        else:
47
            Y = 100 * (np.random.rand(*(batch_dims + [K, N])).astype(np.float32) - 0.5)
48
        if trans_b:
49
            Y = Y.swapaxes(-1, -2)
50

51
        pred_net = caffe2_pb2.NetDef()
52
        pred_net.name = "pred"
53
        pred_net.external_input.extend(["X", "Y"])
54
        pred_net.external_output.append("out")
55
        pred_net.op.add().CopyFrom(
56
            core.CreateOperator(
57
                'BatchMatMul', ['X', 'Y'], 'out', trans_a=trans_a, trans_b=trans_b
58
            )
59
        )
60

61
        pred_net_ref = core.Net("pred_net_ref")
62

63
        # Reference updated to fp16 with fp32 accumulation
64
        pred_net_ref.BatchMatMulFP16Acc32Fake(
65
            ["X", "Y"], ['out'], trans_a=trans_a, trans_b=trans_b)
66

67
        print("dims", batch_dims, X.shape, Y.shape)
68
        pred_net_onnxified = onnxifi_caffe2_net(pred_net,
69
                                                {"X": X.shape, "Y": Y.shape},
70
                                                debug=True,
71
                                                adjust_batch=False,
72
                                                use_onnx=False)
73
        num_onnxified_ops = sum(
74
            1 if o.type == "Onnxifi" else 0 for o in pred_net_onnxified.op)
75
        np.testing.assert_equal(num_onnxified_ops, 1)
76

77
        workspace.FeedBlob("X", X)
78
        workspace.FeedBlob("Y", Y)
79
        workspace.CreateNet(pred_net_onnxified)
80
        workspace.CreateNet(pred_net_ref)
81

82
        # Run Glow net
83
        workspace.RunNet(pred_net_onnxified.name)
84
        out_glow = workspace.FetchBlob('out')
85

86
        # Run caffe2 net
87
        workspace.RunNet(pred_net_ref)
88
        out_c2_fakefp16 = workspace.FetchBlob('out')
89

90
        diff = np.abs(out_c2_fakefp16 - out_glow)
91

92
        if not np.allclose(out_glow, out_c2_fakefp16):
93
            print_test_debug_info("bmm", {
94
                "seed": rand_seed,
95
                "m": M, "k": K,
96
                "n": N, "X": X.shape, "Y": Y.shape,
97
                "trans_a": trans_a,
98
                "trans_b": trans_b,
99
                "run_ints": run_ints,
100
                "out_glow": out_glow,
101
                "out_c2_fakefp16": out_c2_fakefp16,
102
                "diff": diff
103
            })
104
            assert(0)
105

106

107
if __name__ == "__main__":
108
    unittest.main()
109

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.