pytorch

Форк
0
/
test_flop_counter.py 
380 строк · 14.4 Кб
1
# Owner(s): ["module: unknown"]
2

3
import torch
4
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_TORCHDYNAMO
5
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION
6
import torch.utils.flop_counter
7
import torch.nn.functional as F
8
import unittest
9
import functools
10

11
try:
12
    from torchvision import models as torchvision_models
13
    HAS_TORCHVISION = True
14
except ImportError:
15
    HAS_TORCHVISION = False
16
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
17

18
HAS_CUDA = torch.cuda.is_available()
19

20
def FlopCounterMode(*args, **kwargs):
21
    return torch.utils.flop_counter.FlopCounterMode(*args, **kwargs, display=False)
22

23
def get_total_flops(mode):
24
    return str(sum([v for _, v in mode.flop_counts["Global"].items()]))
25

26
def T(*shape, requires_grad=False):
27
    return torch.randn(*shape, requires_grad=requires_grad)
28

29
@unittest.skipIf(TEST_WITH_TORCHDYNAMO, "torchdynamo doesn't work with __torch_dispatch__ right now")
30
class TestFlopCounter(TestCase):
31
    def test_flop_counter_variety(self):
32
        mode = FlopCounterMode()
33
        mod = torch.nn.Linear(9, 10)
34
        with mode:
35
            torch.mm(T(4, 5), T(5, 6))
36
            torch.addmm(T(4, 6), T(4, 5), T(5, 6), beta=0.5, alpha=0.5)
37
            torch.matmul(T(5, 6), T(6, 7))
38
            torch.einsum("ab,bc->ac", T(6, 7), T(7, 8))
39
            mod(T(8, 9))
40

41
        self.assertExpectedInline(get_total_flops(mode), """3012""")
42

43
    def test_op(self):
44
        mode = FlopCounterMode()
45
        with mode:
46
            torch.mm(T(4, 5), T(5, 6))
47
        # 4 * 6 * 2 * 5 = 240
48
        self.assertExpectedInline(get_total_flops(mode), """240""")
49

50
        with mode:
51
            torch.bmm(T(3, 4, 5), T(3, 5, 6))
52
        # 3 * 4 * 6 * 2 * 5 = 720
53
        self.assertExpectedInline(get_total_flops(mode), """720""")
54

55
        with mode:
56
            torch.addmm(T(4, 6), T(4, 5), T(5, 6))
57
            torch.addmm(T(4, 1), T(4, 5), T(5, 6))
58
            torch.addmm(T(6), T(4, 5), T(5, 6))
59

60
        # 4 * 6 * 2 * 5 = 240
61
        self.assertExpectedInline(get_total_flops(mode), """720""")
62

63
        with mode:
64
            torch.baddbmm(T(3, 4, 6), T(3, 4, 5), T(3, 5, 6))
65

66
        # 3 * 4 * 6 * 2 * 5 = 720
67
        self.assertExpectedInline(get_total_flops(mode), """720""")
68

69
        with mode:
70
            torch.conv2d(T(2, 3, 6, 6), T(6, 3, 4, 4), padding=1)
71

72
        # out_image_size = 2 * 5 * 5
73
        # kernel_size = 4 * 4
74
        # c_out = 6
75
        # c_in = 3
76
        # out_image_size * kernel_size * c_out * 2 * c_in
77

78
        # NB: I don't think this properly accounts for padding?
79
        self.assertExpectedInline(get_total_flops(mode), """28800""")
80

81
        with mode:
82
            torch.conv1d(T(2, 3, 6), T(6, 3, 4), padding=1)
83

84
        # out_image_size = 2 * 5
85
        # kernel_size = 4
86
        # c_out = 6
87
        # c_in = 3
88
        # out_image_size * kernel_size * c_out * 2 * c_in
89

90
        # NB: I don't think this properly accounts for padding?
91
        self.assertExpectedInline(get_total_flops(mode), """1440""")
92

93
    def test_backward(self):
94
        mode = FlopCounterMode()
95
        with mode:
96
            a = T(4, 5, requires_grad=True)
97
            a = torch.mm(a, T(5, 6))
98
            a = a.unsqueeze(0).expand(7, 4, 6)
99
            a = torch.bmm(a, T(7, 6, 7))
100
            a.sum().backward()
101

102
        self.assertExpectedInline(get_total_flops(mode), """5184""")
103

104
    def test_torchscript(self):
105
        def foo(x):
106
            return torch.mm(x, x)
107
        mode = FlopCounterMode()
108
        with mode:
109
            foo(T(5, 5))
110
        unscripted_flops = get_total_flops(mode)
111
        ts_foo = torch.jit.script(foo)
112
        with mode:
113
            ts_foo(T(5, 5))
114
        self.assertEqual(unscripted_flops, get_total_flops(mode))
115

116
    def test_autograd_op(self):
117
        class _CustomOp(torch.autograd.Function):
118
            @staticmethod
119
            def forward(ctx, input: torch.Tensor) -> torch.Tensor:
120
                return torch.mm(input, input)
121

122
            @staticmethod
123
            def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
124
                return torch.mm(grad_output, grad_output) + torch.mm(grad_output, grad_output)
125

126
        a = T(5, 5, requires_grad=True)
127
        mode = FlopCounterMode()
128
        with mode:
129
            a = _CustomOp.apply(a)
130
            a.sum().backward()
131

132
        self.assertExpectedInline(get_total_flops(mode), """750""")
133

134
    def test_conv_backwards_as_decomposition(self):
135
        # [conv backwards decomposition as conv forwards]
136

137
        class onlyConvs(torch.autograd.Function):
138
            @staticmethod
139
            def forward(inp, weight, transposed):
140
                if not transposed:
141
                    return F.conv1d(inp, weight)
142
                else:
143
                    return F.conv_transpose1d(inp, weight)
144

145
            @staticmethod
146
            def setup_context(ctx, inputs, output):
147
                inp, weight, transposed = inputs
148
                ctx.save_for_backward(inp, weight)
149
                ctx.transposed = transposed
150

151
            @staticmethod
152
            def backward(ctx, grad_out):
153
                inp, weight = ctx.saved_tensors
154
                if not ctx.transposed:
155
                    grad_inp = F.conv_transpose1d(grad_out, weight)
156
                    grad_weight = F.conv1d(inp, grad_out)
157
                    return grad_inp, grad_weight, None
158
                else:
159
                    grad_inp = F.conv1d(grad_out, weight)
160
                    grad_weight = F.conv1d(grad_out.transpose(1, 0), inp.transpose(1, 0))
161
                    return grad_inp, grad_weight.transpose(1, 0), None
162

163

164
        from torch.func import grad
165
        x = torch.randn(2, 3, 16, dtype=torch.float64)
166
        weight = torch.randn(3, 4, 4, dtype=torch.float64)
167

168
        def boring_conv(x, weight, transposed):
169
            if not transposed:
170
                return F.conv1d(x, weight).pow(2).sum()
171
            else:
172
                return F.conv_transpose1d(x, weight).pow(2).sum()
173

174
        def only_convs(x, weight, transposed):
175
            return onlyConvs.apply(x, weight, transposed).pow(2).sum()
176

177
        boring_grads = grad(boring_conv, argnums=(0, 1))(x, weight, True)
178
        fun_grads = grad(only_convs, argnums=(0, 1))(x, weight, True)
179

180
        self.assertEqual(boring_grads, fun_grads)
181

182

183
    def test_convs(self):
184
        def assert_equivalence(f, expected_forward=None):
185
            mode = FlopCounterMode()
186
            with mode:
187
                f()
188
            conv_forward_flops = mode.get_flop_counts()['Global'][torch.ops.aten.convolution]
189
            conv_backward_flops = mode.get_flop_counts()['Global'][torch.ops.aten.convolution_backward]
190

191
            self.assertEqual(conv_forward_flops * 2, conv_backward_flops)
192
            if expected_forward is not None:
193
                self.assertEqual(conv_forward_flops, expected_forward)
194

195
        x = torch.rand(1, 1, 2, 2, requires_grad=True)
196
        weight = torch.randn(1, 1, 2, 2, requires_grad=True)
197
        assert_equivalence(lambda: F.conv_transpose2d(x, weight).sum().backward(), 32)
198

199
        x = torch.rand(1, 1, 2, 2, requires_grad=True)
200
        weight = torch.randn(1, 1, 1, 1, requires_grad=True)
201
        assert_equivalence(lambda: F.conv2d(x, weight).sum().backward(), 8)
202

203
        for in_channels, out_channels, groups in [
204
            (1, 1, 1),
205
            (1, 3, 1),
206
            (3, 1, 1),
207
            (3, 7, 1),
208
            (2, 4, 2),
209
            (4, 2, 2),
210
        ]:
211
            x = torch.rand(1, in_channels, 4, 4, requires_grad=True)
212
            weight = torch.randn(out_channels, in_channels, 2, 2, requires_grad=True)
213
            assert_equivalence(lambda: F.conv2d(x, weight).sum().backward())
214
            transposed_weight = torch.randn(in_channels, out_channels, 2, 2, requires_grad=True)
215
            assert_equivalence(lambda: F.conv_transpose2d(x, transposed_weight).sum().backward())
216

217
    @skipIfNoTorchVision
218
    def test_module(self):
219
        resnet18 = torchvision_models.resnet18()
220
        mode = FlopCounterMode(resnet18)
221
        with mode:
222
            a = T(1, 3, 224, 224, requires_grad=True)
223
            resnet18(a).sum().backward()
224

225
        self.assertExpectedInline(get_total_flops(mode), """10884440064""")
226
        layer1_conv_flops = mode.flop_counts['ResNet.layer1'][torch.ops.aten.convolution]
227
        layer1_conv_back_flops = mode.flop_counts['ResNet.layer1'][torch.ops.aten.convolution_backward]
228
        self.assertExpectedInline(str(layer1_conv_flops), """924844032""")
229
        self.assertExpectedInline(str(layer1_conv_back_flops), """1849688064""")
230

231

232
    def test_conv_transpose_loop(self):
233
        x = torch.rand(1, 4, 30, 2)
234
        model = torch.nn.ConvTranspose2d(4, 8, (2, 2), stride=2)
235

236
        mode = FlopCounterMode(model)
237
        with mode:
238
            for i in range(50):
239
                out = model(x)
240
                out.sum().backward()
241
        self.assertExpectedInline(str(mode.get_total_flops()), """1536000""")
242

243
    def test_custom(self):
244
        mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: lambda *args, out_shape: 5})
245
        with mode:
246
            a = T(4, 5)
247
            a + a
248

249
        self.assertExpectedInline(get_total_flops(mode), """5""")
250

251
        def count(*args, out):
252
            return out.numel()
253
        count._get_raw = True
254

255
        mode = FlopCounterMode(custom_mapping={torch.ops.aten.add: count})
256
        with mode:
257
            a = T(4, 5)
258
            a + a
259

260
        self.assertExpectedInline(get_total_flops(mode), """20""")
261

262
    def test_noop(self):
263
        mode = FlopCounterMode()
264
        with mode:
265
            T(4, 5).cos()
266

267
    @unittest.skipIf(not HAS_CUDA, "CUDA not available")
268
    @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION or not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
269
                     "Does not support all SDPA backends (pre-SM80 hardware on CUDA)")
270
    def test_sdpa(self):
271
        batch_size = 4
272
        n_heads = 8
273
        seq_len_q = 128
274
        seq_len_k = 256
275
        head_dim = 64
276
        head_dim_v = 64
277
        dtype = torch.float16
278

279
        torch.manual_seed(0)
280

281
        def get_flops(batch_size, n_heads, seq_len_q, seq_len_k, head_dim, head_dim_v, dtype, backend, with_backward=False):
282
            query = torch.randn(batch_size, n_heads, seq_len_q, head_dim, device='cuda', dtype=dtype, requires_grad=True)
283
            key = torch.randn(batch_size, n_heads, seq_len_k, head_dim, device='cuda', dtype=dtype, requires_grad=True)
284
            value = torch.randn(batch_size, n_heads, seq_len_k, head_dim_v, device='cuda', dtype=dtype, requires_grad=True)
285

286
            if backend == "math":
287
                backend = torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False)
288
            elif backend == "flash":
289
                backend = torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False)
290
            elif backend == "mem_efficient":
291
                backend = torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True)
292

293
            mode = FlopCounterMode()
294
            with backend, mode:
295
                out = F.scaled_dot_product_attention(query, key, value, dropout_p=0, is_causal=True)
296
                if with_backward:
297
                    out.sum().backward()
298
            return int(get_total_flops(mode))
299

300
        # Sets seq_len_q == seq_len_k and dim_q == dim_v
301
        run_uniform_flops = functools.partial(get_flops, batch_size, n_heads, seq_len_q, seq_len_q, head_dim, head_dim, dtype)
302

303
        flops = [run_uniform_flops(backend, with_backward=False) for backend in ["math", "flash", "mem_efficient"]]
304
        flops_fw_math, flops_fw_flash, flops_fw_efficient = flops
305
        self.assertEqual(flops_fw_math, flops_fw_flash)
306
        self.assertEqual(flops_fw_math, flops_fw_efficient)
307

308
        self.assertExpectedInline(str(flops_fw_math), """134217728""")
309

310
        flops = [run_uniform_flops(backend, with_backward=True) for backend in ["math", "flash", "mem_efficient"]]
311
        flops_fw_bw_math, flops_fw_bw_flash, flops_fw_bw_efficient = flops
312
        self.assertEqual(flops_fw_math * 3, flops_fw_bw_math)
313
        self.assertEqual(flops_fw_math * 7 // 2, flops_fw_bw_flash)
314
        self.assertEqual(flops_fw_bw_flash, flops_fw_bw_efficient)
315

316

317
        run_nonuniform_flops = functools.partial(get_flops, batch_size, n_heads, seq_len_q, seq_len_k, head_dim, head_dim_v, dtype)
318
        # Flash does not support non-uniform attention, i.e. seq_len_q != seq_len_k or dim_q != dim_v"
319
        non_uniform_backends = ["math", "mem_efficient"]
320
        flops = [run_nonuniform_flops(backend, with_backward=False) for backend in non_uniform_backends]
321
        flops_fw_math, flops_fw_efficient = flops
322
        self.assertEqual(flops_fw_math, flops_fw_efficient)
323

324
        self.assertExpectedInline(str(flops_fw_math), """268435456""")
325

326
        flops = [run_nonuniform_flops(backend, with_backward=True) for backend in non_uniform_backends]
327
        flops_fw_bw_math, flops_fw_bw_efficient = flops
328
        self.assertExpectedInline(str(flops_fw_bw_math), """805306368""")
329
        self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""")
330

331
    def test_hook_registration(self):
332
        model = torch.nn.Linear(100, 100)
333
        x = torch.randn(3, 100)
334

335
        flop_counter = FlopCounterMode(model)
336
        with flop_counter:
337
            self.assertEqual(len(model._forward_pre_hooks), 1)
338
            self.assertEqual(len(model._forward_hooks), 1)
339
            model(x).sum().backward()
340

341
        self.assertEqual(len(model._forward_pre_hooks), 0)
342
        self.assertEqual(len(model._forward_hooks), 0)
343

344
    def test_pytrees(self):
345
        class Foo(torch.nn.Module):
346
            def forward(self, x):
347
                x = x['a'].relu_()
348
                return {'a': torch.mm(x, x)}
349

350
        class Mod(torch.nn.Module):
351
            def __init__(self):
352
                super().__init__()
353
                self.a = Foo()
354
                self.b = Foo()
355

356
            def forward(self, x):
357
                return self.b(self.a(x))
358

359
        mod = Mod()
360
        mode = FlopCounterMode(mod)
361
        with mode:
362
            mod({'a': torch.randn(10, 10, requires_grad=True).clone()})['a'].sum().backward()
363
        self.assertExpectedInline((mode.flop_counts['Mod'][torch.ops.aten.mm]), """12000""")
364

365
        class Mod2(torch.nn.Module):
366
            def forward(self, x):
367
                return (torch.mm(x, x),)
368

369
        mod = Mod2()
370
        mode = FlopCounterMode(mod)
371
        with mode:
372
            mod(torch.randn(10, 10, requires_grad=True))[0].sum().backward()
373
        self.assertExpectedInline((mode.flop_counts['Mod2'][torch.ops.aten.mm]), """6000""")
374

375

376

377

378

379
if __name__ == '__main__':
380
    run_tests()
381

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.