pytorch

Форк
0
/
test_flop_counter.py 
815 строк · 26.7 Кб
1
# Owner(s): ["module: unknown"]
2

3
import functools
4
import unittest
5

6
import torch
7
import torch.nn.functional as F
8
import torch.utils.flop_counter
9
from torch._subclasses.fake_tensor import FakeTensorMode
10
from torch.testing._internal.common_cuda import (
11
    PLATFORM_SUPPORTS_FLASH_ATTENTION,
12
    PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
13
    PLATFORM_SUPPORTS_CUDNN_ATTENTION
14
)
15
from torch.testing._internal.common_utils import (
16
    run_tests,
17
    TEST_WITH_TORCHDYNAMO,
18
    TestCase,
19
    skipIfRocm,
20
)
21

22
try:
23
    from torchvision import models as torchvision_models
24

25
    HAS_TORCHVISION = True
26
except ImportError:
27
    HAS_TORCHVISION = False
28
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
29

30
HAS_CUDA = torch.cuda.is_available()
31

32

33
def FlopCounterMode(*args, **kwargs):
34
    return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
35

36

37
def get_total_flops(mode):
38
    return str(sum(v for _, v in mode.flop_counts["Global"].items()))
39

40

41
def T(*shape, requires_grad=False):
42
    return torch.randn(*shape, requires_grad=requires_grad)
43

44

45
@unittest.skipIf(
46
    TEST_WITH_TORCHDYNAMO, "torchdynamo doesn't work with __torch_dispatch__ right now"
47
)
48
class TestFlopCounter(TestCase):
49
    def test_flop_counter_variety(self):
50
        mod = torch.nn.Linear(9, 10)
51
        with FlopCounterMode() as mode:
52
            torch.mm(T(4, 5), T(5, 6))
53
            torch.addmm(T(4, 6), T(4, 5), T(5, 6), beta=0.5, alpha=0.5)
54
            torch.matmul(T(5, 6), T(6, 7))
55
            torch.einsum("ab,bc->ac", T(6, 7), T(7, 8))
56
            mod(T(8, 9))
57

58
        self.assertExpectedInline(get_total_flops(mode), """3012""")
59

60
    def test_op(self):
61
        with FlopCounterMode() as mode:
62
            torch.mm(T(4, 5), T(5, 6))
63
        # 4 * 6 * 2 * 5 = 240
64
        self.assertExpectedInline(get_total_flops(mode), """240""")
65

66
        with mode:
67
            torch.bmm(T(3, 4, 5), T(3, 5, 6))
68
        # 3 * 4 * 6 * 2 * 5 = 720
69
        self.assertExpectedInline(get_total_flops(mode), """720""")
70

71
        with mode:
72
            torch.addmm(T(4, 6), T(4, 5), T(5, 6))
73
            torch.addmm(T(4, 1), T(4, 5), T(5, 6))
74
            torch.addmm(T(6), T(4, 5), T(5, 6))
75

76
        # 4 * 6 * 2 * 5 = 240
77
        self.assertExpectedInline(get_total_flops(mode), """720""")
78

79
        with mode:
80
            torch.baddbmm(T(3, 4, 6), T(3, 4, 5), T(3, 5, 6))
81

82
        # 3 * 4 * 6 * 2 * 5 = 720
83
        self.assertExpectedInline(get_total_flops(mode), """720""")
84

85
        with mode:
86
            torch.conv2d(T(2, 3, 6, 6), T(6, 3, 4, 4), padding=1)
87

88
        # out_image_size = 2 * 5 * 5
89
        # kernel_size = 4 * 4
90
        # c_out = 6
91
        # c_in = 3
92
        # out_image_size * kernel_size * c_out * 2 * c_in
93

94
        # NB: I don't think this properly accounts for padding?
95
        self.assertExpectedInline(get_total_flops(mode), """28800""")
96

97
        with mode:
98
            torch.conv1d(T(2, 3, 6), T(6, 3, 4), padding=1)
99

100
        # out_image_size = 2 * 5
101
        # kernel_size = 4
102
        # c_out = 6
103
        # c_in = 3
104
        # out_image_size * kernel_size * c_out * 2 * c_in
105

106
        # NB: I don't think this properly accounts for padding?
107
        self.assertExpectedInline(get_total_flops(mode), """1440""")
108

109
    def test_backward(self):
110
        with FlopCounterMode() as mode:
111
            a = T(4, 5, requires_grad=True)
112
            a = torch.mm(a, T(5, 6))
113
            a = a.unsqueeze(0).expand(7, 4, 6)
114
            a = torch.bmm(a, T(7, 6, 7))
115
            a.sum().backward()
116

117
        self.assertExpectedInline(get_total_flops(mode), """5184""")
118

119
    def test_backward_reset(self):
120
        with FlopCounterMode() as mode:
121
            a = T(4, 5, requires_grad=True)
122
            a.mm(a.t()).sum().backward()
123
            a.mm(a.t()).sum().backward()
124

125
        self.assertExpectedInline(get_total_flops(mode), """960""")
126

127
    def test_torchscript(self):
128
        def foo(x):
129
            return torch.mm(x, x)
130

131
        with FlopCounterMode() as mode:
132
            foo(T(5, 5))
133
        unscripted_flops = get_total_flops(mode)
134
        ts_foo = torch.jit.script(foo)
135
        with mode:
136
            ts_foo(T(5, 5))
137
        self.assertEqual(unscripted_flops, get_total_flops(mode))
138

139
    def test_autograd_op(self):
140
        class _CustomOp(torch.autograd.Function):
141
            @staticmethod
142
            def forward(ctx, input: torch.Tensor) -> torch.Tensor:
143
                return torch.mm(input, input)
144

145
            @staticmethod
146
            def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
147
                return torch.mm(grad_output, grad_output) + torch.mm(
148
                    grad_output, grad_output
149
                )
150

151
        a = T(5, 5, requires_grad=True)
152
        with FlopCounterMode() as mode:
153
            a = _CustomOp.apply(a)
154
            a.sum().backward()
155

156
        self.assertExpectedInline(get_total_flops(mode), """750""")
157

158
    def test_conv_backwards_as_decomposition(self):
159
        # [conv backwards decomposition as conv forwards]
160

161
        class onlyConvs(torch.autograd.Function):
162
            @staticmethod
163
            def forward(inp, weight, transposed):
164
                if not transposed:
165
                    return F.conv1d(inp, weight)
166
                else:
167
                    return F.conv_transpose1d(inp, weight)
168

169
            @staticmethod
170
            def setup_context(ctx, inputs, output):
171
                inp, weight, transposed = inputs
172
                ctx.save_for_backward(inp, weight)
173
                ctx.transposed = transposed
174

175
            @staticmethod
176
            def backward(ctx, grad_out):
177
                inp, weight = ctx.saved_tensors
178
                if not ctx.transposed:
179
                    grad_inp = F.conv_transpose1d(grad_out, weight)
180
                    grad_weight = F.conv1d(inp, grad_out)
181
                    return grad_inp, grad_weight, None
182
                else:
183
                    grad_inp = F.conv1d(grad_out, weight)
184
                    grad_weight = F.conv1d(
185
                        grad_out.transpose(1, 0), inp.transpose(1, 0)
186
                    )
187
                    return grad_inp, grad_weight.transpose(1, 0), None
188

189
        from torch.func import grad
190

191
        x = torch.randn(2, 3, 16, dtype=torch.float64)
192
        weight = torch.randn(3, 4, 4, dtype=torch.float64)
193

194
        def boring_conv(x, weight, transposed):
195
            if not transposed:
196
                return F.conv1d(x, weight).pow(2).sum()
197
            else:
198
                return F.conv_transpose1d(x, weight).pow(2).sum()
199

200
        def only_convs(x, weight, transposed):
201
            return onlyConvs.apply(x, weight, transposed).pow(2).sum()
202

203
        boring_grads = grad(boring_conv, argnums=(0, 1))(x, weight, True)
204
        fun_grads = grad(only_convs, argnums=(0, 1))(x, weight, True)
205

206
        self.assertEqual(boring_grads, fun_grads)
207

208
    def test_convs(self):
209
        def assert_equivalence(f, expected_forward=None):
210
            with FlopCounterMode() as mode:
211
                f()
212
            conv_forward_flops = mode.get_flop_counts()["Global"][
213
                torch.ops.aten.convolution
214
            ]
215
            conv_backward_flops = mode.get_flop_counts()["Global"][
216
                torch.ops.aten.convolution_backward
217
            ]
218

219
            self.assertEqual(conv_forward_flops * 2, conv_backward_flops)
220
            if expected_forward is not None:
221
                self.assertEqual(conv_forward_flops, expected_forward)
222

223
        x = torch.rand(1, 1, 2, 2, requires_grad=True)
224
        weight = torch.randn(1, 1, 2, 2, requires_grad=True)
225
        assert_equivalence(lambda: F.conv_transpose2d(x, weight).sum().backward(), 32)
226

227
        x = torch.rand(1, 1, 2, 2, requires_grad=True)
228
        weight = torch.randn(1, 1, 1, 1, requires_grad=True)
229
        assert_equivalence(lambda: F.conv2d(x, weight).sum().backward(), 8)
230

231
        for in_channels, out_channels, groups in [
232
            (1, 1, 1),
233
            (1, 3, 1),
234
            (3, 1, 1),
235
            (3, 7, 1),
236
            (2, 4, 2),
237
            (4, 2, 2),
238
        ]:
239
            x = torch.rand(1, in_channels, 4, 4, requires_grad=True)
240
            weight = torch.randn(out_channels, in_channels, 2, 2, requires_grad=True)
241
            assert_equivalence(lambda: F.conv2d(x, weight).sum().backward())
242
            transposed_weight = torch.randn(
243
                in_channels, out_channels, 2, 2, requires_grad=True
244
            )
245
            assert_equivalence(
246
                lambda: F.conv_transpose2d(x, transposed_weight).sum().backward()
247
            )
248

249
    @skipIfNoTorchVision
250
    def test_module(self):
251
        resnet18 = torchvision_models.resnet18()
252
        with FlopCounterMode(resnet18) as mode:
253
            a = T(1, 3, 224, 224, requires_grad=True)
254
            resnet18(a).sum().backward()
255

256
        self.assertExpectedInline(get_total_flops(mode), """10884440064""")
257
        layer1_conv_flops = mode.flop_counts["ResNet.layer1"][
258
            torch.ops.aten.convolution
259
        ]
260
        layer1_conv_back_flops = mode.flop_counts["ResNet.layer1"][
261
            torch.ops.aten.convolution_backward
262
        ]
263
        self.assertExpectedInline(str(layer1_conv_flops), """924844032""")
264
        self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""")
265

266
    def test_conv_transpose_loop(self):
267
        x = torch.rand(1, 4, 30, 2)
268
        model = torch.nn.ConvTranspose2d(4, 8, (2, 2), stride=2)
269

270
        with FlopCounterMode() as mode:
271
            for i in range(50):
272
                out = model(x)
273
                out.sum().backward()
274
        self.assertExpectedInline(str(mode.get_total_flops()), """1536000""")
275

276
    def test_custom(self):
277
        mode = FlopCounterMode(
278
            custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5}
279
        )
280
        with mode:
281
            a = T(4, 5)
282
            a + a
283

284
        self.assertExpectedInline(get_total_flops(mode), """5""")
285

286
        def count(*args, out_val):
287
            return out_val.numel()
288

289
        count._get_raw = True
290

291
        mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: count})
292
        with mode:
293
            a = T(4, 5)
294
            a + a
295

296
        self.assertExpectedInline(get_total_flops(mode), """20""")
297

298
    def test_noop(self):
299
        with FlopCounterMode() as mode:
300
            T(4, 5).cos()
301

302
    @unittest.skipIf(not HAS_CUDA, "CUDA not available")
303
    @unittest.skipIf(
304
        not PLATFORM_SUPPORTS_FLASH_ATTENTION
305
        or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION
306
        or not PLATFORM_SUPPORTS_CUDNN_ATTENTION,
307
        "Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
308
    )
309
    def test_sdpa(self):
310
        batch_size = 4
311
        n_heads = 8
312
        seq_len_q = 128
313
        seq_len_k = 256
314
        head_dim = 64
315
        head_dim_v = 64
316
        dtype = torch.float16
317

318
        torch.manual_seed(0)
319

320
        def get_flops(
321
            batch_size,
322
            n_heads,
323
            seq_len_q,
324
            seq_len_k,
325
            head_dim,
326
            head_dim_v,
327
            dtype,
328
            backend,
329
            with_backward=False,
330
        ):
331
            query = torch.randn(
332
                batch_size,
333
                n_heads,
334
                seq_len_q,
335
                head_dim,
336
                device="cuda",
337
                dtype=dtype,
338
                requires_grad=True,
339
            )
340
            key = torch.randn(
341
                batch_size,
342
                n_heads,
343
                seq_len_k,
344
                head_dim,
345
                device="cuda",
346
                dtype=dtype,
347
                requires_grad=True,
348
            )
349
            value = torch.randn(
350
                batch_size,
351
                n_heads,
352
                seq_len_k,
353
                head_dim_v,
354
                device="cuda",
355
                dtype=dtype,
356
                requires_grad=True,
357
            )
358

359
            if backend == "math":
360
                backend = torch.backends.cuda.sdp_kernel(
361
                    enable_flash=False,
362
                    enable_math=True,
363
                    enable_mem_efficient=False,
364
                    enable_cudnn=False,
365
                )
366
            elif backend == "flash":
367
                backend = torch.backends.cuda.sdp_kernel(
368
                    enable_flash=True,
369
                    enable_math=False,
370
                    enable_mem_efficient=False,
371
                    enable_cudnn=False,
372
                )
373
            elif backend == "mem_efficient":
374
                backend = torch.backends.cuda.sdp_kernel(
375
                    enable_flash=False,
376
                    enable_math=False,
377
                    enable_mem_efficient=True,
378
                    enable_cudnn=False,
379
                )
380
            elif backend == "cudnn":
381
                backend = torch.backends.cuda.sdp_kernel(
382
                    enable_flash=False,
383
                    enable_math=False,
384
                    enable_mem_efficient=False,
385
                    enable_cudnn=True,
386
                )
387

388
            mode = FlopCounterMode()
389
            with backend, mode:
390
                out = F.scaled_dot_product_attention(
391
                    query, key, value, dropout_p=0, is_causal=True
392
                )
393
                if with_backward:
394
                    out.sum().backward()
395
            return int(get_total_flops(mode))
396

397
        # Sets seq_len_q == seq_len_k and dim_q == dim_v
398
        run_uniform_flops = functools.partial(
399
            get_flops,
400
            batch_size,
401
            n_heads,
402
            seq_len_q,
403
            seq_len_q,
404
            head_dim,
405
            head_dim,
406
            dtype,
407
        )
408

409
        flops = [
410
            run_uniform_flops(backend, with_backward=False)
411
            for backend in ["math", "flash", "mem_efficient", "cudnn"]
412
        ]
413
        flops_fw_math, flops_fw_flash, flops_fw_efficient, flops_fw_cudnn = flops
414
        self.assertEqual(flops_fw_math, flops_fw_flash)
415
        self.assertEqual(flops_fw_math, flops_fw_efficient)
416
        self.assertEqual(flops_fw_math, flops_fw_cudnn)
417

418
        self.assertExpectedInline(str(flops_fw_math), """134217728""")
419

420
        flops = [
421
            run_uniform_flops(backend, with_backward=True)
422
            for backend in ["math", "flash", "mem_efficient", "cudnn"]
423
        ]
424
        flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient, flops_fw_bw_cudnn = flops
425
        self.assertEqual(flops_fw_math * 3, flops_fw_bw_math)
426
        self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash)
427
        self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient)
428
        self.assertEqual(flops_fw_bw_flash, flops_fw_bw_cudnn)
429

430
        run_nonuniform_flops = functools.partial(
431
            get_flops,
432
            batch_size,
433
            n_heads,
434
            seq_len_q,
435
            seq_len_k,
436
            head_dim,
437
            head_dim_v,
438
            dtype,
439
        )
440
        # Flash does not support non-uniform attention, i.e. seq_len_q != seq_len_k or dim_q != dim_v"
441
        non_uniform_backends = ["math", "mem_efficient"]
442
        flops = [
443
            run_nonuniform_flops(backend, with_backward=False)
444
            for backend in non_uniform_backends
445
        ]
446
        flops_fw_math, flops_fw_efficient = flops
447
        self.assertEqual(flops_fw_math, flops_fw_efficient)
448

449
        self.assertExpectedInline(str(flops_fw_math), """268435456""")
450

451
        flops = [
452
            run_nonuniform_flops(backend, with_backward=True)
453
            for backend in non_uniform_backends
454
        ]
455
        flops_fw_bw_math, flops_fw_bw_efficient = flops
456
        self.assertExpectedInline(str(flops_fw_bw_math), """805306368""")
457
        self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""")
458

459
    @skipIfRocm  # Nested tensor
460
    @unittest.skipIf(not HAS_CUDA, "CUDA not available")
461
    @unittest.skipIf(
462
        not PLATFORM_SUPPORTS_FLASH_ATTENTION
463
        or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
464
        "Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
465
    )
466
    def test_sdpa_nested_tensor(self):
467
        def get_flops(q, k, v, backend, with_backward=False):
468
            mode = FlopCounterMode()
469

470
            if backend == "math":
471
                backend = torch.backends.cuda.sdp_kernel(
472
                    enable_flash=False,
473
                    enable_math=True,
474
                    enable_mem_efficient=False,
475
                    enable_cudnn=False,
476
                )
477
            elif backend == "flash":
478
                backend = torch.backends.cuda.sdp_kernel(
479
                    enable_flash=True,
480
                    enable_math=False,
481
                    enable_mem_efficient=False,
482
                    enable_cudnn=False,
483
                )
484
            elif backend == "mem_efficient":
485
                backend = torch.backends.cuda.sdp_kernel(
486
                    enable_flash=False,
487
                    enable_math=False,
488
                    enable_mem_efficient=True,
489
                    enable_cudnn=False,
490
                )
491

492
            with backend, mode:
493
                out = F.scaled_dot_product_attention(
494
                    q, k, v, dropout_p=0, is_causal=True
495
                )
496
                if with_backward:
497
                    if out.is_nested:
498
                        out.values().sum().backward()
499
                    else:
500
                        out.sum().backward()
501

502
            return int(get_total_flops(mode))
503

504
        def get_nested_inputs(
505
            batch_size,
506
            n_heads,
507
            max_seq_len_q,
508
            max_seq_len_k,
509
            head_dim,
510
            head_dim_v,
511
            dtype,
512
        ):
513
            q_lengths = torch.tensor(
514
                [
515
                    max_seq_len_q // 4,
516
                    max_seq_len_q // 4 * 2,
517
                    max_seq_len_q // 4 * 3,
518
                    max_seq_len_q // 4 * 4,
519
                ]
520
            )
521
            k_lengths = torch.tensor(
522
                [
523
                    max_seq_len_k // 4,
524
                    max_seq_len_k // 4 * 2,
525
                    max_seq_len_k // 4 * 3,
526
                    max_seq_len_k // 4 * 4,
527
                ]
528
            )
529
            q_offsets, k_offsets = (
530
                torch.cat((torch.tensor([0]), torch.cumsum(lengths, dim=0))).cuda()
531
                for lengths in (q_lengths, k_lengths)
532
            )
533
            q_values = torch.randn(
534
                q_offsets[-1],
535
                head_dim * n_heads,
536
                dtype=dtype,
537
                requires_grad=True,
538
                device="cuda",
539
            )
540
            k_values = torch.randn(
541
                k_offsets[-1],
542
                head_dim * n_heads,
543
                dtype=dtype,
544
                requires_grad=True,
545
                device="cuda",
546
            )
547
            v_values = torch.randn(
548
                k_offsets[-1],
549
                head_dim_v * n_heads,
550
                dtype=dtype,
551
                requires_grad=True,
552
                device="cuda",
553
            )
554

555
            q = torch.nested.nested_tensor_from_jagged(q_values, q_offsets)
556
            k = torch.nested.nested_tensor_from_jagged(k_values, k_offsets)
557
            v = torch.nested.nested_tensor_from_jagged(v_values, k_offsets)
558

559
            q = q.view(batch_size, -1, n_heads, head_dim).transpose(1, 2)
560
            k = k.view(batch_size, -1, n_heads, head_dim).transpose(1, 2)
561
            v = v.view(batch_size, -1, n_heads, head_dim_v).transpose(1, 2)
562

563
            return q, k, v
564

565
        def get_dense_flops(q, k, v, backend, with_backward=False):
566
            def split_tensor(x):
567
                return (
568
                    y.unsqueeze(0).transpose(1, 2).detach().requires_grad_(True)
569
                    for y in x.transpose(1, 2).unbind(0)
570
                )
571

572
            q_tensors = split_tensor(q)
573
            k_tensors = split_tensor(k)
574
            v_tensors = split_tensor(v)
575

576
            flops = 0
577
            for q_i, k_i, v_i in zip(q_tensors, k_tensors, v_tensors):
578
                flops += get_flops(
579
                    q_i, k_i, v_i, backend=backend, with_backward=with_backward
580
                )
581

582
            return flops
583

584
        uniform_config = {
585
            "batch_size": 4,
586
            "n_heads": 8,
587
            "max_seq_len_q": 128,
588
            "max_seq_len_k": 128,
589
            "head_dim": 64,
590
            "head_dim_v": 64,
591
            "dtype": torch.float16,
592
        }
593

594
        # max_seq_len_q != max_seq_len_k doesn't work for flash attention with dense tensors.
595
        differing_config = {
596
            "batch_size": 4,
597
            "n_heads": 8,
598
            "max_seq_len_q": 128,
599
            "max_seq_len_k": 256,
600
            "head_dim": 64,
601
            "head_dim_v": 64,
602
            "dtype": torch.float16,
603
        }
604

605
        self.assertEqual(
606
            get_dense_flops(
607
                *get_nested_inputs(**uniform_config),
608
                backend="flash",
609
                with_backward=False,
610
            ),
611
            get_flops(
612
                *get_nested_inputs(**uniform_config),
613
                backend="flash",
614
                with_backward=False,
615
            ),
616
        )
617
        self.assertEqual(
618
            get_dense_flops(
619
                *get_nested_inputs(**uniform_config),
620
                backend="mem_efficient",
621
                with_backward=False,
622
            ),
623
            get_flops(
624
                *get_nested_inputs(**uniform_config),
625
                backend="mem_efficient",
626
                with_backward=False,
627
            ),
628
        )
629
        self.assertEqual(
630
            get_dense_flops(
631
                *get_nested_inputs(**differing_config),
632
                backend="mem_efficient",
633
                with_backward=False,
634
            ),
635
            get_flops(
636
                *get_nested_inputs(**differing_config),
637
                backend="mem_efficient",
638
                with_backward=False,
639
            ),
640
        )
641

642
        self.assertEqual(
643
            get_dense_flops(
644
                *get_nested_inputs(**uniform_config),
645
                backend="flash",
646
                with_backward=True,
647
            ),
648
            get_flops(
649
                *get_nested_inputs(**uniform_config),
650
                backend="flash",
651
                with_backward=True,
652
            ),
653
        )
654
        self.assertEqual(
655
            get_dense_flops(
656
                *get_nested_inputs(**uniform_config),
657
                backend="mem_efficient",
658
                with_backward=True,
659
            ),
660
            get_flops(
661
                *get_nested_inputs(**uniform_config),
662
                backend="mem_efficient",
663
                with_backward=True,
664
            ),
665
        )
666
        self.assertEqual(
667
            get_dense_flops(
668
                *get_nested_inputs(**differing_config),
669
                backend="mem_efficient",
670
                with_backward=True,
671
            ),
672
            get_flops(
673
                *get_nested_inputs(**differing_config),
674
                backend="mem_efficient",
675
                with_backward=True,
676
            ),
677
        )
678

679
    @skipIfRocm  # Nested tensor
680
    @unittest.skipIf(not HAS_CUDA, "CUDA not available")
681
    @unittest.skipIf(
682
        not PLATFORM_SUPPORTS_FLASH_ATTENTION,
683
        "Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
684
    )
685
    def test_nested_attention_fake_tensors(self):
686
        x = torch.randn(123, 4, 16, device="cuda", dtype=torch.bfloat16)
687
        offsets = torch.tensor([0, 30, 60, 90, 123], device="cuda")
688
        max_seqlen = 40
689
        with FakeTensorMode() as fake_mode:
690
            fake_x = fake_mode.from_tensor(x)
691
            fake_offsets = fake_mode.from_tensor(offsets)
692

693
            with FlopCounterMode() as fake_flop_counter_mode:
694
                torch.ops.aten._flash_attention_forward(
695
                    fake_x,
696
                    fake_x,
697
                    fake_x,
698
                    fake_offsets,
699
                    fake_offsets,
700
                    max_seqlen,
701
                    max_seqlen,
702
                    0.0,
703
                    False,
704
                    False,
705
                )
706

707
        dense_x = torch.randn(4, 40, 4, 16, dtype=torch.bfloat16, device="cuda").transpose(1, 2)
708

709
        with FlopCounterMode() as real_flop_counter_mode:
710
            torch.ops.aten._flash_attention_forward(
711
                dense_x,
712
                dense_x,
713
                dense_x,
714
                None,
715
                None,
716
                max_seqlen,
717
                max_seqlen,
718
                0.0,
719
                False,
720
                False,
721
            )
722

723
        self.assertEqual(int(get_total_flops(fake_flop_counter_mode)), int(get_total_flops(real_flop_counter_mode)))
724

725

726
    def test_addmm_out(self):
727
        def f(x):
728
            y = torch.zeros(10, 10)
729
            return torch.mm(x, x, out=y)
730

731
        with FlopCounterMode() as mode:
732
            f(torch.randn(10, 10))
733

734
        self.assertExpectedInline(get_total_flops(mode), """2000""")
735

736
    def test_hook_registration(self):
737
        model = torch.nn.Linear(100, 100)
738
        x = torch.randn(3, 100)
739

740
        with FlopCounterMode() as mode:
741
            self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 1)
742
            self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 1)
743
            model(x).sum().backward()
744

745
        self.assertEqual(len(torch.nn.modules.module._global_forward_pre_hooks), 0)
746
        self.assertEqual(len(torch.nn.modules.module._global_forward_hooks), 0)
747

748
    def test_pytrees(self):
749
        class Foo(torch.nn.Module):
750
            def forward(self, x):
751
                x = x["a"].relu_()
752
                return {"a": torch.mm(x, x)}
753

754
        class Mod(torch.nn.Module):
755
            def __init__(self) -> None:
756
                super().__init__()
757
                self.a = Foo()
758
                self.b = Foo()
759

760
            def forward(self, x):
761
                return self.b(self.a(x))
762

763
        mod = Mod()
764
        with FlopCounterMode() as mode:
765
            mod({"a": torch.randn(10, 10, requires_grad=True).clone()})[
766
                "a"
767
            ].sum().backward()
768
        self.assertExpectedInline(
769
            (mode.flop_counts["Mod"][torch.ops.aten.mm]), """12000"""
770
        )
771

772
        class Mod2(torch.nn.Module):
773
            def forward(self, x):
774
                return (torch.mm(x, x),)
775

776
        mod = Mod2()
777
        with FlopCounterMode() as mode:
778
            mod(torch.randn(10, 10, requires_grad=True))[0].sum().backward()
779
        self.assertExpectedInline(
780
            (mode.flop_counts["Mod2"][torch.ops.aten.mm]), """6000"""
781
        )
782

783
    def test_warning(self):
784
        mod = torch.nn.Linear(2, 2)
785
        with self.assertWarnsRegex(UserWarning, "not needed"):
786
            FlopCounterMode(mod)
787

788
    def test_custom_op(self):
789
        from torch.utils.flop_counter import FlopCounterMode, register_flop_formula
790

791
        @torch.library.custom_op("mylib::foo", mutates_args=())
792
        def foo(x: torch.Tensor) -> torch.Tensor:
793
            return x.sin()
794

795
        called = 0
796

797
        with self.assertRaisesRegex(ValueError, "expected each target to be OpOverloadPacket"):
798
            register_flop_formula(torch.ops.mylib.foo.default)(lambda x: x)
799

800
        @register_flop_formula(torch.ops.mylib.foo)
801
        def formula(*args, **kwargs):
802
            nonlocal called
803
            called += 1
804
            return 9001
805

806
        x = torch.randn(3)
807
        with FlopCounterMode(display=False) as mode:
808
            y = foo(x)
809

810
        self.assertEqual(called, 1)
811
        self.assertExpectedInline(get_total_flops(mode), """9001""")
812

813

814
if __name__ == "__main__":
815
    run_tests()
816

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.