5
from typing import Dict, List, Optional, Set
8
import torch.utils.benchmark as benchmark
9
from torch._C._profiler import (
16
from torch.profiler import profile
17
from torch.profiler._utils import index_of_first_match, traverse_bfs, traverse_dfs
22
Base class for all patterns, subclass this class and implement match()
23
to define custom patterns.
25
In subclass, define description and skip property.
28
def __init__(self, prof: profile, should_benchmark: bool = False):
30
self.should_benchmark = should_benchmark
31
self.name = "Please specify a name for pattern"
32
self.description = "Please specify a description for pattern"
34
assert prof.profiler is not None and prof.profiler.kineto_results is not None
35
self.event_tree = prof.profiler.kineto_results.experimental_event_tree()
36
self.tid_root: Dict[int, List[_ProfilerEvent]] = {}
37
for event in self.event_tree:
38
self.tid_root.setdefault(event.start_tid, []).append(event)
44
def report(self, event: _ProfilerEvent):
46
f"{self.description}\n[Source Code Location] {source_code_location(event)}"
50
def eventTreeTraversal(self):
52
Traverse the event tree and yield all events.
53
Override this method in subclass to customize the traversal.
55
yield from traverse_dfs(self.event_tree)
57
def summary(self, events: List[_ProfilerEvent]):
58
default_summary = f"{self.name}: {len(events)} events matched."
59
if self.should_benchmark:
62
self.benchmark_summary(events)
63
if hasattr(self, "benchmark")
66
return default_summary
68
def benchmark_summary(self, events: List[_ProfilerEvent]):
69
def format_time(time_ns: int):
70
unit_lst = ["ns", "us", "ms"]
73
return f"{time_ns:.2f} {unit}"
75
return f"{time_ns:.2f} s"
77
assert hasattr(self, "benchmark"), "Please implement benchmark()"
78
shapes_factor_map = self.benchmark(events)
79
original_time = sum(event.duration_time_ns for event in events)
81
shapes_factor_map[input_shapes(event)] * event.duration_time_ns
85
f"{self.name}: {len(events)} events matched. "
86
f"Total Estimated Speedup: {format_time(original_time - new_time)} ({round(original_time/new_time, 2)}X)"
89
def match(self, event: _ProfilerEvent):
91
Return True if the event matches the pattern.
92
This method should be overriden in subclass.
94
raise NotImplementedError
96
def matched_events(self):
100
for event in self.eventTreeTraversal():
101
if self.match(event):
102
matched_events.append(event)
103
return matched_events
105
def root_of(self, event: _ProfilerEvent):
110
def siblings_of(self, event: _ProfilerEvent):
112
children = event.parent.children
114
children = self.tid_root[event.start_tid]
115
index = children.index(event)
116
return children[:index], children[index + 1 :]
118
def next_of(self, event: _ProfilerEvent):
119
_, next_events = self.siblings_of(event)
120
return next_events[0] if next_events else None
122
def prev_of(self, event: _ProfilerEvent):
123
prev_events, _ = self.siblings_of(event)
124
return prev_events[-1] if prev_events else None
126
def go_up_until(self, event: _ProfilerEvent, predicate):
129
while event.parent and not predicate(event):
137
class NamePattern(Pattern):
138
def __init__(self, prof: profile, name: str, should_benchmark: bool = False):
139
super().__init__(prof, should_benchmark)
140
self.description = f"Matched Name Event: {name}"
143
def match(self, event: _ProfilerEvent):
144
return re.search(self.name, event.name) is not None
147
class ExtraCUDACopyPattern(Pattern):
149
This pattern identifies if we creates a constant tensor on CPU and immediately moves it to GPU.
150
example: torch.zeros((100, 100)).to("cuda")
153
build-in method |build-in method
155
aten::fill_/aten::zero_ | aten::_to_copy
158
We start at node aten::to, go parent events' previous events,
159
and check if we have a aten::fill_/aten::zero_ as we keep going down the tree.
160
We always select the last child in the children list when we go down the tree.
161
If at any step we failed, it is not a match.
164
def __init__(self, prof: profile, should_benchmark: bool = False):
165
super().__init__(prof, should_benchmark)
166
self.name = "Extra CUDA Copy Pattern"
167
self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU."
168
self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device"
178
return not self.prof.with_stack or not self.prof.record_shapes
180
def match(self, event):
182
if event.name != "aten::to":
185
if not event.children:
187
event = event.children[-1]
188
if event.name != "aten::_to_copy":
190
if not event.children:
192
event = event.children[-1]
193
if event.name != "aten::copy_":
196
dtypes = input_dtypes(event)
199
if dtypes[0] is None or dtypes[0] != dtypes[1]:
207
event = self.prev_of(event)
210
while event.children:
211
event = event.children[-1]
213
if event.name in self.init_ops:
215
return event.name in self.init_ops
218
def benchmark(self, events: List[_ProfilerEvent]):
219
shapes_factor_map = {input_shapes(event): 0.0 for event in events}
220
for shape in shapes_factor_map:
222
to_timer = benchmark.Timer(
223
stmt='torch.ones(size).to("cuda")', globals={"size": size}
225
de_timer = benchmark.Timer(
226
stmt='torch.ones(size, device="cuda")', globals={"size": size}
228
to_time = to_timer.timeit(10).mean
229
de_time = de_timer.timeit(10).mean
230
shapes_factor_map[shape] = de_time / to_time
231
return shapes_factor_map
234
class ForLoopIndexingPattern(Pattern):
236
This pattern identifies if we use a for loop to index a tensor that
239
tensor = torch.empty((100, 100))
244
aten::select | ... | aten::select | ... (Repeat)
247
We start at node aten::select, and we check if we can find this alternating patterns.
248
We also keep a dictionary to avoid duplicate match in the for loop.
251
def __init__(self, prof: profile, should_benchmark: bool = False):
252
super().__init__(prof, should_benchmark)
253
self.name = "For Loop Indexing Pattern"
254
self.description = "For loop indexing detected. Vectorization recommended."
255
self.visited: Set[int] = set()
257
def eventTreeTraversal(self):
259
We need to use BFS traversal order to avoid duplicate match.
261
yield from traverse_bfs(self.event_tree)
263
def match(self, event: _ProfilerEvent):
264
if event.name != "aten::select":
266
if event.id in self.visited:
269
_, next = self.siblings_of(event)
274
def same_ops(list1, list2):
275
if len(list1) != len(list2):
277
for op1, op2 in zip(list1, list2):
278
if op1.name != op2.name:
283
next_select_idx = index_of_first_match(next, lambda e: e.name == "aten::select")
284
if next_select_idx is None:
286
indexing_ops = [event] + next[:next_select_idx]
287
next = next[len(indexing_ops) - 1 :]
288
for i in range(0, len(next), len(indexing_ops)):
289
if same_ops(indexing_ops, next[i : i + len(indexing_ops)]):
291
self.visited.add(next[i].id)
294
return repeat_count >= 10
297
class FP32MatMulPattern(Pattern):
298
def __init__(self, prof: profile, should_benchmark: bool = False):
299
super().__init__(prof, should_benchmark)
300
self.name = "FP32 MatMul Pattern"
302
"You are currently using GPU that supports TF32. "
303
"Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'"
305
self.url = "https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
309
if torch.version.hip is not None:
313
has_tf32 = all(int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list())
314
return has_tf32 is False or super().skip or not self.prof.record_shapes
316
def match(self, event: _ProfilerEvent):
318
if event.tag != _EventType.TorchOp:
320
assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
321
if event.name == "aten::mm":
322
if event.extra_fields.allow_tf32_cublas is False:
326
def report(self, event: _ProfilerEvent):
327
return self.description
329
def benchmark(self, events: List[_ProfilerEvent]):
330
shapes_factor_map = {input_shapes(event): 0.0 for event in events}
331
for shape in shapes_factor_map:
332
matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32)
333
matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float32)
334
fp32_timer = benchmark.Timer(
335
stmt="torch.mm(matrixA, matrixB)",
336
globals={"matrixA": matrixA, "matrixB": matrixB},
338
tf32_timer = benchmark.Timer(
339
stmt="torch.mm(matrixA, matrixB)",
340
setup="torch.backends.cuda.matmul.allow_tf32 = True",
341
globals={"matrixA": matrixA, "matrixB": matrixB},
343
torch.backends.cuda.matmul.allow_tf32 = False
344
fp32_time = fp32_timer.timeit(10).mean
345
tf32_time = tf32_timer.timeit(10).mean
346
shapes_factor_map[shape] = tf32_time / fp32_time
347
return shapes_factor_map
350
class OptimizerSingleTensorPattern(Pattern):
352
This pattern identifies if we are using the single-tensor version of an optimizer.
354
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
355
By adding foreach=True to enable multi-tensor optimizer, we can gain speedup when
356
the kernels are relatively small.
359
XXXXX: _single_tenser_<OPTIMIZER_NAME>
365
def __init__(self, prof: profile, should_benchmark: bool = False):
366
super().__init__(prof, should_benchmark)
367
self.name = "Optimizer Single Tensor Pattern"
368
self.optimizers_with_foreach = ["adam", "sgd", "adamw"]
370
"Deteced optimizer running with single tensor implementation. "
371
"Please enable multi tensor implementation by passing 'foreach=True' into optimizer."
375
def match(self, event: _ProfilerEvent):
376
for optimizer in self.optimizers_with_foreach:
377
if event.name.endswith(f"_single_tensor_{optimizer}"):
382
class SynchronizedDataLoaderPattern(Pattern):
384
This pattern identifies if we are using num_workers=0 in DataLoader.
386
torch.utils.data.DataLoader(dataset, batch_size=batch_size)
387
Add num_workers=N to the arguments. N depends on system configuration.
390
dataloader.py(...): __iter__
391
dataloader.py(...): _get_iterator
392
NOT dataloader.py(...): check_worker_number_rationality
395
If we don't see check_worker_number_rationality call in the dataloader __iter__,
396
It is not an asynchronous dataloader.
400
def __init__(self, prof: profile, should_benchmark: bool = False):
401
super().__init__(prof, should_benchmark)
402
self.name = "Synchronized DataLoader Pattern"
404
"Detected DataLoader running with synchronized implementation. "
405
"Please enable asynchronous dataloading by setting num_workers > 0 when initializing DataLoader."
408
"https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
409
"#enable-async-data-loading-and-augmentation"
412
def match(self, event: _ProfilerEvent):
413
def is_dataloader_function(name: str, function_name: str):
414
return name.startswith(
415
os.path.join("torch", "utils", "data", "dataloader.py")
416
) and name.endswith(function_name)
423
except UnicodeDecodeError:
426
if not is_dataloader_function(event.name, "__iter__"):
428
if not event.children:
430
event = event.children[0]
431
if not is_dataloader_function(event.name, "_get_iterator"):
433
if not event.children:
435
event = event.children[0]
436
return not is_dataloader_function(event.name, "check_worker_number_rationality")
440
class GradNotSetToNonePattern(Pattern):
442
This pattern identifies if we are not setting grad to None in zero_grad.
444
optimizer.zero_grad()
445
By setting set_to_none=True, we can gain speedup
452
aten::zero_ is called on each parameter in the model.
453
We also want to make sure it is not called by aten::zeros.
459
def __init__(self, prof: profile, should_benchmark: bool = False):
460
super().__init__(prof, should_benchmark)
461
self.name = "Gradient Set To Zero Instead of None Pattern"
463
"Detected gradient set to zero instead of None. "
464
"Please add 'set_to_none=True' when calling zero_grad()."
467
"https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
468
"#disable-gradient-calculation-for-validation-or-inference"
471
def match(self, event: _ProfilerEvent):
472
if not event.name.endswith(": zero_grad"):
474
if not event.children:
477
for sub_event in traverse_dfs(event.children):
479
sub_event.name == "aten::zero_"
480
and sub_event.parent.name != "aten::zeros"
487
class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern):
489
This pattern identifies if we are enabling bias in Conv2d which is followed by BatchNorm2d.
490
Bias doesn't do anything when followed by batchnorm.
492
nn.Module: Conv2d | nn.Module: BatchNorm2d
494
aten::conv2d AND dtype of third argument is not null
495
The third argument is the bias
500
def __init__(self, prof: profile, should_benchmark: bool = False):
501
super().__init__(prof, should_benchmark)
502
self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern"
503
self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d."
505
"https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
506
"#disable-bias-for-convolutions-directly-followed-by-a-batch-norm"
511
return self.prof.record_shapes is False or super().skip
513
def match(self, event: _ProfilerEvent):
514
if event.name != "aten::conv2d":
516
if len(input_dtypes(event)) < 3 or input_dtypes(event)[2] is None:
519
event = self.go_up_until(
520
event, lambda e: e.name.startswith("nn.Module: Conv2d")
524
event = self.next_of(event)
527
return event.name.startswith("nn.Module: BatchNorm2d")
530
class MatMulDimInFP16Pattern(Pattern):
531
def __init__(self, prof: profile, should_benchmark: bool = False):
532
super().__init__(prof, should_benchmark)
533
self.name = "Matrix Multiplication Dimension Not Aligned Pattern"
534
self.description = "Detected matmul with dimension not aligned. Please use matmul with aligned dimension."
535
self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-mixed-precision-and-amp"
539
return not self.prof.with_stack or not self.prof.record_shapes
541
def match(self, event: _ProfilerEvent):
542
def mutiple_of(shapes, multiple):
543
return all(dim % multiple == 0 for shape in shapes for dim in shape[-2:])
545
if event.name not in ("aten::mm", "aten::bmm", "aten::addmm"):
547
if not input_dtypes(event):
549
arg_dtype = input_dtypes(event)[0]
550
if arg_dtype in (torch.bfloat16, torch.half) and not mutiple_of(
551
input_shapes(event), 8
556
def benchmark(self, events: List[_ProfilerEvent]):
557
def closest_multiple(shapes, multiple):
558
return [multiple * math.ceil(shape / multiple) for shape in shapes]
560
shapes_factor_map = {input_shapes(event): 0.0 for event in events}
561
for shape in shapes_factor_map:
562
matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float16)
563
matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float16)
564
not_aligned_dim_timer = benchmark.Timer(
565
stmt="torch.mm(matrixA, matrixB)",
566
globals={"matrixA": matrixA, "matrixB": matrixB},
568
matrixA = torch.randn(
569
closest_multiple(shape[0], 8), device="cuda", dtype=torch.float16
571
matrixB = torch.randn(
572
closest_multiple(shape[1], 8), device="cuda", dtype=torch.float16
574
aligned_dim_timer = benchmark.Timer(
575
stmt="torch.mm(matrixA, matrixB)",
576
globals={"matrixA": matrixA, "matrixB": matrixB},
578
not_aligned_dim_time = not_aligned_dim_timer.timeit(10).mean
579
aligned_dim_time = aligned_dim_timer.timeit(10).mean
580
shapes_factor_map[shape] = aligned_dim_time / not_aligned_dim_time
581
return shapes_factor_map
584
def source_code_location(event: Optional[_ProfilerEvent]):
586
if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall:
588
event.extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall)
590
if not event.extra_fields.caller.file_name.startswith("torch" + os.sep):
591
return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}"
593
return "No source code location found"
596
def input_shapes(event: _ProfilerEvent):
597
assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
598
return tuple(tuple(getattr(i, "sizes", ())) for i in event.extra_fields.inputs)
601
def input_dtypes(event: _ProfilerEvent):
602
assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
603
return tuple(getattr(i, "dtype", None) for i in event.extra_fields.inputs)
606
def report_all_anti_patterns(
608
should_benchmark: bool = False,
609
print_enable: bool = True,
610
json_report_dir: Optional[str] = None,
612
report_dict: Dict = {}
614
ExtraCUDACopyPattern(prof, should_benchmark),
616
FP32MatMulPattern(prof, should_benchmark),
617
OptimizerSingleTensorPattern(prof, should_benchmark),
618
SynchronizedDataLoaderPattern(prof, should_benchmark),
619
GradNotSetToNonePattern(prof, should_benchmark),
620
Conv2dBiasFollowedByBatchNorm2dPattern(prof, should_benchmark),
621
MatMulDimInFP16Pattern(prof, should_benchmark),
625
message_list = [f"{'-'*40}TorchTidy Report{'-'*40}"]
626
message_list.append("Matched Events:")
628
for anti_pattern in anti_patterns:
629
matched_events = anti_pattern.matched_events()
630
if not matched_events:
632
summaries.append(anti_pattern.summary(matched_events))
633
for event in matched_events:
634
report_msg = anti_pattern.report(event)
635
if report_msg not in reported:
636
message_list.append(report_msg)
637
reported.add(report_msg)
638
src_location, line_no = source_code_location(event).split(":")
639
report_dict.setdefault(src_location, []).append(
641
"line_number": int(line_no),
642
"name": anti_pattern.name,
643
"url": anti_pattern.url,
644
"message": anti_pattern.description,
648
if json_report_dir is not None:
649
json_report_path = os.path.join(json_report_dir, "torchtidy_report.json")
650
if os.path.exists(json_report_path):
651
with open(json_report_path) as f:
652
exisiting_report = json.load(f)
653
exisiting_report.update(report_dict)
654
report_dict = exisiting_report
655
with open(json_report_path, "w") as f:
656
json.dump(report_dict, f, indent=4)
658
message_list.append("Summary:")
659
message_list += summaries
660
message_list.append(f"{'-'*40}TorchTidy Report{'-'*40}")
662
print("\n".join(message_list))