3
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
4
from typing import List, Any, Dict, Optional, Union, NamedTuple
5
from collections import defaultdict
6
from torch.utils._python_dispatch import TorchDispatchMode
7
from torch.utils.hooks import RemovableHandle
8
from torch._decomp import register_decomposition
10
from functools import wraps
14
__all__ = ["FlopCounterMode", "register_flop_formula"]
19
if isinstance(i, torch.Tensor):
23
flop_registry: Dict[Any, Any] = {}
27
def nf(*args, out=None, **kwargs):
28
args, kwargs, out_shape = tree_map(get_shape, (args, kwargs, out))
29
return f(*args, out_shape=out_shape, **kwargs)
32
def register_flop_formula(targets, get_raw=False):
33
def register_fun(flop_formula):
35
flop_formula = shape_wrapper(flop_formula)
36
register_decomposition(targets, registry=flop_registry, unsafe=True)(flop_formula)
41
@register_flop_formula(aten.mm)
42
def mm_flop(a_shape, b_shape, *args, out_shape=None, **kwargs) -> int:
43
"""Count flops for matmul."""
44
# Inputs should be a list of length 2.
45
# Inputs contains the shapes of two matrices.
49
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
52
@register_flop_formula(aten.addmm)
53
def addmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
54
"""Count flops for addmm."""
55
return mm_flop(a_shape, b_shape)
57
@register_flop_formula(aten.bmm)
58
def bmm_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int:
59
"""Count flops for the bmm operation."""
60
# Inputs should be a list of length 2.
61
# Inputs contains the shapes of two tensor.
66
# NB(chilli): Should be 2 * k - 1 technically for FLOPs.
67
flop = b * m * n * 2 * k
70
@register_flop_formula(aten.baddbmm)
71
def baddbmm_flop(self_shape, a_shape, b_shape, out_shape=None, **kwargs) -> int:
72
"""Count flops for the baddbmm operation."""
73
# Inputs should be a list of length 3.
74
# Inputs contains the shapes of three tensors.
75
return bmm_flop(a_shape, b_shape)
82
transposed: bool = False,
84
"""Count flops for convolution.
86
Note only multiplication is
87
counted. Computation for bias are ignored.
88
Flops for a transposed convolution are calculated as
89
flops = (x_shape[2:] * prod(w_shape) * batch_size).
91
x_shape (list(int)): The input shape before convolution.
92
w_shape (list(int)): The filter shape.
93
out_shape (list(int)): The output shape after convolution.
94
transposed (bool): is the convolution transposed
96
int: the number of flops
99
batch_size = x_shape[0]
100
conv_shape = (x_shape if transposed else out_shape)[2:]
101
c_out, c_in, *filter_size = w_shape
104
General idea here is that for a regular conv, for each point in the output
105
spatial dimension we convolve the filter with something (hence
106
`prod(conv_shape) * prod(filter_size)` ops). Then, this gets multiplied by
107
1. batch_size, 2. the cross product of input and weight channels.
109
For the transpose, it's not each point in the *output* spatial dimension but
110
each point in the *input* spatial dimension.
112
# NB(chilli): I don't think this properly accounts for padding :think:
113
# NB(chilli): Should be 2 * c_in - 1 technically for FLOPs.
114
flop = prod(conv_shape) * prod(filter_size) * batch_size * c_out * c_in * 2
117
@register_flop_formula([aten.convolution, aten._convolution])
118
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
119
"""Count flops for convolution."""
120
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
123
@register_flop_formula(aten.convolution_backward)
124
def conv_backward_flop(
139
return [shape[1], shape[0]] + list(shape[2:])
143
Let's say we have a regular 1D conv
147
{Ai + Bj, Bi + Cj} [out]
149
And as a reminder, the transposed conv of the above is
150
=> {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
152
For the backwards of conv, we now have
157
# grad_inp as conv_transpose(grad_out, weight)
158
Let's first compute grad_inp. To do so, we can simply look at all the
159
multiplications that each element of inp is involved in. For example, A is
160
only involved in the first element of the output (and thus only depends upon
161
D in grad_out), and C is only involved in the last element of the output
162
(and thus only depends upon E in grad_out)
164
{Di, Dj + Ei, Ej} [grad_inp]
166
Note that this corresponds to the below conv_transpose. This gives us the
167
output_mask[0] branch, which is grad_inp.
169
{D, E} [inp (grad_out)]
172
{Di, Dj + Ei, Ej} [out (grad_inp)]
174
I leave the fact that grad_inp for a transposed conv is just conv(grad_out,
175
weight) as an exercise for the reader.
177
# grad_weight as conv(inp, grad_out)
178
To compute grad_weight, we again look at the terms in the output, which as
180
=> {Ai + Bj, Bi + Cj} [out]
182
If we manually compute the gradient for the weights, we see it's
183
{AD + BE, BD + CE} [grad_weight]
185
This corresponds to the below conv
187
{D, E} [weight (grad_out)]
189
{AD + BE, BD + CE} [out (grad_weight)]
191
# grad_weight of transposed conv as conv(grad_out, inp)
192
As a reminder, the terms of the output of a transposed conv are:
193
=> {Ai, Aj + Bi, Bj + Ci, Cj} [transposed conv out]
194
=> {D, E, F, G} [grad_out]
196
Manually computing the gradient for the weights, we see it's
197
{AD + BE + CF, AE + BF + CG} [grad_weight]
199
This corresponds to the below conv
200
{D, E, F, G} [inp (grad_out)]
201
{A, B, C} [weight (inp)]
203
{AD + BE + CF, AE + BF + CG} [out (grad_weight)]
205
For the full backwards formula, there are also some details involving
206
transpose of the batch/channel dimensions and groups, but I skip those for
207
the sake of brevity (and they're pretty similar to matmul backwards)
209
Check [conv backwards decomposition as conv forwards]
211
# grad_inp as conv_transpose(grad_out, weight)
213
grad_input_shape = get_shape(out_shape[0])
214
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not transposed)
217
grad_weight_shape = get_shape(out_shape[1])
219
# grad_weight of transposed conv as conv(grad_out, inp)
220
flop_count += conv_flop_count(t(grad_out_shape), t(x_shape), t(grad_weight_shape), transposed=False)
222
# grad_weight as conv(inp, grad_out)
223
flop_count += conv_flop_count(t(x_shape), t(grad_out_shape), t(grad_weight_shape), transposed=False)
227
def sdpa_flop_count(query_shape, key_shape, value_shape):
229
Count flops for self-attention.
231
NB: We can assume that value_shape == key_shape
233
b, h, s_q, d_q = query_shape
234
_b2, _h2, s_k, _d2 = key_shape
235
_b3, _h3, _s3, d_v = value_shape
236
assert b == _b2 == _b3 and h == _h2 == _h3 and d_q == _d2 and s_k == _s3 and d_q == _d2
238
# q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
239
total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
240
# scores: [b, h, s_q, s_k] @ v: [b, h, s_k, d_v] -> out: [b, h, s_q, d_v]
241
total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_v))
245
@register_flop_formula([aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention])
246
def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
247
"""Count flops for self-attention."""
248
# NB: We aren't accounting for causal attention here
249
return sdpa_flop_count(query_shape, key_shape, value_shape)
252
def sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape):
254
b, h, s_q, d_q = query_shape
255
_b2, _h2, s_k, _d2 = key_shape
256
_b3, _h3, _s3, d_v = value_shape
257
_b4, _h4, _s4, _d4 = grad_out_shape
258
assert b == _b2 == _b3 == _b4 and h == _h2 == _h3 == _h4 and d_q == _d2
259
assert d_v == _d4 and s_k == _s3 and s_q == _s4
261
# Step 1: We recompute the scores matrix.
262
# q: [b, h, s_q, d_q] @ k: [b, h, d_q, s_k] -> scores: [b, h, s_q, s_k]
263
total_flops += bmm_flop((b * h, s_q, d_q), (b * h, d_q, s_k))
265
# Step 2: We propagate the gradients through the score @ v operation.
266
# gradOut: [b, h, s_q, d_v] @ v: [b, h, d_v, s_k] -> gradScores: [b, h, s_q, s_k]
267
total_flops += bmm_flop((b * h, s_q, d_v), (b * h, d_v, s_k))
268
# scores: [b, h, s_k, s_q] @ gradOut: [b, h, s_q, d_v] -> gradV: [b, h, s_k, d_v]
269
total_flops += bmm_flop((b * h, s_k, s_q), (b * h, s_q, d_v))
271
# Step 3: We propagate th gradients through the k @ v operation
272
# gradScores: [b, h, s_q, s_k] @ k: [b, h, s_k, d_q] -> gradQ: [b, h, s_q, d_q]
273
total_flops += bmm_flop((b * h, s_q, s_k), (b * h, s_k, d_q))
274
# q: [b, h, d_q, s_q] @ gradScores: [b, h, s_q, s_k] -> gradK: [b, h, d_q, s_k]
275
total_flops += bmm_flop((b * h, d_q, s_q), (b * h, s_q, s_k))
279
@register_flop_formula([aten._scaled_dot_product_efficient_attention_backward, aten._scaled_dot_product_flash_attention_backward])
280
def sdpa_backward_flop(grad_out_shape, query_shape, key_shape, value_shape, *args, out_shape=None, **kwargs) -> int:
281
"""Count flops for self-attention backward."""
282
return sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape)
286
aten.addmm: addmm_flop,
288
aten.baddbmm: baddbmm_flop,
289
aten.convolution: conv_flop,
290
aten._convolution: conv_flop,
291
aten.convolution_backward: conv_backward_flop,
292
aten._scaled_dot_product_efficient_attention: sdpa_flop,
293
aten._scaled_dot_product_flash_attention: sdpa_flop,
294
aten._scaled_dot_product_efficient_attention_backward: sdpa_backward_flop,
295
aten._scaled_dot_product_flash_attention_backward: sdpa_backward_flop,
298
def normalize_tuple(x):
299
if not isinstance(x, tuple):
304
# Define the suffixes for different orders of magnitude
305
suffixes = ["", "K", "M", "B", "T"]
307
def get_suffix_str(number):
308
# Find the index of the appropriate suffix based on the number of digits
309
# with some additional overflow.
310
# i.e. 1.01B should be displayed as 1001M, not 1.001B
311
index = max(0, min(len(suffixes) - 1, (len(str(number)) - 2) // 3))
312
return suffixes[index]
314
def convert_num_with_suffix(number, suffix):
315
index = suffixes.index(suffix)
316
# Divide the number by 1000^index and format it to two decimal places
317
value = f"{number / 1000 ** index:.3f}"
318
# Return the value and the suffix as a string
319
return value + suffixes[index]
321
def convert_to_percent_str(num, denom):
324
return f"{num / denom:.2%}"
326
def _pytreeify_preserve_structure(f):
329
flat_args, spec = tree_flatten(args)
331
return tree_unflatten(out, spec)
336
class FlopCounterMode(TorchDispatchMode):
338
``FlopCounterMode`` is a context manager that counts the number of flops within its context.
340
It does this using a ``TorchDispatchMode``.
342
It also supports hierarchical output by passing a module (or list of
343
modules) to FlopCounterMode on construction. If you do not need hierarchical
344
output, you do not need to use it with a module.
348
.. code-block:: python
351
flop_counter = FlopCounterMode(mod)
359
mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None,
361
display: bool = True,
362
custom_mapping: Optional[Dict[Any, Any]] = None):
363
self.flop_counts: Dict[str, Dict[Any, int]] = defaultdict(lambda: defaultdict(int))
365
self.parents = ["Global"]
366
self.in_backward = False
367
self.display = display
368
if custom_mapping is None:
370
if isinstance(mods, torch.nn.Module):
373
# Keys will include the modules in `mods` and their submodules
374
self._module_to_forward_hook_handles: Dict[nn.Module, _ForwardHookHandles] = {}
375
self.flop_registry = {
377
**{k: v if getattr(v, "_get_raw", False) else shape_wrapper(v) for k, v in custom_mapping.items()}
380
def _register_forward_hooks(self):
381
if self.mods is None:
383
for mod in self.mods:
384
prefix = type(mod).__name__
385
for name, module in dict(mod.named_modules()).items():
389
name = ".".join([prefix, name])
391
forward_pre_hook_handle = module.register_forward_pre_hook(self._enter_module(name))
392
forward_hook_handle = module.register_forward_hook(self._exit_module(name))
393
self._module_to_forward_hook_handles[module] = _ForwardHookHandles(
394
forward_pre_hook_handle, forward_hook_handle
397
def _deregister_forward_hooks(self):
398
for forward_hook_handles in self._module_to_forward_hook_handles.values():
399
forward_hook_handles[0].remove()
400
forward_hook_handles[1].remove()
401
self._module_to_forward_hook_handles.clear()
403
def _enter_module(self, name):
404
def f(module, inputs):
405
out = _pytreeify_preserve_structure(self._create_pre_module(name))(inputs)
410
def _exit_module(self, name):
411
def f(module, inputs, outputs):
412
outputs = _pytreeify_preserve_structure(self._create_post_module(name))(outputs)
416
def _create_post_module(self, name):
417
class PushState(torch.autograd.Function):
419
def forward(ctx, *args):
420
assert self.parents[-1] == name, f"{self.parents[-1]} is not {name}"
422
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
426
def backward(ctx, *grad_outs):
427
self.in_backward = True
428
self.parents.append(name)
431
return PushState.apply
433
def _create_pre_module(self, name):
434
class PopState(torch.autograd.Function):
436
def forward(ctx, *args):
438
self.parents = ["Global"]
439
self.in_backward = True
440
self.parents.append(name)
441
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
445
def backward(ctx, *grad_outs):
446
assert self.parents[-1] == name
450
return PopState.apply
452
def get_total_flops(self) -> int:
453
return sum(self.flop_counts['Global'].values())
455
def get_flop_counts(self) -> Dict[str, Dict[Any, int]]:
456
"""Return the flop counts as a dictionary of dictionaries.
459
dictionary is keyed by module name, and the inner dictionary is keyed by
463
Dict[str, Dict[Any, int]]: The flop counts as a dictionary.
465
return {k: dict(v) for k, v in self.flop_counts.items()}
467
def get_table(self, depth=None):
474
tabulate.PRESERVE_WHITESPACE = True
475
header = ["Module", "FLOP", "% Total"]
477
global_flops = self.get_total_flops()
478
global_suffix = get_suffix_str(global_flops)
479
is_global_subsumed = False
481
def process_mod(mod_name, depth):
482
nonlocal is_global_subsumed
484
total_flops = sum(self.flop_counts[mod_name].values())
486
is_global_subsumed |= total_flops >= global_flops
488
padding = " " * depth
492
convert_num_with_suffix(total_flops, global_suffix),
493
convert_to_percent_str(total_flops, global_flops)
495
for k, v in self.flop_counts[mod_name].items():
497
padding + " - " + str(k),
498
convert_num_with_suffix(v, global_suffix),
499
convert_to_percent_str(v, global_flops)
503
for mod in self.flop_counts.keys():
506
mod_depth = mod.count(".") + 1
507
if mod_depth > depth:
510
cur_values = process_mod(mod, mod_depth - 1)
511
values.extend(cur_values)
513
# We do a bit of messing around here to only output the "Global" value
514
# if there are any FLOPs in there that aren't already fully contained by
516
if 'Global' in self.flop_counts and not is_global_subsumed:
517
for idx, value in enumerate(values):
518
values[idx][0] = " " + values[idx][0]
520
values = process_mod('Global', 0) + values
523
values = [["Global", "0", "0%"]]
525
return tabulate.tabulate(values, headers=header, colalign=("left", "right", "right"))
528
self.flop_counts.clear()
529
self._register_forward_hooks()
533
def __exit__(self, *args):
535
print(self.get_table(self.depth))
536
self._deregister_forward_hooks()
537
super().__exit__(*args)
539
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
540
kwargs = kwargs if kwargs else {}
541
out = func(*args, **kwargs)
542
func_packet = func._overloadpacket
543
if func_packet in self.flop_registry:
544
flop_count_func = self.flop_registry[func_packet]
545
flop_count = flop_count_func(*args, **kwargs, out=out) # type: ignore[operator]
546
if len(set(self.parents)) != len(self.parents):
548
"The module hierarchy tracking seems to be messed up."
549
"Please file a bug or just run the flop counter without"
550
"tracking the module hierarchy (i.e. `with FlopCounterMode():`)"
552
for par in set(self.parents):
553
self.flop_counts[par][func_packet] += flop_count
557
class _ForwardHookHandles(NamedTuple):
558
forward_pre_hook_handle: RemovableHandle
559
forward_hook_handle: RemovableHandle