7
import torch.nn.functional as F
8
import torch.utils.flop_counter
9
from torch._subclasses.fake_tensor import FakeTensorMode
10
from torch.testing._internal.common_cuda import (
11
PLATFORM_SUPPORTS_FLASH_ATTENTION,
12
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
13
PLATFORM_SUPPORTS_CUDNN_ATTENTION
15
from torch.testing._internal.common_utils import (
17
TEST_WITH_TORCHDYNAMO,
23
from torchvision import models as torchvision_models
25
HAS_TORCHVISION = True
27
HAS_TORCHVISION = False
28
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
30
HAS_CUDA = torch.cuda.is_available()
33
def FlopCounterMode(*args, **kwargs):
34
return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
37
def get_total_flops(mode):
38
return str(sum(v for _, v in mode.flop_counts["Global"].items()))
41
def T(*shape, requires_grad=False):
42
return torch.randn(*shape, requires_grad=requires_grad)
46
TEST_WITH_TORCHDYNAMO, "torchdynamo doesn't work with __torch_dispatch__ right now"
48
class TestFlopCounter(TestCase):
49
def test_flop_counter_variety(self):
50
mod = torch.nn.Linear(9, 10)
51
with FlopCounterMode() as mode:
52
torch.mm(T(4, 5), T(5, 6))
53
torch.addmm(T(4, 6), T(4, 5), T(5, 6), beta=0.5, alpha=0.5)
54
torch.matmul(T(5, 6), T(6, 7))
55
torch.einsum("ab,bc->ac", T(6, 7), T(7, 8))
58
self.assertExpectedInline(get_total_flops(mode), """3012""")
61
with FlopCounterMode() as mode:
62
torch.mm(T(4, 5), T(5, 6))
64
self.assertExpectedInline(get_total_flops(mode), """240""")
67
torch.bmm(T(3, 4, 5), T(3, 5, 6))
69
self.assertExpectedInline(get_total_flops(mode), """720""")
72
torch.addmm(T(4, 6), T(4, 5), T(5, 6))
73
torch.addmm(T(4, 1), T(4, 5), T(5, 6))
74
torch.addmm(T(6), T(4, 5), T(5, 6))
77
self.assertExpectedInline(get_total_flops(mode), """720""")
80
torch.baddbmm(T(3, 4, 6), T(3, 4, 5), T(3, 5, 6))
83
self.assertExpectedInline(get_total_flops(mode), """720""")
86
torch.conv2d(T(2, 3, 6, 6), T(6, 3, 4, 4), padding=1)
95
self.assertExpectedInline(get_total_flops(mode), """28800""")
98
torch.conv1d(T(2, 3, 6), T(6, 3, 4), padding=1)
107
self.assertExpectedInline(get_total_flops(mode), """1440""")
109
def test_backward(self):
110
with FlopCounterMode() as mode:
111
a = T(4, 5, requires_grad=True)
112
a = torch.mm(a, T(5, 6))
113
a = a.unsqueeze(0).expand(7, 4, 6)
114
a = torch.bmm(a, T(7, 6, 7))
117
self.assertExpectedInline(get_total_flops(mode), """5184""")
119
def test_backward_reset(self):
120
with FlopCounterMode() as mode:
121
a = T(4, 5, requires_grad=True)
122
a.mm(a.t()).sum().backward()
123
a.mm(a.t()).sum().backward()
125
self.assertExpectedInline(get_total_flops(mode), """960""")
127
def test_torchscript(self):
129
return torch.mm(x, x)
131
with FlopCounterMode() as mode:
133
unscripted_flops = get_total_flops(mode)
134
ts_foo = torch.jit.script(foo)
137
self.assertEqual(unscripted_flops, get_total_flops(mode))
139
def test_autograd_op(self):
140
class _CustomOp(torch.autograd.Function):
142
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
143
return torch.mm(input, input)
146
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
147
return torch.mm(grad_output, grad_output) + torch.mm(
148
grad_output, grad_output
151
a = T(5, 5, requires_grad=True)
152
with FlopCounterMode() as mode:
153
a = _CustomOp.apply(a)
156
self.assertExpectedInline(get_total_flops(mode), """750""")
158
def test_conv_backwards_as_decomposition(self):
161
class onlyConvs(torch.autograd.Function):
163
def forward(inp, weight, transposed):
165
return F.conv1d(inp, weight)
167
return F.conv_transpose1d(inp, weight)
170
def setup_context(ctx, inputs, output):
171
inp, weight, transposed = inputs
172
ctx.save_for_backward(inp, weight)
173
ctx.transposed = transposed
176
def backward(ctx, grad_out):
177
inp, weight = ctx.saved_tensors
178
if not ctx.transposed:
179
grad_inp = F.conv_transpose1d(grad_out, weight)
180
grad_weight = F.conv1d(inp, grad_out)
181
return grad_inp, grad_weight, None
183
grad_inp = F.conv1d(grad_out, weight)
184
grad_weight = F.conv1d(
185
grad_out.transpose(1, 0), inp.transpose(1, 0)
187
return grad_inp, grad_weight.transpose(1, 0), None
189
from torch.func import grad
191
x = torch.randn(2, 3, 16, dtype=torch.float64)
192
weight = torch.randn(3, 4, 4, dtype=torch.float64)
194
def boring_conv(x, weight, transposed):
196
return F.conv1d(x, weight).pow(2).sum()
198
return F.conv_transpose1d(x, weight).pow(2).sum()
200
def only_convs(x, weight, transposed):
201
return onlyConvs.apply(x, weight, transposed).pow(2).sum()
203
boring_grads = grad(boring_conv, argnums=(0, 1))(x, weight, True)
204
fun_grads = grad(only_convs, argnums=(0, 1))(x, weight, True)
206
self.assertEqual(boring_grads, fun_grads)
208
def test_convs(self):
209
def assert_equivalence(f, expected_forward=None):
210
with FlopCounterMode() as mode:
212
conv_forward_flops = mode.get_flop_counts()["Global"][
213
torch.ops.aten.convolution
215
conv_backward_flops = mode.get_flop_counts()["Global"][
216
torch.ops.aten.convolution_backward
219
self.assertEqual(conv_forward_flops * 2, conv_backward_flops)
220
if expected_forward is not None:
221
self.assertEqual(conv_forward_flops, expected_forward)
223
x = torch.rand(1, 1, 2, 2, requires_grad=True)
224
weight = torch.randn(1, 1, 2, 2, requires_grad=True)
225
assert_equivalence(lambda: F.conv_transpose2d(x, weight).sum().backward(), 32)
227
x = torch.rand(1, 1, 2, 2, requires_grad=True)
228
weight = torch.randn(1, 1, 1, 1, requires_grad=True)
229
assert_equivalence(lambda: F.conv2d(x, weight).sum().backward(), 8)
231
for in_channels, out_channels, groups in [
239
x = torch.rand(1, in_channels, 4, 4, requires_grad=True)
240
weight = torch.randn(out_channels, in_channels, 2, 2, requires_grad=True)
241
assert_equivalence(lambda: F.conv2d(x, weight).sum().backward())
242
transposed_weight = torch.randn(
243
in_channels, out_channels, 2, 2, requires_grad=True
246
lambda: F.conv_transpose2d(x, transposed_weight).sum().backward()
250
def test_module(self):
251
resnet18 = torchvision_models.resnet18()
252
with FlopCounterMode(resnet18) as mode:
253
a = T(1, 3, 224, 224, requires_grad=True)
254
resnet18(a).sum().backward()
256
self.assertExpectedInline(get_total_flops(mode), """10884440064""")
257
layer1_conv_flops = mode.flop_counts["ResNet.layer1"][
258
torch.ops.aten.convolution
260
layer1_conv_back_flops = mode.flop_counts["ResNet.layer1"][
261
torch.ops.aten.convolution_backward
263
self.assertExpectedInline(str(layer1_conv_flops), """924844032""")
264
self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""")
266
def test_conv_transpose_loop(self):
267
x = torch.rand(1, 4, 30, 2)
268
model = torch.nn.ConvTranspose2d(4, 8, (2, 2), stride=2)
270
with FlopCounterMode() as mode:
274
self.assertExpectedInline(str(mode.get_total_flops()), """1536000""")
276
def test_custom(self):
277
mode = FlopCounterMode(
278
custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5}
284
self.assertExpectedInline(get_total_flops(mode), """5""")
286
def count(*args, out_val):
287
return out_val.numel()
289
count._get_raw = True
291
mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: count})
296
self.assertExpectedInline(get_total_flops(mode), """20""")
299
with FlopCounterMode() as mode:
302
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
304
not PLATFORM_SUPPORTS_FLASH_ATTENTION
305
or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION
306
or not PLATFORM_SUPPORTS_CUDNN_ATTENTION,
307
"Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
316
dtype = torch.float16
359
if backend == "math":
360
backend = torch.backends.cuda.sdp_kernel(
363
enable_mem_efficient=False,
366
elif backend == "flash":
367
backend = torch.backends.cuda.sdp_kernel(
370
enable_mem_efficient=False,
373
elif backend == "mem_efficient":
374
backend = torch.backends.cuda.sdp_kernel(
377
enable_mem_efficient=True,
380
elif backend == "cudnn":
381
backend = torch.backends.cuda.sdp_kernel(
384
enable_mem_efficient=False,
388
mode = FlopCounterMode()
390
out = F.scaled_dot_product_attention(
391
query, key, value, dropout_p=0, is_causal=True
395
return int(get_total_flops(mode))
398
run_uniform_flops = functools.partial(
410
run_uniform_flops(backend, with_backward=False)
411
for backend in ["math", "flash", "mem_efficient", "cudnn"]
413
flops_fw_math, flops_fw_flash, flops_fw_efficient, flops_fw_cudnn = flops
414
self.assertEqual(flops_fw_math, flops_fw_flash)
415
self.assertEqual(flops_fw_math, flops_fw_efficient)
416
self.assertEqual(flops_fw_math, flops_fw_cudnn)
418
self.assertExpectedInline(str(flops_fw_math), """134217728""")
421
run_uniform_flops(backend, with_backward=True)
422
for backend in ["math", "flash", "mem_efficient", "cudnn"]
424
flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient, flops_fw_bw_cudnn = flops
425
self.assertEqual(flops_fw_math * 3, flops_fw_bw_math)
426
self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash)
427
self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient)
428
self.assertEqual(flops_fw_bw_flash, flops_fw_bw_cudnn)
430
run_nonuniform_flops = functools.partial(
441
non_uniform_backends = ["math", "mem_efficient"]
443
run_nonuniform_flops(backend, with_backward=False)
444
for backend in non_uniform_backends
446
flops_fw_math, flops_fw_efficient = flops
447
self.assertEqual(flops_fw_math, flops_fw_efficient)
449
self.assertExpectedInline(str(flops_fw_math), """268435456""")
452
run_nonuniform_flops(backend, with_backward=True)
453
for backend in non_uniform_backends
455
flops_fw_bw_math, flops_fw_bw_efficient = flops
456
self.assertExpectedInline(str(flops_fw_bw_math), """805306368""")
457
self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""")
460
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
462
not PLATFORM_SUPPORTS_FLASH_ATTENTION
463
or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
464
"Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
466
def test_sdpa_nested_tensor(self):
467
def get_flops(q, k, v, backend, with_backward=False):
468
mode = FlopCounterMode()
470
if backend == "math":
471
backend = torch.backends.cuda.sdp_kernel(
474
enable_mem_efficient=False,
477
elif backend == "flash":
478
backend = torch.backends.cuda.sdp_kernel(
481
enable_mem_efficient=False,
484
elif backend == "mem_efficient":
485
backend = torch.backends.cuda.sdp_kernel(
488
enable_mem_efficient=True,
493
out = F.scaled_dot_product_attention(
494
q, k, v, dropout_p=0, is_causal=True
498
out.values().sum().backward()
502
return int(get_total_flops(mode))
504
def get_nested_inputs(
513
q_lengths = torch.tensor(
516
max_seq_len_q // 4 * 2,
517
max_seq_len_q // 4 * 3,
518
max_seq_len_q // 4 * 4,
521
k_lengths = torch.tensor(
524
max_seq_len_k // 4 * 2,
525
max_seq_len_k // 4 * 3,
526
max_seq_len_k // 4 * 4,
529
q_offsets, k_offsets = (
530
torch.cat((torch.tensor([0]), torch.cumsum(lengths, dim=0))).cuda()
531
for lengths in (q_lengths, k_lengths)
533
q_values = torch.randn(
540
k_values = torch.randn(
547
v_values = torch.randn(
549
head_dim_v * n_heads,
555
q = torch.nested.nested_tensor_from_jagged(q_values, q_offsets)
556
k = torch.nested.nested_tensor_from_jagged(k_values, k_offsets)
557
v = torch.nested.nested_tensor_from_jagged(v_values, k_offsets)
559
q = q.view(batch_size, -1, n_heads, head_dim).transpose(1, 2)
560
k = k.view(batch_size, -1, n_heads, head_dim).transpose(1, 2)
561
v = v.view(batch_size, -1, n_heads, head_dim_v).transpose(1, 2)
565
def get_dense_flops(q, k, v, backend, with_backward=False):
568
y.unsqueeze(0).transpose(1, 2).detach().requires_grad_(True)
569
for y in x.transpose(1, 2).unbind(0)
572
q_tensors = split_tensor(q)
573
k_tensors = split_tensor(k)
574
v_tensors = split_tensor(v)
577
for q_i, k_i, v_i in zip(q_tensors, k_tensors, v_tensors):
579
q_i, k_i, v_i, backend=backend, with_backward=with_backward
587
"max_seq_len_q": 128,
588
"max_seq_len_k": 128,
591
"dtype": torch.float16,
598
"max_seq_len_q": 128,
599
"max_seq_len_k": 256,
602
"dtype": torch.float16,
607
*get_nested_inputs(**uniform_config),
612
*get_nested_inputs(**uniform_config),
619
*get_nested_inputs(**uniform_config),
620
backend="mem_efficient",
624
*get_nested_inputs(**uniform_config),
625
backend="mem_efficient",
631
*get_nested_inputs(**differing_config),
632
backend="mem_efficient",
636
*get_nested_inputs(**differing_config),
637
backend="mem_efficient",
644
*get_nested_inputs(**uniform_config),
649
*get_nested_inputs(**uniform_config),
656
*get_nested_inputs(**uniform_config),
657
backend="mem_efficient",
661
*get_nested_inputs(**uniform_config),
662
backend="mem_efficient",
668
*get_nested_inputs(**differing_config),
669
backend="mem_efficient",
673
*get_nested_inputs(**differing_config),
674
backend="mem_efficient",
680
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
682
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
683
"Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
685
def test_nested_attention_fake_tensors(self):
686
x = torch.randn(123, 4, 16, device="cuda", dtype=torch.bfloat16)
687
offsets = torch.tensor([0, 30, 60, 90, 123], device="cuda")
689
with FakeTensorMode() as fake_mode:
690
fake_x = fake_mode.from_tensor(x)
691
fake_offsets = fake_mode.from_tensor(offsets)
693
with FlopCounterMode() as fake_flop_counter_mode:
694
torch.ops.aten._flash_attention_forward(
707
dense_x = torch.randn(4, 40, 4, 16, dtype=torch.bfloat16, device="cuda").transpose(1, 2)
709
with FlopCounterMode() as real_flop_counter_mode:
710
torch.ops.aten._flash_attention_forward(
723
self.assertEqual(int(get_total_flops(fake_flop_counter_mode)), int(get_total_flops(real_flop_counter_mode)))
726
def test_addmm_out(self):
728
y = torch.zeros(10, 10)
729
return torch.mm(x, x, out=y)
731
with FlopCounterMode() as mode:
732
f(torch.randn(10, 10))
734
self.assertExpectedInline(get_total_flops(mode), """2000""")
736
def test_hook_registration(self):
737
model = torch.nn.Linear(100, 100)
738
x = torch.randn(3, 100)
740
with FlopCounterMode() as mode:
741
self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 1)
742
self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 1)
743
model(x).sum().backward()
745
self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 0)
746
self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 0)
748
def test_pytrees(self):
749
class Foo(torch.nn.Module):
750
def forward(self, x):
752
return {"a": torch.mm(x, x)}
754
class Mod(torch.nn.Module):
755
def __init__(self) -> None:
760
def forward(self, x):
761
return self.b(self.a(x))
764
with FlopCounterMode() as mode:
765
mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
768
self.assertExpectedInline(
769
(mode.flop_counts["Mod"][torch.ops.aten.mm]), """12000"""
772
class Mod2(torch.nn.Module):
773
def forward(self, x):
774
return (torch.mm(x, x),)
777
with FlopCounterMode() as mode:
778
mod(torch.randn(10, 10, requires_grad=True))[0].sum().backward()
779
self.assertExpectedInline(
780
(mode.flop_counts["Mod2"][torch.ops.aten.mm]), """6000"""
783
def test_warning(self):
784
mod = torch.nn.Linear(2, 2)
785
with self.assertWarnsRegex(UserWarning, "not needed"):
788
def test_custom_op(self):
789
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula
791
@torch.library.custom_op("mylib::foo", mutates_args=())
792
def foo(x: torch.Tensor) -> torch.Tensor:
797
with self.assertRaisesRegex(ValueError, "expected each target to be OpOverloadPacket"):
798
register_flop_formula(torch.ops.mylib.foo.default)(lambda x: x)
800
@register_flop_formula(torch.ops.mylib.foo)
801
def formula(*args, **kwargs):
807
with FlopCounterMode(display=False) as mode:
810
self.assertEqual(called, 1)
811
self.assertExpectedInline(get_total_flops(mode), """9001""")
814
if __name__ == "__main__":