4
from typing import Dict, Optional
9
from torch.testing._internal.common_utils import TestCase, run_tests
10
from typing import List
13
def __init__(self, scripted):
15
if hasattr(scripted, "_c"):
16
self.static_module = torch._C._jit_to_static_module(scripted._c)
18
self.static_module = torch._C._jit_to_static_module(scripted.graph)
20
def __call__(self, *args, **kwargs):
21
return self.static_module(*args, **kwargs)
23
def benchmark(self, args, kwargs, warmup_runs, main_runs):
24
self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
26
def runAsync(self, args, kwargs):
27
return self.static_module.runAsync(args, kwargs)
29
def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
30
return self.static_module.benchmark_individual_ops(
31
args, kwargs, warmup_runs, main_runs
36
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
38
output = input.matmul(weight.t())
45
torch.nn.functional.linear = linear_shim
48
class MultiHeadAttentionLayer(nn.Module):
49
def __init__(self, hid_dim, n_heads, dropout, device):
51
assert hid_dim % n_heads == 0
52
self.hid_dim = hid_dim
53
self.n_heads = n_heads
54
self.head_dim = hid_dim // n_heads
55
self.fc_q = nn.Linear(hid_dim, hid_dim)
56
self.fc_k = nn.Linear(hid_dim, hid_dim)
57
self.fc_v = nn.Linear(hid_dim, hid_dim)
58
self.fc_o = nn.Linear(hid_dim, hid_dim)
60
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
62
def forward(self, query, key, value, mask):
63
batch_size = query.shape[0]
67
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
68
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
69
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
70
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
72
attention = torch.softmax(energy, dim=-1)
74
x = torch.matmul(attention, V)
75
x = x.permute(0, 2, 1, 3).contiguous()
76
x = x.view(batch_size, -1, self.hid_dim)
82
def create_mlp(ln, sigmoid_layer):
83
layers = nn.ModuleList()
84
for i in range(0, len(ln) - 1):
88
LL = nn.Linear(int(n), int(m), bias=True)
91
std_dev = np.sqrt(2 / (m + n))
92
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
93
std_dev = np.sqrt(1 / m)
94
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
95
LL.weight.data = torch.tensor(W, requires_grad=True)
96
LL.bias.data = torch.tensor(bt, requires_grad=True)
99
if i == sigmoid_layer:
100
layers.append(nn.Sigmoid())
102
layers.append(nn.ReLU())
104
with torch.no_grad():
105
s = torch.jit.script(torch.nn.Sequential(*layers))
110
def trivial_graph(a, b, c):
111
s = torch.tensor([[3, 3], [3, 3]])
114
def elementwise_square_addition(input1, input2):
115
return input1 * input1 + input2 * input2
117
def fork_wait_graph1(input1, input2):
118
fut = torch.jit.fork(elementwise_square_addition, input1, input2)
119
return torch.jit.wait(fut)
121
def fork_wait_graph2(input1, input2):
122
fut = torch.jit.fork(loop_graph, input1, input2, 5)
123
return torch.jit.wait(fut)
126
graph with multiple fork/wait operations
127
:param input: torch.tensor input to forked subgraph
128
:param iters: number of future/wait pairs to be created
130
def fork_wait_graph3(input, iters: int):
131
futures : List[torch.jit.Future[torch.Tensor]] = []
132
for _ in range(iters):
133
futures.append(torch.jit.fork(torch.neg, input))
135
for future in futures:
136
results.append(torch.jit.wait(future))
137
return torch.sum(torch.stack(results))
140
graph with multi-level fork/wait operations
141
:param input: torch.tensor input to forked subgraph
142
:param num_forks: number of top level forks
143
:param num_child_forks: number of child forks per parent fork
145
def fork_wait_graph4(input, num_forks: int, num_child_forks: int):
146
futures : List[torch.jit.Future[torch.Tensor]] = []
147
for _ in range(num_forks):
148
futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks))
150
for future in futures:
151
results.append(torch.jit.wait(future))
152
return torch.sum(torch.stack(results))
154
def add_tensor(input1, input2):
155
return input1 + input2
157
def fork_wait_graph_exception(input1, input2):
158
fut = torch.jit.fork(add_tensor, input1, input2)
159
return torch.jit.wait(fut)
161
def loop_graph(a, b, iters: int):
163
for i in range(iters):
170
def output_graph(a, b, c, iters: int):
171
s = torch.tensor([[3, 3], [3, 3]])
173
d: Dict[int, torch.Tensor] = {}
174
for i in range(iters):
179
class SubModule(nn.Module):
185
def forward(self, x):
186
return self.a + self.b + x
189
class SubModule2(nn.Module):
195
def forward(self, x):
197
return self.a + self.b + x
200
class TestModule(nn.Module):
203
self.sub1 = SubModule()
204
self.sub2 = SubModule2()
208
def forward(self, x):
210
return self.sub1(x) + self.a + self.b + self.sub2(x)
213
class TestStaticModule(TestCase):
216
Test Case: To test simple fork/wait operation in a graph
217
fork is called on simple addition operation on input tensors
219
def test_fork_wait_1(self):
220
inp1 = torch.ones(5, 5)
221
inp2 = torch.randn(5, 5)
222
torch_graph = torch.jit.script(fork_wait_graph1)
223
output_ref = torch_graph(inp1, inp2)
224
static_runtime_module = StaticModule(torch_graph)
225
output_test = static_runtime_module(inp1, inp2)
226
torch.testing.assert_close(output_test, output_ref)
229
Test Case: To test simple fork/wait operation with
230
StaticRuntime runAsync API returning future
232
def test_fork_wait_1_async(self):
233
inp1 = torch.ones(5, 5)
234
inp2 = torch.randn(5, 5)
235
torch_graph = torch.jit.script(fork_wait_graph1)
236
output_ref = torch_graph(inp1, inp2)
237
static_runtime_module = StaticModule(torch_graph)
238
output_test = static_runtime_module.runAsync((inp1, inp2), {})
240
torch.testing.assert_close(output_test.value(), output_ref)
243
Test Case: To test fork/wait operation in a graph on
244
a loop subgraph performing mix of operations
246
def test_fork_wait_2(self):
247
inp1 = torch.randn(5, 5)
248
inp2 = torch.randn(5, 5)
249
torch_graph = torch.jit.script(fork_wait_graph2)
250
output_ref = torch_graph(inp1, inp2)
251
static_runtime_module = StaticModule(torch_graph)
252
output_test = static_runtime_module(inp1, inp2)
253
torch.testing.assert_close(output_test, output_ref)
256
Test Case: To test fork/wait operation on a loop
257
subgraph with StaticRuntime runAsync API returning future
259
def test_fork_wait_2_async(self):
260
inp1 = torch.randn(5, 5)
261
inp2 = torch.randn(5, 5)
262
torch_graph = torch.jit.script(fork_wait_graph2)
263
output_ref = torch_graph(inp1, inp2)
264
static_runtime_module = StaticModule(torch_graph)
265
output_test = static_runtime_module.runAsync((inp1, inp2), {})
267
torch.testing.assert_close(output_test.value(), output_ref)
270
Test Case: To test fork/wait operation in a graph on
271
having multiple fork/wait operations
273
def test_fork_wait_3(self):
274
input = torch.ones(3, 3)
276
torch_graph = torch.jit.script(fork_wait_graph3)
277
output_ref = torch_graph(input, num_forks)
278
static_runtime_module = StaticModule(torch_graph)
279
output_test = static_runtime_module(input, num_forks)
280
torch.testing.assert_close(output_test, output_ref)
283
Test Case: To test fork/wait operation in a graph with
284
multiple fork/wait operations on runAsync API returning future
286
def test_fork_wait_3_async(self):
287
input = torch.ones(3, 3)
289
torch_graph = torch.jit.script(fork_wait_graph3)
290
output_ref = torch_graph(input, num_forks)
291
static_runtime_module = StaticModule(torch_graph)
292
output_test = static_runtime_module.runAsync((input, num_forks), {})
294
torch.testing.assert_close(output_test.value(), output_ref)
297
Test Case: To test fork/wait operation in a graph on
298
multiple nested fork/wait operations
300
@unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
301
def test_fork_wait_4(self):
302
input = torch.ones(3, 3)
305
torch_graph = torch.jit.script(fork_wait_graph4)
306
static_runtime_module = StaticModule(torch_graph)
307
output_ref = torch_graph(input, num_forks, num_child_forks)
308
output_test = static_runtime_module(input, num_forks, num_child_forks)
309
torch.testing.assert_close(output_test, output_ref)
312
Test Case: To test fork/wait operation in a graph with multiple
313
nested fork/wait operations on runAsync API returning future
315
@unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
316
def test_fork_wait_4_async(self):
317
input = torch.ones(3, 3)
320
torch_graph = torch.jit.script(fork_wait_graph4)
321
static_runtime_module = StaticModule(torch_graph)
322
output_ref = torch_graph(input, num_forks, num_child_forks)
323
output_test = static_runtime_module.runAsync(
324
(input, num_forks, num_child_forks), {})
326
torch.testing.assert_close(output_test.value(), output_ref)
329
Test Case: To test exception handling in fork/wait
330
operation. Add.Tensor op is called for tensors with
331
non-matching dims on the forked subgraph and the
332
exception raised by subgraph is set on future returned
333
by prim::fork to parent graph. Returned exception is
334
checked for substring expected_error_msg as declared below
336
def test_fork_wait_exception(self):
338
input1 = torch.randn(4, 7)
339
input2 = torch.randn(4, 5)
340
torch_graph = torch.jit.script(fork_wait_graph_exception)
342
static_runtime_module = StaticModule(torch_graph)
343
output_test = static_runtime_module(input1, input2)
344
except Exception as error:
345
expected_error_msg = (
346
"The size of tensor a (7) must match the size "
347
"of tensor b (5) at non-singleton dimension 1"
350
if str(error).find(expected_error_msg) == -1:
352
"Tried execution of add.Tensors with incompatible shape. "
353
"Exception raised by forked runtime execution does "
354
f"not contain expected substring: \"{expected_error_msg}\""
358
Test Case: To test exception handling in fork/wait
359
operation with runAsync API. Add.Tensor op is called for
360
tensors with non-matching dims on the forked subgraph
361
and the exception raised by subgraph is set on future returned
362
by prim::fork to parent graph. Returned exception is
363
checked for substring expected_error_msg as declared below
365
def test_fork_wait_exception_async(self):
367
input1 = torch.randn(4, 7)
368
input2 = torch.randn(4, 5)
369
torch_graph = torch.jit.script(fork_wait_graph_exception)
371
static_runtime_module = StaticModule(torch_graph)
372
output_test = static_runtime_module.runAsync(
373
(input1, input2), {})
374
except Exception as error:
375
expected_error_msg = (
376
"The size of tensor a (7) must match the size "
377
"of tensor b (5) at non-singleton dimension 1"
380
if str(error).find(expected_error_msg) == -1:
382
"Tried execution of add.Tensors with incompatible shape. "
383
"Exception raised by forked runtime execution does "
384
f"not contain expected substring: \"{expected_error_msg}\""
387
def test_multihead_attention_layer(self):
394
device = torch.device("cpu")
395
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
396
with torch.no_grad():
397
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
398
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
401
attention = torch.jit.script(attention)
403
o_ref = attention(src, src, src, src_mask)
405
attention_a = StaticModule(attention)
406
o_test = attention_a(src, src, src, src_mask)
407
o_test_kw = attention_a(src, src, value=src, mask=src_mask)
409
for a, b in zip(o_ref, o_test):
410
torch.testing.assert_close(a, b)
412
for a, b in zip(o_ref, o_test_kw):
413
torch.testing.assert_close(a, b)
415
def test_multihead_attention_layer_benchmark(self):
422
device = torch.device("cpu")
423
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
424
with torch.no_grad():
425
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
426
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
429
attention = torch.jit.script(attention)
430
attention_a = StaticModule(attention)
432
attention_a.benchmark([src, src, src, src_mask], {}, 2, 2)
433
metrics = attention_a.benchmark_individual_ops(
434
[src, src, src, src_mask], {}, 2, 2
439
ln_bot = [512, 512, 64]
441
ln_top = [100, 1024, 1024, 1024, 1]
443
bot_l = create_mlp(ln_bot, sigmoid_bot)
444
bot_l_acc = StaticModule(bot_l)
445
top_l = create_mlp(ln_top, sigmoid_top)
446
top_l_acc = StaticModule(top_l)
447
with torch.no_grad():
448
bot_inp = torch.randn(2048, 512)
449
top_inp = torch.randn(2048, 100)
450
ref_bot = bot_l(bot_inp)
451
acc_bot = bot_l_acc(bot_inp)
452
torch.testing.assert_close(acc_bot, ref_bot)
453
ref_top = top_l(top_inp)
454
acc_top = top_l_acc(top_inp)
455
torch.testing.assert_close(acc_top, ref_top)
457
with torch.no_grad():
458
bot_inp = torch.randn(2048, 512)
459
top_inp = torch.randn(2048, 100)
460
ref_bot = bot_l(bot_inp)
461
acc_bot = bot_l_acc(bot_inp)
462
torch.testing.assert_close(acc_bot, ref_bot)
463
ref_top = top_l(top_inp)
464
acc_top = top_l_acc(top_inp)
465
torch.testing.assert_close(acc_top, ref_top)
467
def test_trivial_graph(self):
468
s = torch.full((2, 2), 2)
469
tg = torch.jit.script(trivial_graph)
471
tg_a = StaticModule(tg)
472
o_test = tg_a(s, s, s)
473
torch.testing.assert_close(o_ref, o_test)
475
def test_leaky_relu(self):
476
s = torch.randn(5, 5)
477
tg = torch.jit.script(nn.LeakyReLU(0.1))
479
tg_a = StaticModule(tg)
481
torch.testing.assert_close(o_ref, o_test)
485
TorchScript IR of TestModule() after freezing:
486
graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule,
488
%18 : int = prim::Constant[value=30]()
489
%30 : int = prim::Constant[value=13]()
490
%3 : int = prim::Constant[value=20]()
491
%2 : int = prim::Constant[value=1]()
492
%self.sub2.a : int = prim::Constant[value=12]()
493
%self.a : int = prim::Constant[value=3]()
494
= prim::SetAttr[name="b"](%self, %3)
495
%17 : Tensor = aten::add(%x.1, %30, %2)
496
%7 : Tensor = aten::add(%17, %self.a, %2)
497
%b.1 : int = prim::GetAttr[name="b"](%self)
498
%9 : Tensor = aten::add(%7, %b.1, %2)
499
%sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self)
500
= prim::SetAttr[name="b"](%sub2, %18)
501
%b : int = prim::GetAttr[name="b"](%sub2)
502
%22 : int = aten::add(%self.sub2.a, %b)
503
%23 : Tensor = aten::add(%x.1, %22, %2)
504
%12 : Tensor = aten::add(%9, %23, %2)
511
input = torch.randn(2, 2)
512
output_s = m.forward(input)
514
ms = torch.jit.script(m)
515
sm = StaticModule(ms)
516
output_sm = sm(input)
517
torch.testing.assert_close(output_s, output_sm)
518
sm.benchmark([input], {}, 2, 2)
519
sm.benchmark_individual_ops([input], {}, 2, 2)
520
sm.benchmark([], {"x": input}, 2, 2)
521
sm.benchmark_individual_ops([], {"x": input}, 2, 2)
523
@unittest.skip("Temporarily disabled")
524
def test_fusion_trivial_graph(self):
525
s = torch.full((2, 2), 2)
526
tg = torch.jit.script(trivial_graph)
528
torch._C._fuse_to_static_module(tg.graph)
529
assert "StaticSubgraph" in str(tg.graph)
531
torch.testing.assert_close(o_ref, o_test)
533
@unittest.skip("Temporarily disabled")
534
def test_fusion_multihead_attention_layer(self):
541
device = torch.device("cpu")
542
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
543
with torch.no_grad():
544
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
545
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
548
attention = torch.jit.script(attention)
550
o_ref = attention(src, src, src, src_mask)
552
torch._C._fuse_to_static_module(attention._c)
553
o_test = attention(src, src, src, src_mask)
555
for a, b in zip(o_ref, o_test):
556
torch.testing.assert_close(a, b)
558
@unittest.skip("Temporarily disabled")
559
def test_fusion_loop(self):
560
a = torch.randn(5, 5)
561
b = torch.randn(5, 5)
563
lg = torch.jit.script(loop_graph)
565
torch._C._fuse_to_static_module(lg.graph)
566
assert "StaticSubgraph" in str(lg.graph)
568
torch.testing.assert_close(o_ref, o_test)
570
@unittest.skip("Temporarily disabled")
571
def test_fusion_outputs(self):
572
a = torch.randn(2, 2)
573
b = torch.randn(2, 2)
575
og = torch.jit.script(output_graph)
576
o_ref = og(a, b, b, c)
577
torch._C._fuse_to_static_module(og.graph)
578
assert "StaticSubgraph" in str(og.graph)
579
o_test = og(a, b, b, c)
580
for i in o_ref.keys():
581
torch.testing.assert_close(o_ref[i], o_test[i])
583
def test_create_object(self):
585
def __init__(self, x: torch.Tensor) -> None:
588
class Mod(torch.nn.Module):
589
def __init__(self) -> None:
592
def forward(self, y: torch.Tensor) -> torch.Tensor:
596
mod = torch.jit.script(Mod()).eval()
597
y = torch.randn((1, ))
600
static_mod = StaticModule(torch.jit.freeze(mod))
601
actual = static_mod(y)
603
self.assertEqual(expected, actual)
605
if __name__ == "__main__":