3
# Must happen before importing caffe2.python.*
4
import caffe2.python.fakelowp.init_shared_libs # noqa
7
from hypothesis import given, settings
8
from hypothesis import strategies as st
9
from caffe2.proto import caffe2_pb2
10
from caffe2.python import core, workspace
11
from caffe2.python.onnx.onnxifi import onnxifi_caffe2_net
12
from caffe2.python.fakelowp.test_utils import print_test_debug_info
13
import caffe2.python.serialized_test.serialized_test_util as serial
18
"--glow_global_fp16=0",
19
"--glow_global_fused_scale_offset_fp16=0",
20
"--glow_global_force_sls_fp16_accum=0",
23
GLOW_MATMUL_ATOL = 1e-5
24
GLOW_MATMUL_RTOL = 1e-3
26
class SparseLengthsSum8BitFakeNNPIFp32Test(serial.SerializedTestCase):
28
seed=st.integers(0, 65535),
29
num_rows=st.integers(2, 20),
30
embedding_dim=st.sampled_from([8, 12, 16, 24, 32, 54, 64, 128]),
31
batch_size=st.integers(1, 5),
32
max_weight=st.integers(0, 100),
34
@settings(deadline=datetime.timedelta(seconds=10))
35
def test_slws_fused_8bit_rowwise_acc32_nnpi(
36
self, seed, num_rows, embedding_dim, batch_size, max_weight
41
"--glow_global_fp16=0",
42
"--glow_global_fused_scale_offset_fp16=0",
43
"--glow_global_force_sls_fp16_accum=0",
47
workspace.ResetWorkspace()
49
data = np.random.rand(num_rows, embedding_dim).astype(np.float32)
50
lengths = np.random.choice(np.arange(1, num_rows), batch_size).astype(np.int32)
53
for length in lengths:
54
_indices.extend(np.random.choice(np.arange(1, num_rows), length))
55
indices = np.asarray(_indices).astype(np.int64)
57
weights = np.random.uniform(
63
pred_net = caffe2_pb2.NetDef()
64
pred_net.name = "pred"
65
pred_net.external_input.extend(
66
["quantized_data", "weights", "indices", "lengths"]
68
pred_net.external_output.append("Y")
69
pred_net.op.add().CopyFrom(
71
"SparseLengthsWeightedSumFused8BitRowwise",
72
["quantized_data", "weights", "indices", "lengths"],
77
ref_net = caffe2_pb2.NetDef()
79
ref_net.external_input.extend(
80
["quantized_data", "weights", "indices", "lengths"]
82
ref_net.external_output.append("Y")
83
ref_net.op.add().CopyFrom(
85
"SparseLengthsWeightedSumFused8BitRowwiseFakeFP32NNPI",
86
["quantized_data", "weights", "indices", "lengths"],
91
workspace.FeedBlob("data", data)
92
workspace.RunOperatorOnce(
94
"FloatToFused8BitRowwiseQuantized",
99
onnxified_net = onnxifi_caffe2_net(
102
max_batch_size=batch_size,
103
max_seq_size=np.max(lengths),
108
num_onnxified_ops = sum(
109
1 if o.type == "Onnxifi" else 0 for o in onnxified_net.op)
110
np.testing.assert_equal(num_onnxified_ops, 1)
112
workspace.FeedBlob("indices", indices)
113
workspace.FeedBlob("lengths", lengths)
114
workspace.FeedBlob("weights", weights)
116
workspace.CreateNet(onnxified_net)
117
workspace.CreateNet(ref_net)
119
workspace.RunNet(onnxified_net.name)
120
Y_glow = workspace.FetchBlob("Y")
122
workspace.RunNet(ref_net.name)
123
Y_ref = workspace.FetchBlob("Y")
125
diff = np.abs((Y_ref - Y_glow) / (Y_ref + 1e-8))
126
max_err = np.max(diff, axis=1)
127
num_offenders = (max_err > 0).sum()
128
if num_offenders > 0:
129
print_test_debug_info(
130
"test_slws_fused_8bit_rowwise_acc32_nnpi",
133
"num_rows": num_rows,
134
"embedding_dim": embedding_dim,
135
"batch_size": batch_size,
143
"rowwise_diff": np.max(diff, axis=1),
149
@given(seed=st.integers(0, 65535))
150
@settings(deadline=datetime.timedelta(seconds=10))
151
def test_small_sls_acc32(self, seed):
152
workspace.GlobalInit(
155
"--glow_global_fp16=0",
156
"--glow_global_fused_scale_offset_fp16=0",
157
"--glow_global_force_sls_fp16_accum=0",
161
workspace.ResetWorkspace()
165
data = 4 * (np.random.random_sample((n, DIM)) + 1).astype(np.float32)
167
lengths = np.array([n], dtype=np.int32)
168
indices = np.array(range(n), dtype=np.int64)
169
weights = np.random.uniform(low=0.01, high=0.5, size=[n]).astype(np.float32)
171
pred_net = caffe2_pb2.NetDef()
172
pred_net.name = "pred"
173
pred_net.external_input.extend(
174
["quantized_data", "weights", "indices", "lengths"]
176
pred_net.external_output.append("Y")
177
pred_net.op.add().CopyFrom(
179
"SparseLengthsWeightedSumFused8BitRowwise",
180
["quantized_data", "weights", "indices", "lengths"],
185
ref_net = caffe2_pb2.NetDef()
187
ref_net.external_input.extend(
188
["quantized_data", "weights", "indices", "lengths"]
190
ref_net.external_output.append("Y")
191
ref_net.op.add().CopyFrom(
193
"SparseLengthsWeightedSumFused8BitRowwiseFakeFP32NNPI",
194
["quantized_data", "weights", "indices", "lengths"],
199
workspace.FeedBlob("data", data)
200
workspace.RunOperatorOnce(
202
"FloatToFused8BitRowwiseQuantized", ["data"], ["quantized_data"]
206
quantized_data = workspace.FetchBlob("quantized_data")
208
onnxified_net = onnxifi_caffe2_net(
217
num_onnxified_ops = sum(
218
1 if o.type == "Onnxifi" else 0 for o in onnxified_net.op)
219
np.testing.assert_equal(num_onnxified_ops, 1)
221
workspace.FeedBlob("indices", indices)
222
workspace.FeedBlob("lengths", lengths)
223
workspace.FeedBlob("weights", weights)
225
workspace.CreateNet(onnxified_net)
226
workspace.CreateNet(ref_net)
228
workspace.RunNet(onnxified_net.name)
229
Y_glow = workspace.FetchBlob("Y")
231
workspace.RunNet(ref_net.name)
232
Y_ref = workspace.FetchBlob("Y")
234
diff = np.abs((Y_ref - Y_glow) / (Y_ref + 1e-8))
235
max_err = np.max(diff, axis=1)
236
num_offenders = (max_err > 0).sum()
237
if num_offenders > 0:
238
np.set_printoptions(precision=12)
241
Y_ref.astype(np.float16).astype(np.float32),
243
Y_glow.astype(np.float16).astype(np.float32),
245
print_test_debug_info(
246
"test_small_sls_acc32",
251
"quantized_data": quantized_data,
257
"rowwise_diff": np.max(diff, axis=1),
263
if __name__ == '__main__':