4
from typing import Dict, Optional
9
from torch.testing._internal.common_utils import TestCase, run_tests
10
from torch.testing._internal.static_module import StaticModule
11
from typing import List
15
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
17
output = input.matmul(weight.t())
24
torch.nn.functional.linear = linear_shim
27
class MultiHeadAttentionLayer(nn.Module):
28
def __init__(self, hid_dim, n_heads, dropout, device):
30
assert hid_dim % n_heads == 0
31
self.hid_dim = hid_dim
32
self.n_heads = n_heads
33
self.head_dim = hid_dim // n_heads
34
self.fc_q = nn.Linear(hid_dim, hid_dim)
35
self.fc_k = nn.Linear(hid_dim, hid_dim)
36
self.fc_v = nn.Linear(hid_dim, hid_dim)
37
self.fc_o = nn.Linear(hid_dim, hid_dim)
39
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
41
def forward(self, query, key, value, mask):
42
batch_size = query.shape[0]
46
Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
47
K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
48
V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
49
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
51
attention = torch.softmax(energy, dim=-1)
53
x = torch.matmul(attention, V)
54
x = x.permute(0, 2, 1, 3).contiguous()
55
x = x.view(batch_size, -1, self.hid_dim)
61
def create_mlp(ln, sigmoid_layer):
62
layers = nn.ModuleList()
63
for i in range(0, len(ln) - 1):
67
LL = nn.Linear(int(n), int(m), bias=True)
70
std_dev = np.sqrt(2 / (m + n))
71
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
72
std_dev = np.sqrt(1 / m)
73
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
74
LL.weight.data = torch.tensor(W, requires_grad=True)
75
LL.bias.data = torch.tensor(bt, requires_grad=True)
78
if i == sigmoid_layer:
79
layers.append(nn.Sigmoid())
81
layers.append(nn.ReLU())
84
s = torch.jit.script(torch.nn.Sequential(*layers))
89
def trivial_graph(a, b, c):
90
s = torch.tensor([[3, 3], [3, 3]])
93
def elementwise_square_addition(input1, input2):
94
return input1 * input1 + input2 * input2
96
def fork_wait_graph1(input1, input2):
97
fut = torch.jit.fork(elementwise_square_addition, input1, input2)
98
return torch.jit.wait(fut)
100
def fork_wait_graph2(input1, input2):
101
fut = torch.jit.fork(loop_graph, input1, input2, 5)
102
return torch.jit.wait(fut)
105
graph with multiple fork/wait operations
106
:param input: torch.tensor input to forked subgraph
107
:param iters: number of future/wait pairs to be created
109
def fork_wait_graph3(input, iters: int):
110
futures : List[torch.jit.Future[torch.Tensor]] = []
111
for _ in range(iters):
112
futures.append(torch.jit.fork(torch.neg, input))
114
for future in futures:
115
results.append(torch.jit.wait(future))
116
return torch.sum(torch.stack(results))
119
graph with multi-level fork/wait operations
120
:param input: torch.tensor input to forked subgraph
121
:param num_forks: number of top level forks
122
:param num_child_forks: number of child forks per parent fork
124
def fork_wait_graph4(input, num_forks: int, num_child_forks: int):
125
futures : List[torch.jit.Future[torch.Tensor]] = []
126
for _ in range(num_forks):
127
futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks))
129
for future in futures:
130
results.append(torch.jit.wait(future))
131
return torch.sum(torch.stack(results))
133
def add_tensor(input1, input2):
134
return input1 + input2
136
def fork_wait_graph_exception(input1, input2):
137
fut = torch.jit.fork(add_tensor, input1, input2)
138
return torch.jit.wait(fut)
140
def loop_graph(a, b, iters: int):
142
for i in range(iters):
149
def output_graph(a, b, c, iters: int):
150
s = torch.tensor([[3, 3], [3, 3]])
152
d: Dict[int, torch.Tensor] = {}
153
for i in range(iters):
158
class SubModule(nn.Module):
159
def __init__(self) -> None:
164
def forward(self, x):
165
return self.a + self.b + x
168
class SubModule2(nn.Module):
169
def __init__(self) -> None:
174
def forward(self, x):
176
return self.a + self.b + x
179
class TestModule(nn.Module):
180
def __init__(self) -> None:
182
self.sub1 = SubModule()
183
self.sub2 = SubModule2()
187
def forward(self, x):
189
return self.sub1(x) + self.a + self.b + self.sub2(x)
192
class TestStaticModule(TestCase):
195
Test Case: To test simple fork/wait operation in a graph
196
fork is called on simple addition operation on input tensors
198
def test_fork_wait_1(self):
199
inp1 = torch.ones(5, 5)
200
inp2 = torch.randn(5, 5)
201
torch_graph = torch.jit.script(fork_wait_graph1)
202
output_ref = torch_graph(inp1, inp2)
203
static_runtime_module = StaticModule(torch_graph)
204
output_test = static_runtime_module(inp1, inp2)
205
torch.testing.assert_close(output_test, output_ref)
208
Test Case: To test simple fork/wait operation with
209
StaticRuntime runAsync API returning future
211
def test_fork_wait_1_async(self):
212
inp1 = torch.ones(5, 5)
213
inp2 = torch.randn(5, 5)
214
torch_graph = torch.jit.script(fork_wait_graph1)
215
output_ref = torch_graph(inp1, inp2)
216
static_runtime_module = StaticModule(torch_graph)
217
output_test = static_runtime_module.runAsync((inp1, inp2), {})
219
torch.testing.assert_close(output_test.value(), output_ref)
222
Test Case: To test fork/wait operation in a graph on
223
a loop subgraph performing mix of operations
225
def test_fork_wait_2(self):
226
inp1 = torch.randn(5, 5)
227
inp2 = torch.randn(5, 5)
228
torch_graph = torch.jit.script(fork_wait_graph2)
229
output_ref = torch_graph(inp1, inp2)
230
static_runtime_module = StaticModule(torch_graph)
231
output_test = static_runtime_module(inp1, inp2)
232
torch.testing.assert_close(output_test, output_ref)
235
Test Case: To test fork/wait operation on a loop
236
subgraph with StaticRuntime runAsync API returning future
238
def test_fork_wait_2_async(self):
239
inp1 = torch.randn(5, 5)
240
inp2 = torch.randn(5, 5)
241
torch_graph = torch.jit.script(fork_wait_graph2)
242
output_ref = torch_graph(inp1, inp2)
243
static_runtime_module = StaticModule(torch_graph)
244
output_test = static_runtime_module.runAsync((inp1, inp2), {})
246
torch.testing.assert_close(output_test.value(), output_ref)
249
Test Case: To test fork/wait operation in a graph on
250
having multiple fork/wait operations
252
def test_fork_wait_3(self):
253
input = torch.ones(3, 3)
255
torch_graph = torch.jit.script(fork_wait_graph3)
256
output_ref = torch_graph(input, num_forks)
257
static_runtime_module = StaticModule(torch_graph)
258
output_test = static_runtime_module(input, num_forks)
259
torch.testing.assert_close(output_test, output_ref)
262
Test Case: To test fork/wait operation in a graph with
263
multiple fork/wait operations on runAsync API returning future
265
def test_fork_wait_3_async(self):
266
input = torch.ones(3, 3)
268
torch_graph = torch.jit.script(fork_wait_graph3)
269
output_ref = torch_graph(input, num_forks)
270
static_runtime_module = StaticModule(torch_graph)
271
output_test = static_runtime_module.runAsync((input, num_forks), {})
273
torch.testing.assert_close(output_test.value(), output_ref)
276
Test Case: To test fork/wait operation in a graph on
277
multiple nested fork/wait operations
279
@unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
280
def test_fork_wait_4(self):
281
input = torch.ones(3, 3)
284
torch_graph = torch.jit.script(fork_wait_graph4)
285
static_runtime_module = StaticModule(torch_graph)
286
output_ref = torch_graph(input, num_forks, num_child_forks)
287
output_test = static_runtime_module(input, num_forks, num_child_forks)
288
torch.testing.assert_close(output_test, output_ref)
291
Test Case: To test fork/wait operation in a graph with multiple
292
nested fork/wait operations on runAsync API returning future
294
@unittest.skip("Broken test: https://github.com/pytorch/pytorch/issues/109782")
295
def test_fork_wait_4_async(self):
296
input = torch.ones(3, 3)
299
torch_graph = torch.jit.script(fork_wait_graph4)
300
static_runtime_module = StaticModule(torch_graph)
301
output_ref = torch_graph(input, num_forks, num_child_forks)
302
output_test = static_runtime_module.runAsync(
303
(input, num_forks, num_child_forks), {})
305
torch.testing.assert_close(output_test.value(), output_ref)
308
Test Case: To test exception handling in fork/wait
309
operation. Add.Tensor op is called for tensors with
310
non-matching dims on the forked subgraph and the
311
exception raised by subgraph is set on future returned
312
by prim::fork to parent graph. Returned exception is
313
checked for substring expected_error_msg as declared below
315
def test_fork_wait_exception(self):
317
input1 = torch.randn(4, 7)
318
input2 = torch.randn(4, 5)
319
torch_graph = torch.jit.script(fork_wait_graph_exception)
321
static_runtime_module = StaticModule(torch_graph)
322
output_test = static_runtime_module(input1, input2)
323
except Exception as error:
324
expected_error_msg = (
325
"The size of tensor a (7) must match the size "
326
"of tensor b (5) at non-singleton dimension 1"
329
if str(error).find(expected_error_msg) == -1:
331
"Tried execution of add.Tensors with incompatible shape. "
332
"Exception raised by forked runtime execution does "
333
f'not contain expected substring: "{expected_error_msg}"'
337
Test Case: To test exception handling in fork/wait
338
operation with runAsync API. Add.Tensor op is called for
339
tensors with non-matching dims on the forked subgraph
340
and the exception raised by subgraph is set on future returned
341
by prim::fork to parent graph. Returned exception is
342
checked for substring expected_error_msg as declared below
344
def test_fork_wait_exception_async(self):
346
input1 = torch.randn(4, 7)
347
input2 = torch.randn(4, 5)
348
torch_graph = torch.jit.script(fork_wait_graph_exception)
350
static_runtime_module = StaticModule(torch_graph)
351
output_test = static_runtime_module.runAsync(
352
(input1, input2), {})
353
except Exception as error:
354
expected_error_msg = (
355
"The size of tensor a (7) must match the size "
356
"of tensor b (5) at non-singleton dimension 1"
359
if str(error).find(expected_error_msg) == -1:
361
"Tried execution of add.Tensors with incompatible shape. "
362
"Exception raised by forked runtime execution does "
363
f'not contain expected substring: "{expected_error_msg}"'
366
def test_multihead_attention_layer(self):
373
device = torch.device("cpu")
374
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
375
with torch.no_grad():
376
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
377
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
380
attention = torch.jit.script(attention)
382
o_ref = attention(src, src, src, src_mask)
384
attention_a = StaticModule(attention)
385
o_test = attention_a(src, src, src, src_mask)
386
o_test_kw = attention_a(src, src, value=src, mask=src_mask)
388
for a, b in zip(o_ref, o_test):
389
torch.testing.assert_close(a, b)
391
for a, b in zip(o_ref, o_test_kw):
392
torch.testing.assert_close(a, b)
394
def test_multihead_attention_layer_benchmark(self):
401
device = torch.device("cpu")
402
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
403
with torch.no_grad():
404
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
405
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
408
attention = torch.jit.script(attention)
409
attention_a = StaticModule(attention)
411
attention_a.benchmark([src, src, src, src_mask], {}, 2, 2)
412
metrics = attention_a.benchmark_individual_ops(
413
[src, src, src, src_mask], {}, 2, 2
418
ln_bot = [512, 512, 64]
420
ln_top = [100, 1024, 1024, 1024, 1]
422
bot_l = create_mlp(ln_bot, sigmoid_bot)
423
bot_l_acc = StaticModule(bot_l)
424
top_l = create_mlp(ln_top, sigmoid_top)
425
top_l_acc = StaticModule(top_l)
426
with torch.no_grad():
427
bot_inp = torch.randn(2048, 512)
428
top_inp = torch.randn(2048, 100)
429
ref_bot = bot_l(bot_inp)
430
acc_bot = bot_l_acc(bot_inp)
431
torch.testing.assert_close(acc_bot, ref_bot)
432
ref_top = top_l(top_inp)
433
acc_top = top_l_acc(top_inp)
434
torch.testing.assert_close(acc_top, ref_top)
436
with torch.no_grad():
437
bot_inp = torch.randn(2048, 512)
438
top_inp = torch.randn(2048, 100)
439
ref_bot = bot_l(bot_inp)
440
acc_bot = bot_l_acc(bot_inp)
441
torch.testing.assert_close(acc_bot, ref_bot)
442
ref_top = top_l(top_inp)
443
acc_top = top_l_acc(top_inp)
444
torch.testing.assert_close(acc_top, ref_top)
446
def test_trivial_graph(self):
447
s = torch.full((2, 2), 2)
448
tg = torch.jit.script(trivial_graph)
450
tg_a = StaticModule(tg)
451
o_test = tg_a(s, s, s)
452
torch.testing.assert_close(o_ref, o_test)
454
def test_leaky_relu(self):
455
s = torch.randn(5, 5)
456
tg = torch.jit.script(nn.LeakyReLU(0.1))
458
tg_a = StaticModule(tg)
460
torch.testing.assert_close(o_ref, o_test)
464
TorchScript IR of TestModule() after freezing:
465
graph(%self : __torch__.test_static_runtime.___torch_mangle_0.TestModule,
467
%18 : int = prim::Constant[value=30]()
468
%30 : int = prim::Constant[value=13]()
469
%3 : int = prim::Constant[value=20]()
470
%2 : int = prim::Constant[value=1]()
471
%self.sub2.a : int = prim::Constant[value=12]()
472
%self.a : int = prim::Constant[value=3]()
473
= prim::SetAttr[name="b"](%self, %3)
474
%17 : Tensor = aten::add(%x.1, %30, %2)
475
%7 : Tensor = aten::add(%17, %self.a, %2)
476
%b.1 : int = prim::GetAttr[name="b"](%self)
477
%9 : Tensor = aten::add(%7, %b.1, %2)
478
%sub2 : __torch__.test_static_runtime.___torch_mangle_2.SubModule2 = prim::GetAttr[name="sub2"](%self)
479
= prim::SetAttr[name="b"](%sub2, %18)
480
%b : int = prim::GetAttr[name="b"](%sub2)
481
%22 : int = aten::add(%self.sub2.a, %b)
482
%23 : Tensor = aten::add(%x.1, %22, %2)
483
%12 : Tensor = aten::add(%9, %23, %2)
490
input = torch.randn(2, 2)
491
output_s = m.forward(input)
493
ms = torch.jit.script(m)
494
sm = StaticModule(ms)
495
output_sm = sm(input)
496
torch.testing.assert_close(output_s, output_sm)
497
sm.benchmark([input], {}, 2, 2)
498
sm.benchmark_individual_ops([input], {}, 2, 2)
499
sm.benchmark([], {"x": input}, 2, 2)
500
sm.benchmark_individual_ops([], {"x": input}, 2, 2)
502
@unittest.skip("Temporarily disabled")
503
def test_fusion_trivial_graph(self):
504
s = torch.full((2, 2), 2)
505
tg = torch.jit.script(trivial_graph)
507
torch._C._fuse_to_static_module(tg.graph)
508
assert "StaticSubgraph" in str(tg.graph)
510
torch.testing.assert_close(o_ref, o_test)
512
@unittest.skip("Temporarily disabled")
513
def test_fusion_multihead_attention_layer(self):
520
device = torch.device("cpu")
521
attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device)
522
with torch.no_grad():
523
src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device)
524
src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device)
527
attention = torch.jit.script(attention)
529
o_ref = attention(src, src, src, src_mask)
531
torch._C._fuse_to_static_module(attention._c)
532
o_test = attention(src, src, src, src_mask)
534
for a, b in zip(o_ref, o_test):
535
torch.testing.assert_close(a, b)
537
@unittest.skip("Temporarily disabled")
538
def test_fusion_loop(self):
539
a = torch.randn(5, 5)
540
b = torch.randn(5, 5)
542
lg = torch.jit.script(loop_graph)
544
torch._C._fuse_to_static_module(lg.graph)
545
assert "StaticSubgraph" in str(lg.graph)
547
torch.testing.assert_close(o_ref, o_test)
549
@unittest.skip("Temporarily disabled")
550
def test_fusion_outputs(self):
551
a = torch.randn(2, 2)
552
b = torch.randn(2, 2)
554
og = torch.jit.script(output_graph)
555
o_ref = og(a, b, b, c)
556
torch._C._fuse_to_static_module(og.graph)
557
assert "StaticSubgraph" in str(og.graph)
558
o_test = og(a, b, b, c)
559
for i in o_ref.keys():
560
torch.testing.assert_close(o_ref[i], o_test[i])
562
def test_create_object(self):
564
def __init__(self, x: torch.Tensor) -> None:
567
class Mod(torch.nn.Module):
568
def __init__(self) -> None:
571
def forward(self, y: torch.Tensor) -> torch.Tensor:
575
mod = torch.jit.script(Mod()).eval()
576
y = torch.randn((1, ))
579
static_mod = StaticModule(torch.jit.freeze(mod))
580
actual = static_mod(y)
582
self.assertEqual(expected, actual)
584
if __name__ == "__main__":