1
# Owner(s): ["module: unknown"]
4
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_TORCHDYNAMO
5
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION
6
import torch.utils.flop_counter
7
import torch.nn.functional as F
12
from torchvision import models as torchvision_models
13
HAS_TORCHVISION = True
15
HAS_TORCHVISION = False
16
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
18
HAS_CUDA = torch.cuda.is_available()
20
def FlopCounterMode(*args, **kwargs):
21
return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
23
def get_total_flops(mode):
24
return str(sum([v for _, v in mode.flop_counts["Global"].items()]))
26
def T(*shape, requires_grad=False):
27
return torch.randn(*shape, requires_grad=requires_grad)
29
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "torchdynamo doesn't work with __torch_dispatch__ right now")
30
class TestFlopCounter(TestCase):
31
def test_flop_counter_variety(self):
32
mode = FlopCounterMode()
33
mod = torch.nn.Linear(9, 10)
35
torch.mm(T(4, 5), T(5, 6))
36
torch.addmm(T(4, 6), T(4, 5), T(5, 6), beta=0.5, alpha=0.5)
37
torch.matmul(T(5, 6), T(6, 7))
38
torch.einsum("ab,bc->ac", T(6, 7), T(7, 8))
41
self.assertExpectedInline(get_total_flops(mode), """3012""")
44
mode = FlopCounterMode()
46
torch.mm(T(4, 5), T(5, 6))
48
self.assertExpectedInline(get_total_flops(mode), """240""")
51
torch.bmm(T(3, 4, 5), T(3, 5, 6))
52
# 3 * 4 * 6 * 2 * 5 = 720
53
self.assertExpectedInline(get_total_flops(mode), """720""")
56
torch.addmm(T(4, 6), T(4, 5), T(5, 6))
57
torch.addmm(T(4, 1), T(4, 5), T(5, 6))
58
torch.addmm(T(6), T(4, 5), T(5, 6))
61
self.assertExpectedInline(get_total_flops(mode), """720""")
64
torch.baddbmm(T(3, 4, 6), T(3, 4, 5), T(3, 5, 6))
66
# 3 * 4 * 6 * 2 * 5 = 720
67
self.assertExpectedInline(get_total_flops(mode), """720""")
70
torch.conv2d(T(2, 3, 6, 6), T(6, 3, 4, 4), padding=1)
72
# out_image_size = 2 * 5 * 5
76
# out_image_size * kernel_size * c_out * 2 * c_in
78
# NB: I don't think this properly accounts for padding?
79
self.assertExpectedInline(get_total_flops(mode), """28800""")
82
torch.conv1d(T(2, 3, 6), T(6, 3, 4), padding=1)
84
# out_image_size = 2 * 5
88
# out_image_size * kernel_size * c_out * 2 * c_in
90
# NB: I don't think this properly accounts for padding?
91
self.assertExpectedInline(get_total_flops(mode), """1440""")
93
def test_backward(self):
94
mode = FlopCounterMode()
96
a = T(4, 5, requires_grad=True)
97
a = torch.mm(a, T(5, 6))
98
a = a.unsqueeze(0).expand(7, 4, 6)
99
a = torch.bmm(a, T(7, 6, 7))
102
self.assertExpectedInline(get_total_flops(mode), """5184""")
104
def test_torchscript(self):
106
return torch.mm(x, x)
107
mode = FlopCounterMode()
110
unscripted_flops = get_total_flops(mode)
111
ts_foo = torch.jit.script(foo)
114
self.assertEqual(unscripted_flops, get_total_flops(mode))
116
def test_autograd_op(self):
117
class _CustomOp(torch.autograd.Function):
119
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
120
return torch.mm(input, input)
123
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
124
return torch.mm(grad_output, grad_output) + torch.mm(grad_output, grad_output)
126
a = T(5, 5, requires_grad=True)
127
mode = FlopCounterMode()
129
a = _CustomOp.apply(a)
132
self.assertExpectedInline(get_total_flops(mode), """750""")
134
def test_conv_backwards_as_decomposition(self):
135
# [conv backwards decomposition as conv forwards]
137
class onlyConvs(torch.autograd.Function):
139
def forward(inp, weight, transposed):
141
return F.conv1d(inp, weight)
143
return F.conv_transpose1d(inp, weight)
146
def setup_context(ctx, inputs, output):
147
inp, weight, transposed = inputs
148
ctx.save_for_backward(inp, weight)
149
ctx.transposed = transposed
152
def backward(ctx, grad_out):
153
inp, weight = ctx.saved_tensors
154
if not ctx.transposed:
155
grad_inp = F.conv_transpose1d(grad_out, weight)
156
grad_weight = F.conv1d(inp, grad_out)
157
return grad_inp, grad_weight, None
159
grad_inp = F.conv1d(grad_out, weight)
160
grad_weight = F.conv1d(grad_out.transpose(1, 0), inp.transpose(1, 0))
161
return grad_inp, grad_weight.transpose(1, 0), None
164
from torch.func import grad
165
x = torch.randn(2, 3, 16, dtype=torch.float64)
166
weight = torch.randn(3, 4, 4, dtype=torch.float64)
168
def boring_conv(x, weight, transposed):
170
return F.conv1d(x, weight).pow(2).sum()
172
return F.conv_transpose1d(x, weight).pow(2).sum()
174
def only_convs(x, weight, transposed):
175
return onlyConvs.apply(x, weight, transposed).pow(2).sum()
177
boring_grads = grad(boring_conv, argnums=(0, 1))(x, weight, True)
178
fun_grads = grad(only_convs, argnums=(0, 1))(x, weight, True)
180
self.assertEqual(boring_grads, fun_grads)
183
def test_convs(self):
184
def assert_equivalence(f, expected_forward=None):
185
mode = FlopCounterMode()
188
conv_forward_flops = mode.get_flop_counts()['Global'][torch.ops.aten.convolution]
189
conv_backward_flops = mode.get_flop_counts()['Global'][torch.ops.aten.convolution_backward]
191
self.assertEqual(conv_forward_flops * 2, conv_backward_flops)
192
if expected_forward is not None:
193
self.assertEqual(conv_forward_flops, expected_forward)
195
x = torch.rand(1, 1, 2, 2, requires_grad=True)
196
weight = torch.randn(1, 1, 2, 2, requires_grad=True)
197
assert_equivalence(lambda: F.conv_transpose2d(x, weight).sum().backward(), 32)
199
x = torch.rand(1, 1, 2, 2, requires_grad=True)
200
weight = torch.randn(1, 1, 1, 1, requires_grad=True)
201
assert_equivalence(lambda: F.conv2d(x, weight).sum().backward(), 8)
203
for in_channels, out_channels, groups in [
211
x = torch.rand(1, in_channels, 4, 4, requires_grad=True)
212
weight = torch.randn(out_channels, in_channels, 2, 2, requires_grad=True)
213
assert_equivalence(lambda: F.conv2d(x, weight).sum().backward())
214
transposed_weight = torch.randn(in_channels, out_channels, 2, 2, requires_grad=True)
215
assert_equivalence(lambda: F.conv_transpose2d(x, transposed_weight).sum().backward())
218
def test_module(self):
219
resnet18 = torchvision_models.resnet18()
220
mode = FlopCounterMode(resnet18)
222
a = T(1, 3, 224, 224, requires_grad=True)
223
resnet18(a).sum().backward()
225
self.assertExpectedInline(get_total_flops(mode), """10884440064""")
226
layer1_conv_flops = mode.flop_counts['ResNet.layer1'][torch.ops.aten.convolution]
227
layer1_conv_back_flops = mode.flop_counts['ResNet.layer1'][torch.ops.aten.convolution_backward]
228
self.assertExpectedInline(str(layer1_conv_flops), """924844032""")
229
self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""")
232
def test_conv_transpose_loop(self):
233
x = torch.rand(1, 4, 30, 2)
234
model = torch.nn.ConvTranspose2d(4, 8, (2, 2), stride=2)
236
mode = FlopCounterMode(model)
241
self.assertExpectedInline(str(mode.get_total_flops()), """1536000""")
243
def test_custom(self):
244
mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5})
249
self.assertExpectedInline(get_total_flops(mode), """5""")
251
def count(*args, out):
253
count._get_raw = True
255
mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: count})
260
self.assertExpectedInline(get_total_flops(mode), """20""")
263
mode = FlopCounterMode()
267
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
268
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
269
"Does not support all SDPA backends (pre-SM80 hardware on CUDA)")
277
dtype = torch.float16
281
def get_flops(batch_size, n_heads, seq_len_q, seq_len_k, head_dim, head_dim_v, dtype, backend, with_backward=False):
282
query = torch.randn(batch_size, n_heads, seq_len_q, head_dim, device='cuda', dtype=dtype, requires_grad=True)
283
key = torch.randn(batch_size, n_heads, seq_len_k, head_dim, device='cuda', dtype=dtype, requires_grad=True)
284
value = torch.randn(batch_size, n_heads, seq_len_k, head_dim_v, device='cuda', dtype=dtype, requires_grad=True)
286
if backend == "math":
287
backend = torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False)
288
elif backend == "flash":
289
backend = torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False)
290
elif backend == "mem_efficient":
291
backend = torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True)
293
mode = FlopCounterMode()
295
out = F.scaled_dot_product_attention(query, key, value, dropout_p=0, is_causal=True)
298
return int(get_total_flops(mode))
300
# Sets seq_len_q == seq_len_k and dim_q == dim_v
301
run_uniform_flops = functools.partial(get_flops, batch_size, n_heads, seq_len_q, seq_len_q, head_dim, head_dim, dtype)
303
flops = [run_uniform_flops(backend, with_backward=False) for backend in ["math", "flash", "mem_efficient"]]
304
flops_fw_math, flops_fw_flash, flops_fw_efficient = flops
305
self.assertEqual(flops_fw_math, flops_fw_flash)
306
self.assertEqual(flops_fw_math, flops_fw_efficient)
308
self.assertExpectedInline(str(flops_fw_math), """134217728""")
310
flops = [run_uniform_flops(backend, with_backward=True) for backend in ["math", "flash", "mem_efficient"]]
311
flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient = flops
312
self.assertEqual(flops_fw_math * 3, flops_fw_bw_math)
313
self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash)
314
self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient)
317
run_nonuniform_flops = functools.partial(get_flops, batch_size, n_heads, seq_len_q, seq_len_k, head_dim, head_dim_v, dtype)
318
# Flash does not support non-uniform attention, i.e. seq_len_q != seq_len_k or dim_q != dim_v"
319
non_uniform_backends = ["math", "mem_efficient"]
320
flops = [run_nonuniform_flops(backend, with_backward=False) for backend in non_uniform_backends]
321
flops_fw_math, flops_fw_efficient = flops
322
self.assertEqual(flops_fw_math, flops_fw_efficient)
324
self.assertExpectedInline(str(flops_fw_math), """268435456""")
326
flops = [run_nonuniform_flops(backend, with_backward=True) for backend in non_uniform_backends]
327
flops_fw_bw_math, flops_fw_bw_efficient = flops
328
self.assertExpectedInline(str(flops_fw_bw_math), """805306368""")
329
self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""")
331
def test_hook_registration(self):
332
model = torch.nn.Linear(100, 100)
333
x = torch.randn(3, 100)
335
flop_counter = FlopCounterMode(model)
337
self.assertEqual(len(model._forward_pre_hooks), 1)
338
self.assertEqual(len(model._forward_hooks), 1)
339
model(x).sum().backward()
341
self.assertEqual(len(model._forward_pre_hooks), 0)
342
self.assertEqual(len(model._forward_hooks), 0)
344
def test_pytrees(self):
345
class Foo(torch.nn.Module):
346
def forward(self, x):
348
return {'a': torch.mm(x, x)}
350
class Mod(torch.nn.Module):
356
def forward(self, x):
357
return self.b(self.a(x))
360
mode = FlopCounterMode(mod)
362
mod({'a': torch.randn(10, 10, requires_grad=True).clone()})['a'].sum().backward()
363
self.assertExpectedInline((mode.flop_counts['Mod'][torch.ops.aten.mm]), """12000""")
365
class Mod2(torch.nn.Module):
366
def forward(self, x):
367
return (torch.mm(x, x),)
370
mode = FlopCounterMode(mod)
372
mod(torch.randn(10, 10, requires_grad=True))[0].sum().backward()
373
self.assertExpectedInline((mode.flop_counts['Mod2'][torch.ops.aten.mm]), """6000""")
379
if __name__ == '__main__':