pytorch

Форк
0
/
flop_counter.py 
559 строк · 19.8 Кб
1
import torch
2
import torch.nn as nn
3
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
4
from typing import List, Any, Dict, Optional, Union, NamedTuple
5
from collections import defaultdict
6
from torch.utils._python_dispatch import TorchDispatchMode
7
from torch.utils.hooks import RemovableHandle
8
from torch._decomp import register_decomposition
9
from math import prod
10
from functools import wraps
11

12

13

14
__all__ = ["FlopCounterMode", "register_flop_formula"]
15

16
aten = torch.ops.aten
17

18
def get_shape(i):
19
    if isinstance(i, torch.Tensor):
20
        return i.shape
21
    return i
22

23
flop_registry: Dict[Any, Any] = {}
24

25
def shape_wrapper(f):
26
    @wraps(f)
27
    def nf(*args, out=None, **kwargs):
28
        args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
29
        return f(*args, out_shape=out_shape, **kwargs)
30
    return nf
31

32
def register_flop_formula(targets, get_raw=False):
33
    def register_fun(flop_formula):
34
        if not get_raw:
35
            flop_formula = shape_wrapper(flop_formula)
36
        register_decomposition(targets, registry=flop_registry, unsafe=True)(flop_formula)
37
        return flop_formula
38

39
    return register_fun
40

41
@register_flop_formula(aten.mm)
42
def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
43
    """Count flops for matmul."""
44
    # Inputs should be a list of length 2.
45
    # Inputs contains the shapes of two matrices.
46
    m, k = a_shape
47
    k2, n = b_shape
48
    assert k == k2
49
    # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
50
    return m * n * 2 * k
51

52
@register_flop_formula(aten.addmm)
53
def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
54
    """Count flops for addmm."""
55
    return mm_flop(a_shape, b_shape)
56

57
@register_flop_formula(aten.bmm)
58
def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
59
    """Count flops for the bmm operation."""
60
    # Inputs should be a list of length 2.
61
    # Inputs contains the shapes of two tensor.
62
    b, m, k = a_shape
63
    b2, k2, n = b_shape
64
    assert b == b2
65
    assert k == k2
66
    # NB(chilli): Should be 2 * k - 1 technically for FLOPs.
67
    flop = b * m * n * 2 * k
68
    return flop
69

70
@register_flop_formula(aten.baddbmm)
71
def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
72
    """Count flops for the baddbmm operation."""
73
    # Inputs should be a list of length 3.
74
    # Inputs contains the shapes of three tensors.
75
    return bmm_flop(a_shape, b_shape)
76

77

78
def conv_flop_count(
79
    x_shape: List[int],
80
    w_shape: List[int],
81
    out_shape: List[int],
82
    transposed: bool = False,
83
) -> int:
84
    """Count flops for convolution.
85

86
    Note only multiplication is
87
    counted. Computation for bias are ignored.
88
    Flops for a transposed convolution are calculated as
89
    flops = (x_shape[2:] * prod(w_shape) * batch_size).
90
    Args:
91
        x_shape (list(int)): The input shape before convolution.
92
        w_shape (list(int)): The filter shape.
93
        out_shape (list(int)): The output shape after convolution.
94
        transposed (bool): is the convolution transposed
95
    Returns:
96
        int: the number of flops
97
    """
98

99
    batch_size = x_shape[0]
100
    conv_shape = (x_shape if transposed else out_shape)[2:]
101
    c_out, c_in, *filter_size = w_shape
102

103
    """
104
    General idea here is that for a regular conv, for each point in the output
105
    spatial dimension we convolve the filter with something (hence
106
    `prod(conv_shape) * prod(filter_size)` ops). Then, this gets multiplied by
107
    1. batch_size, 2. the cross product of input and weight channels.
108

109
    For the transpose, it's not each point in the *output* spatial dimension but
110
    each point in the *input* spatial dimension.
111
    """
112
    # NB(chilli): I don't think this properly accounts for padding :think:
113
    # NB(chilli): Should be 2 * c_in - 1 technically for FLOPs.
114
    flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2
115
    return flop
116

117
@register_flop_formula([aten.convolution, aten._convolution])
118
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
119
    """Count flops for convolution."""
120
    return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
121

122

123
@register_flop_formula(aten.convolution_backward)
124
def conv_backward_flop(
125
        grad_out_shape,
126
        x_shape,
127
        w_shape,
128
        _bias,
129
        _stride,
130
        _padding,
131
        _dilation,
132
        transposed,
133
        _output_padding,
134
        _groups,
135
        output_mask,
136
        out_shape) -> int:
137

138
    def t(shape):
139
        return [shape[1], shape[0]] + list(shape[2:])
140
    flop_count = 0
141

142
    """
143
    Let's say we have a regular 1D conv
144
    {A, B, C} [inp]
145
    {i, j} [weight]
146
    => (conv)
147
    {Ai + Bj, Bi + Cj} [out]
148

149
    And as a reminder, the transposed conv of the above is
150
    => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
151

152
    For the backwards of conv, we now have
153
    {D, E} [grad_out]
154
    {A, B, C} [inp]
155
    {i, j} [weight]
156

157
    # grad_inp as conv_transpose(grad_out, weight)
158
    Let's first compute grad_inp. To do so, we can simply look at all the
159
    multiplications that each element of inp is involved in. For example, A is
160
    only involved in the first element of the output (and thus only depends upon
161
    D in grad_out), and C is only involved in the last element of the output
162
    (and thus only depends upon E in grad_out)
163

164
    {Di, Dj + Ei, Ej} [grad_inp]
165

166
    Note that this corresponds to the below conv_transpose. This gives us the
167
    output_mask[0] branch, which is grad_inp.
168

169
    {D, E} [inp (grad_out)]
170
    {i, j} [weight]
171
    => (conv_transpose)
172
    {Di, Dj + Ei, Ej} [out (grad_inp)]
173

174
    I leave the fact that grad_inp for a transposed conv is just conv(grad_out,
175
    weight) as an exercise for the reader.
176

177
    # grad_weight as conv(inp, grad_out)
178
    To compute grad_weight, we again look at the terms in the output, which as
179
    a reminder is:
180
    => {Ai + Bj, Bi + Cj} [out]
181
    => {D, E} [grad_out]
182
    If we manually compute the gradient for the weights, we see it's
183
    {AD + BE, BD + CE} [grad_weight]
184

185
    This corresponds to the below conv
186
    {A, B, C} [inp]
187
    {D, E} [weight (grad_out)]
188
    => (conv)
189
    {AD + BE, BD + CE} [out (grad_weight)]
190

191
    # grad_weight of transposed conv as conv(grad_out, inp)
192
    As a reminder, the terms of the output of a transposed conv are:
193
    => {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
194
    => {D, E, F, G} [grad_out]
195

196
    Manually computing the gradient for the weights, we see it's
197
    {AD + BE + CF, AE + BF + CG} [grad_weight]
198

199
    This corresponds to the below conv
200
    {D, E, F, G} [inp (grad_out)]
201
    {A, B, C} [weight (inp)]
202
    => (conv)
203
    {AD + BE + CF, AE + BF + CG} [out (grad_weight)]
204

205
    For the full backwards formula, there are also some details involving
206
    transpose of the batch/channel dimensions and groups, but I skip those for
207
    the sake of brevity (and they're pretty similar to matmul backwards)
208

209
    Check [conv backwards decomposition as conv forwards]
210
    """
211
    # grad_inp as conv_transpose(grad_out, weight)
212
    if output_mask[0]:
213
        grad_input_shape = get_shape(out_shape[0])
214
        flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed)
215

216
    if output_mask[1]:
217
        grad_weight_shape = get_shape(out_shape[1])
218
        if transposed:
219
            # grad_weight of transposed conv as conv(grad_out, inp)
220
            flop_count += conv_flop_count(t(grad_out_shape), t(x_shape), t(grad_weight_shape), transposed=False)
221
        else:
222
            # grad_weight as conv(inp, grad_out)
223
            flop_count += conv_flop_count(t(x_shape), t(grad_out_shape), t(grad_weight_shape), transposed=False)
224

225
    return flop_count
226

227
def sdpa_flop_count(query_shape, key_shape, value_shape):
228
    """
229
    Count flops for self-attention.
230

231
    NB: We can assume that value_shape == key_shape
232
    """
233
    b, h, s_q, d_q = query_shape
234
    _b2, _h2, s_k, _d2 = key_shape
235
    _b3, _h3, _s3, d_v = value_shape
236
    assert b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2
237
    total_flops = 0
238
    # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
239
    total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
240
    # scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v]
241
    total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v))
242
    return total_flops
243

244

245
@register_flop_formula([aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention])
246
def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
247
    """Count flops for self-attention."""
248
    # NB: We aren't accounting for causal attention here
249
    return sdpa_flop_count(query_shape, key_shape, value_shape)
250

251

252
def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape):
253
    total_flops = 0
254
    b, h, s_q, d_q = query_shape
255
    _b2, _h2, s_k, _d2 = key_shape
256
    _b3, _h3, _s3, d_v = value_shape
257
    _b4, _h4, _s4, _d4 = grad_out_shape
258
    assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2
259
    assert d_v == _d4 and s_k == _s3 and s_q == _s4
260
    total_flops = 0
261
    # Step 1: We recompute the scores matrix.
262
    # q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
263
    total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
264

265
    # Step 2: We propagate the gradients through the score @ v operation.
266
    # gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k]
267
    total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k))
268
    # scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v]
269
    total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v))
270

271
    # Step 3: We propagate th gradients through the k @ v operation
272
    # gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q]
273
    total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q))
274
    # q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k]
275
    total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k))
276
    return total_flops
277

278

279
@register_flop_formula([aten._scaled_dot_product_efficient_attention_backward, aten._scaled_dot_product_flash_attention_backward])
280
def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
281
    """Count flops for self-attention backward."""
282
    return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
283

284
flop_registry = {
285
    aten.mm: mm_flop,
286
    aten.addmm: addmm_flop,
287
    aten.bmm: bmm_flop,
288
    aten.baddbmm: baddbmm_flop,
289
    aten.convolution: conv_flop,
290
    aten._convolution: conv_flop,
291
    aten.convolution_backward: conv_backward_flop,
292
    aten._scaled_dot_product_efficient_attention: sdpa_flop,
293
    aten._scaled_dot_product_flash_attention: sdpa_flop,
294
    aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop,
295
    aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop,
296
}
297

298
def normalize_tuple(x):
299
    if not isinstance(x, tuple):
300
        return (x,)
301
    return x
302

303

304
# Define the suffixes for different orders of magnitude
305
suffixes = ["", "K", "M", "B", "T"]
306
# Thanks BingChat!
307
def get_suffix_str(number):
308
    # Find the index of the appropriate suffix based on the number of digits
309
    # with some additional overflow.
310
    # i.e. 1.01B should be displayed as 1001M, not 1.001B
311
    index = max(0, min(len(suffixes) - 1, (len(str(number)) - 2) // 3))
312
    return suffixes[index]
313

314
def convert_num_with_suffix(number, suffix):
315
    index = suffixes.index(suffix)
316
    # Divide the number by 1000^index and format it to two decimal places
317
    value = f"{number / 1000 ** index:.3f}"
318
    # Return the value and the suffix as a string
319
    return value + suffixes[index]
320

321
def convert_to_percent_str(num, denom):
322
    if denom == 0:
323
        return "0%"
324
    return f"{num / denom:.2%}"
325

326
def _pytreeify_preserve_structure(f):
327
    @wraps(f)
328
    def nf(args):
329
        flat_args, spec = tree_flatten(args)
330
        out = f(*flat_args)
331
        return tree_unflatten(out, spec)
332

333
    return nf
334

335

336
class FlopCounterMode(TorchDispatchMode):
337
    """
338
    ``FlopCounterMode`` is a context manager that counts the number of flops within its context.
339

340
    It does this using a ``TorchDispatchMode``.
341

342
    It also supports hierarchical output by passing a module (or list of
343
    modules) to FlopCounterMode on construction. If you do not need hierarchical
344
    output, you do not need to use it with a module.
345

346
    Example usage
347

348
    .. code-block:: python
349

350
        mod = ...
351
        flop_counter = FlopCounterMode(mod)
352
        with flop_counter:
353
            mod.sum().backward()
354

355
    """
356

357
    def __init__(
358
            self,
359
            mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
360
            depth: int = 2,
361
            display: bool = True,
362
            custom_mapping: Optional[Dict[Any, Any]] = None):
363
        self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(lambda: defaultdict(int))
364
        self.depth = depth
365
        self.parents = ["Global"]
366
        self.in_backward = False
367
        self.display = display
368
        if custom_mapping is None:
369
            custom_mapping = {}
370
        if isinstance(mods, torch.nn.Module):
371
            mods = [mods]
372
        self.mods = mods
373
        # Keys will include the modules in `mods` and their submodules
374
        self._module_to_forward_hook_handles: Dict[nn.Module, _ForwardHookHandles] = {}
375
        self.flop_registry = {
376
            **flop_registry,
377
            **{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
378
        }
379

380
    def _register_forward_hooks(self):
381
        if self.mods is None:
382
            return
383
        for mod in self.mods:
384
            prefix = type(mod).__name__
385
            for name, module in dict(mod.named_modules()).items():
386
                if name == "":
387
                    name = prefix
388
                else:
389
                    name = ".".join([prefix, name])
390

391
                forward_pre_hook_handle = module.register_forward_pre_hook(self._enter_module(name))
392
                forward_hook_handle = module.register_forward_hook(self._exit_module(name))
393
                self._module_to_forward_hook_handles[module] = _ForwardHookHandles(
394
                    forward_pre_hook_handle, forward_hook_handle
395
                )
396

397
    def _deregister_forward_hooks(self):
398
        for forward_hook_handles in self._module_to_forward_hook_handles.values():
399
            forward_hook_handles[0].remove()
400
            forward_hook_handles[1].remove()
401
        self._module_to_forward_hook_handles.clear()
402

403
    def _enter_module(self, name):
404
        def f(module, inputs):
405
            out = _pytreeify_preserve_structure(self._create_pre_module(name))(inputs)
406
            return out
407

408
        return f
409

410
    def _exit_module(self, name):
411
        def f(module, inputs, outputs):
412
            outputs = _pytreeify_preserve_structure(self._create_post_module(name))(outputs)
413
            return outputs
414
        return f
415

416
    def _create_post_module(self, name):
417
        class PushState(torch.autograd.Function):
418
            @staticmethod
419
            def forward(ctx, *args):
420
                assert self.parents[-1] == name, f"{self.parents[-1]} is not {name}"
421
                self.parents.pop()
422
                args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
423
                return args
424

425
            @staticmethod
426
            def backward(ctx, *grad_outs):
427
                self.in_backward = True
428
                self.parents.append(name)
429
                return grad_outs
430

431
        return PushState.apply
432

433
    def _create_pre_module(self, name):
434
        class PopState(torch.autograd.Function):
435
            @staticmethod
436
            def forward(ctx, *args):
437
                if self.in_backward:
438
                    self.parents = ["Global"]
439
                    self.in_backward = True
440
                self.parents.append(name)
441
                args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
442
                return args
443

444
            @staticmethod
445
            def backward(ctx, *grad_outs):
446
                assert self.parents[-1] == name
447
                self.parents.pop()
448
                return grad_outs
449

450
        return PopState.apply
451

452
    def get_total_flops(self) -> int:
453
        return sum(self.flop_counts['Global'].values())
454

455
    def get_flop_counts(self) -> Dict[str, Dict[Any, int]]:
456
        """Return the flop counts as a dictionary of dictionaries.
457

458
        The outer
459
        dictionary is keyed by module name, and the inner dictionary is keyed by
460
        operation name.
461

462
        Returns:
463
            Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
464
        """
465
        return {k: dict(v) for k, v in self.flop_counts.items()}
466

467
    def get_table(self, depth=None):
468
        if depth is None:
469
            depth = self.depth
470
        if depth is None:
471
            depth = 999999
472

473
        import tabulate
474
        tabulate.PRESERVE_WHITESPACE = True
475
        header = ["Module", "FLOP", "% Total"]
476
        values = []
477
        global_flops = self.get_total_flops()
478
        global_suffix = get_suffix_str(global_flops)
479
        is_global_subsumed = False
480

481
        def process_mod(mod_name, depth):
482
            nonlocal is_global_subsumed
483

484
            total_flops = sum(self.flop_counts[mod_name].values())
485

486
            is_global_subsumed |= total_flops >= global_flops
487

488
            padding = " " * depth
489
            values = []
490
            values.append([
491
                padding + mod_name,
492
                convert_num_with_suffix(total_flops, global_suffix),
493
                convert_to_percent_str(total_flops, global_flops)
494
            ])
495
            for k, v in self.flop_counts[mod_name].items():
496
                values.append([
497
                    padding + " - " + str(k),
498
                    convert_num_with_suffix(v, global_suffix),
499
                    convert_to_percent_str(v, global_flops)
500
                ])
501
            return values
502

503
        for mod in self.flop_counts.keys():
504
            if mod == 'Global':
505
                continue
506
            mod_depth = mod.count(".") + 1
507
            if mod_depth > depth:
508
                continue
509

510
            cur_values = process_mod(mod, mod_depth - 1)
511
            values.extend(cur_values)
512

513
        # We do a bit of messing around here to only output the "Global" value
514
        # if there are any FLOPs in there that aren't already fully contained by
515
        # a module.
516
        if 'Global' in self.flop_counts and not is_global_subsumed:
517
            for idx, value in enumerate(values):
518
                values[idx][0] = " " + values[idx][0]
519

520
            values = process_mod('Global', 0) + values
521

522
        if len(values) == 0:
523
            values = [["Global", "0", "0%"]]
524

525
        return tabulate.tabulate(values, headers=header, colalign=("left", "right", "right"))
526

527
    def __enter__(self):
528
        self.flop_counts.clear()
529
        self._register_forward_hooks()
530
        super().__enter__()
531
        return self
532

533
    def __exit__(self, *args):
534
        if self.display:
535
            print(self.get_table(self.depth))
536
        self._deregister_forward_hooks()
537
        super().__exit__(*args)
538

539
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
540
        kwargs = kwargs if kwargs else {}
541
        out = func(*args, **kwargs)
542
        func_packet = func._overloadpacket
543
        if func_packet in self.flop_registry:
544
            flop_count_func = self.flop_registry[func_packet]
545
            flop_count = flop_count_func(*args, **kwargs, out=out)  # type: ignore[operator]
546
            if len(set(self.parents)) != len(self.parents):
547
                print(
548
                    "The module hierarchy tracking seems to be messed up."
549
                    "Please file a bug or just run the flop counter without"
550
                    "tracking the module hierarchy (i.e. `with FlopCounterMode():`)"
551
                )
552
            for par in set(self.parents):
553
                self.flop_counts[par][func_packet] += flop_count
554

555
        return out
556

557
class _ForwardHookHandles(NamedTuple):
558
    forward_pre_hook_handle: RemovableHandle
559
    forward_hook_handle: RemovableHandle
560

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.