pytorch

Форк
0
/
_pattern_matcher.py 
662 строки · 24.2 Кб
1
import json
2
import math
3
import os
4
import re
5
from typing import Dict, List, Optional, Set
6

7
import torch
8
import torch.utils.benchmark as benchmark
9
from torch._C._profiler import (
10
    _EventType,
11
    _ExtraFields_PyCall,
12
    _ExtraFields_PyCCall,
13
    _ExtraFields_TorchOp,
14
    _ProfilerEvent,
15
)
16
from torch.profiler import profile
17
from torch.profiler._utils import index_of_first_match, traverse_bfs, traverse_dfs
18

19

20
class Pattern:
21
    """
22
    Base class for all patterns, subclass this class and implement match()
23
    to define custom patterns.
24

25
    In subclass, define description and skip property.
26
    """
27

28
    def __init__(self, prof: profile, should_benchmark: bool = False):
29
        self.prof = prof
30
        self.should_benchmark = should_benchmark
31
        self.name = "Please specify a name for pattern"
32
        self.description = "Please specify a description for pattern"
33
        self.url = ""
34
        assert prof.profiler is not None and prof.profiler.kineto_results is not None
35
        self.event_tree = prof.profiler.kineto_results.experimental_event_tree()
36
        self.tid_root: Dict[int, List[_ProfilerEvent]] = {}
37
        for event in self.event_tree:
38
            self.tid_root.setdefault(event.start_tid, []).append(event)
39

40
    @property
41
    def skip(self):
42
        return False
43

44
    def report(self, event: _ProfilerEvent):
45
        msg = (
46
            f"{self.description}\n[Source Code Location] {source_code_location(event)}"
47
        )
48
        return msg
49

50
    def eventTreeTraversal(self):
51
        """
52
        Traverse the event tree and yield all events.
53
        Override this method in subclass to customize the traversal.
54
        """
55
        yield from traverse_dfs(self.event_tree)
56

57
    def summary(self, events: List[_ProfilerEvent]):
58
        default_summary = f"{self.name}: {len(events)} events matched."
59
        if self.should_benchmark:
60
            # If benchmark summary is not empty, use it.
61
            return (
62
                self.benchmark_summary(events)
63
                if hasattr(self, "benchmark")  # type: ignore[attr-defined]
64
                else default_summary
65
            )
66
        return default_summary
67

68
    def benchmark_summary(self, events: List[_ProfilerEvent]):
69
        def format_time(time_ns: int):
70
            unit_lst = ["ns", "us", "ms"]
71
            for unit in unit_lst:
72
                if time_ns < 1000:
73
                    return f"{time_ns:.2f} {unit}"
74
                time_ns //= 1000
75
            return f"{time_ns:.2f} s"
76

77
        assert hasattr(self, "benchmark"), "Please implement benchmark()"
78
        shapes_factor_map = self.benchmark(events)  # type: ignore[attr-defined]
79
        original_time = sum(event.duration_time_ns for event in events)
80
        new_time = sum(
81
            shapes_factor_map[input_shapes(event)] * event.duration_time_ns
82
            for event in events
83
        )
84
        return (
85
            f"{self.name}: {len(events)} events matched. "
86
            f"Total Estimated Speedup: {format_time(original_time - new_time)} ({round(original_time/new_time, 2)}X)"
87
        )
88

89
    def match(self, event: _ProfilerEvent):
90
        """
91
        Return True if the event matches the pattern.
92
        This method should be overriden in subclass.
93
        """
94
        raise NotImplementedError
95

96
    def matched_events(self):
97
        if self.skip:
98
            return []
99
        matched_events = []
100
        for event in self.eventTreeTraversal():
101
            if self.match(event):
102
                matched_events.append(event)
103
        return matched_events
104

105
    def root_of(self, event: _ProfilerEvent):
106
        while event.parent:
107
            event = event.parent
108
        return event
109

110
    def siblings_of(self, event: _ProfilerEvent):
111
        if event.parent:
112
            children = event.parent.children
113
        else:
114
            children = self.tid_root[event.start_tid]
115
        index = children.index(event)
116
        return children[:index], children[index + 1 :]
117

118
    def next_of(self, event: _ProfilerEvent):
119
        _, next_events = self.siblings_of(event)
120
        return next_events[0] if next_events else None
121

122
    def prev_of(self, event: _ProfilerEvent):
123
        prev_events, _ = self.siblings_of(event)
124
        return prev_events[-1] if prev_events else None
125

126
    def go_up_until(self, event: _ProfilerEvent, predicate):
127
        if not event:
128
            return None
129
        while event.parent and not predicate(event):
130
            event = event.parent
131
        return event
132

133

134
# Patterns
135

136

137
class NamePattern(Pattern):
138
    def __init__(self, prof: profile, name: str, should_benchmark: bool = False):
139
        super().__init__(prof, should_benchmark)
140
        self.description = f"Matched Name Event: {name}"
141
        self.name = name
142

143
    def match(self, event: _ProfilerEvent):
144
        return re.search(self.name, event.name) is not None
145

146

147
class ExtraCUDACopyPattern(Pattern):
148
    """
149
    This pattern identifies if we creates a constant tensor on CPU and immediately moves it to GPU.
150
    example: torch.zeros((100, 100)).to("cuda")
151

152
    Pattern:
153
    build-in method                 |build-in method
154
        ...                         |    aten::to
155
            aten::fill_/aten::zero_ |        aten::_to_copy
156

157
    Algorithm:
158
    We start at node aten::to, go parent events' previous events,
159
    and check if we have a aten::fill_/aten::zero_ as we keep going down the tree.
160
    We always select the last child in the children list when we go down the tree.
161
    If at any step we failed, it is not a match.
162
    """
163

164
    def __init__(self, prof: profile, should_benchmark: bool = False):
165
        super().__init__(prof, should_benchmark)
166
        self.name = "Extra CUDA Copy Pattern"
167
        self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU."
168
        self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device"
169
        self.init_ops = {
170
            "aten::fill_",
171
            "aten::zero_",
172
            "aten::normal_",
173
            "aten::uniform_",
174
        }
175

176
    @property
177
    def skip(self):
178
        return not self.prof.with_stack or not self.prof.record_shapes
179

180
    def match(self, event):
181
        # TODO: We should also check tensor identities
182
        if event.name != "aten::to":
183
            return False
184
        to_event = event
185
        if not event.children:
186
            return False
187
        event = event.children[-1]
188
        if event.name != "aten::_to_copy":
189
            return False
190
        if not event.children:
191
            return False
192
        event = event.children[-1]
193
        if event.name != "aten::copy_":
194
            return False
195
        # aten::copy_ should have the first 2 args dtype the same
196
        dtypes = input_dtypes(event)
197
        if len(dtypes) < 2:
198
            return False
199
        if dtypes[0] is None or dtypes[0] != dtypes[1]:
200
            return False
201
        event = to_event
202
        # Up one level
203
        event = event.parent
204
        if event is None:
205
            return False
206
        # Check if we have a aten::fill_ in previous leaf
207
        event = self.prev_of(event)
208
        if event is None:
209
            return False
210
        while event.children:
211
            event = event.children[-1]
212
            # aten::zero_ is a special optimzation case where fill_ is not called
213
            if event.name in self.init_ops:
214
                return True
215
        return event.name in self.init_ops
216
        # TODO: Check if tensor is reused
217

218
    def benchmark(self, events: List[_ProfilerEvent]):
219
        shapes_factor_map = {input_shapes(event): 0.0 for event in events}
220
        for shape in shapes_factor_map:
221
            size = shape[0]
222
            to_timer = benchmark.Timer(
223
                stmt='torch.ones(size).to("cuda")', globals={"size": size}
224
            )
225
            de_timer = benchmark.Timer(
226
                stmt='torch.ones(size, device="cuda")', globals={"size": size}
227
            )
228
            to_time = to_timer.timeit(10).mean
229
            de_time = de_timer.timeit(10).mean
230
            shapes_factor_map[shape] = de_time / to_time
231
        return shapes_factor_map
232

233

234
class ForLoopIndexingPattern(Pattern):
235
    """
236
    This pattern identifies if we use a for loop to index a tensor that
237
    can be vectorized.
238
    example:
239
    tensor = torch.empty((100, 100))
240
    for i in range(100):
241
        tensor[i] = i
242

243
    Pattern:
244
    aten::select | ... | aten::select | ... (Repeat)
245

246
    Algorithm:
247
    We start at node aten::select, and we check if we can find this alternating patterns.
248
    We also keep a dictionary to avoid duplicate match in the for loop.
249
    """
250

251
    def __init__(self, prof: profile, should_benchmark: bool = False):
252
        super().__init__(prof, should_benchmark)
253
        self.name = "For Loop Indexing Pattern"
254
        self.description = "For loop indexing detected. Vectorization recommended."
255
        self.visited: Set[int] = set()
256

257
    def eventTreeTraversal(self):
258
        """
259
        We need to use BFS traversal order to avoid duplicate match.
260
        """
261
        yield from traverse_bfs(self.event_tree)
262

263
    def match(self, event: _ProfilerEvent):
264
        if event.name != "aten::select":
265
            return False
266
        if event.id in self.visited:
267
            return False
268
        repeat_count = 1
269
        _, next = self.siblings_of(event)
270
        if len(next) <= 1:
271
            return False
272

273
        # Custom event list matching
274
        def same_ops(list1, list2):
275
            if len(list1) != len(list2):
276
                return False
277
            for op1, op2 in zip(list1, list2):
278
                if op1.name != op2.name:
279
                    return False
280
            return True
281

282
        # Record the ops between two aten::select
283
        next_select_idx = index_of_first_match(next, lambda e: e.name == "aten::select")
284
        if next_select_idx is None:
285
            return False
286
        indexing_ops = [event] + next[:next_select_idx]
287
        next = next[len(indexing_ops) - 1 :]
288
        for i in range(0, len(next), len(indexing_ops)):
289
            if same_ops(indexing_ops, next[i : i + len(indexing_ops)]):
290
                repeat_count += 1
291
                self.visited.add(next[i].id)
292
            else:
293
                break
294
        return repeat_count >= 10
295

296

297
class FP32MatMulPattern(Pattern):
298
    def __init__(self, prof: profile, should_benchmark: bool = False):
299
        super().__init__(prof, should_benchmark)
300
        self.name = "FP32 MatMul Pattern"
301
        self.description = (
302
            "You are currently using GPU that supports TF32. "
303
            "Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'"
304
        )
305
        self.url = "https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
306

307
    @property
308
    def skip(self):
309
        if torch.version.hip is not None:
310
            has_tf32 = False
311
        else:
312
            # Anything less than sm_80 is not Ampere which doesn't support TF32
313
            has_tf32 = all(int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list())
314
        return has_tf32 is False or super().skip or not self.prof.record_shapes
315

316
    def match(self, event: _ProfilerEvent):
317
        # If we saw this pattern once, we don't need to match it again
318
        if event.tag != _EventType.TorchOp:
319
            return False
320
        assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
321
        if event.name == "aten::mm":
322
            if event.extra_fields.allow_tf32_cublas is False:
323
                return True
324
        return False
325

326
    def report(self, event: _ProfilerEvent):
327
        return self.description
328

329
    def benchmark(self, events: List[_ProfilerEvent]):
330
        shapes_factor_map = {input_shapes(event): 0.0 for event in events}
331
        for shape in shapes_factor_map:
332
            matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float32)
333
            matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float32)
334
            fp32_timer = benchmark.Timer(
335
                stmt="torch.mm(matrixA, matrixB)",
336
                globals={"matrixA": matrixA, "matrixB": matrixB},
337
            )
338
            tf32_timer = benchmark.Timer(
339
                stmt="torch.mm(matrixA, matrixB)",
340
                setup="torch.backends.cuda.matmul.allow_tf32 = True",
341
                globals={"matrixA": matrixA, "matrixB": matrixB},
342
            )
343
            torch.backends.cuda.matmul.allow_tf32 = False
344
            fp32_time = fp32_timer.timeit(10).mean
345
            tf32_time = tf32_timer.timeit(10).mean
346
            shapes_factor_map[shape] = tf32_time / fp32_time
347
        return shapes_factor_map
348

349

350
class OptimizerSingleTensorPattern(Pattern):
351
    """
352
    This pattern identifies if we are using the single-tensor version of an optimizer.
353
    example:
354
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
355
    By adding foreach=True to enable multi-tensor optimizer, we can gain speedup when
356
    the kernels are relatively small.
357

358
    Pattern:
359
    XXXXX: _single_tenser_<OPTIMIZER_NAME>
360

361
    Algorithm:
362
    String match
363
    """
364

365
    def __init__(self, prof: profile, should_benchmark: bool = False):
366
        super().__init__(prof, should_benchmark)
367
        self.name = "Optimizer Single Tensor Pattern"
368
        self.optimizers_with_foreach = ["adam", "sgd", "adamw"]
369
        self.description = (
370
            "Deteced optimizer running with single tensor implementation. "
371
            "Please enable multi tensor implementation by passing 'foreach=True' into optimizer."
372
        )
373
        self.url = ""
374

375
    def match(self, event: _ProfilerEvent):
376
        for optimizer in self.optimizers_with_foreach:
377
            if event.name.endswith(f"_single_tensor_{optimizer}"):
378
                return True
379
        return False
380

381

382
class SynchronizedDataLoaderPattern(Pattern):
383
    """
384
    This pattern identifies if we are using num_workers=0 in DataLoader.
385
    example:
386
    torch.utils.data.DataLoader(dataset, batch_size=batch_size)
387
    Add num_workers=N to the arguments. N depends on system configuration.
388

389
    Pattern:
390
    dataloader.py(...): __iter__
391
        dataloader.py(...): _get_iterator
392
            NOT dataloader.py(...): check_worker_number_rationality
393

394
    Algorithm:
395
    If we don't see check_worker_number_rationality call in the dataloader __iter__,
396
    It is not an asynchronous dataloader.
397

398
    """
399

400
    def __init__(self, prof: profile, should_benchmark: bool = False):
401
        super().__init__(prof, should_benchmark)
402
        self.name = "Synchronized DataLoader Pattern"
403
        self.description = (
404
            "Detected DataLoader running with synchronized implementation. "
405
            "Please enable asynchronous dataloading by setting num_workers > 0 when initializing DataLoader."
406
        )
407
        self.url = (
408
            "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
409
            "#enable-async-data-loading-and-augmentation"
410
        )
411

412
    def match(self, event: _ProfilerEvent):
413
        def is_dataloader_function(name: str, function_name: str):
414
            return name.startswith(
415
                os.path.join("torch", "utils", "data", "dataloader.py")
416
            ) and name.endswith(function_name)
417

418
        # TODO: fixme! Due to lifetime issues of the function name, this field might
419
        # actually point to an already freed string when the even is a PyCall.
420
        # Just silently skip this to unblock testing.
421
        try:
422
            event.name
423
        except UnicodeDecodeError:
424
            return False
425

426
        if not is_dataloader_function(event.name, "__iter__"):
427
            return False
428
        if not event.children:
429
            return False
430
        event = event.children[0]
431
        if not is_dataloader_function(event.name, "_get_iterator"):
432
            return False
433
        if not event.children:
434
            return False
435
        event = event.children[0]
436
        return not is_dataloader_function(event.name, "check_worker_number_rationality")
437
        # TODO: We should also check if the loader is bottleneck.
438

439

440
class GradNotSetToNonePattern(Pattern):
441
    """
442
    This pattern identifies if we are not setting grad to None in zero_grad.
443
    example:
444
    optimizer.zero_grad()
445
    By setting set_to_none=True, we can gain speedup
446

447
    Pattern:
448
    XXXXX: _zero_grad
449
        NOT aten::zeros
450
            aten::zero_
451

452
    aten::zero_ is called on each parameter in the model.
453
    We also want to make sure it is not called by aten::zeros.
454

455
    Algorithm:
456
    String match
457
    """
458

459
    def __init__(self, prof: profile, should_benchmark: bool = False):
460
        super().__init__(prof, should_benchmark)
461
        self.name = "Gradient Set To Zero Instead of None Pattern"
462
        self.description = (
463
            "Detected gradient set to zero instead of None. "
464
            "Please add 'set_to_none=True' when calling zero_grad()."
465
        )
466
        self.url = (
467
            "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
468
            "#disable-gradient-calculation-for-validation-or-inference"
469
        )
470

471
    def match(self, event: _ProfilerEvent):
472
        if not event.name.endswith(": zero_grad"):
473
            return False
474
        if not event.children:
475
            return False
476

477
        for sub_event in traverse_dfs(event.children):
478
            if (
479
                sub_event.name == "aten::zero_"
480
                and sub_event.parent.name != "aten::zeros"
481
            ):
482
                return True
483
        # TODO: We should also check if the optimizer's numerical behavior will change.
484
        return False
485

486

487
class Conv2dBiasFollowedByBatchNorm2dPattern(Pattern):
488
    """
489
    This pattern identifies if we are enabling bias in Conv2d which is followed by BatchNorm2d.
490
    Bias doesn't do anything when followed by batchnorm.
491
    Pattern:
492
    nn.Module: Conv2d            | nn.Module: BatchNorm2d
493
        ...
494
            aten::conv2d AND dtype of third argument is not null
495
    The third argument is the bias
496
    Algorithm:
497
    String match
498
    """
499

500
    def __init__(self, prof: profile, should_benchmark: bool = False):
501
        super().__init__(prof, should_benchmark)
502
        self.name = "Enabling Bias in Conv2d Followed By BatchNorm Pattern"
503
        self.description = "Detected bias enabled in Conv2d that is followed by BatchNorm2d. Please set 'bias=False' in Conv2d."
504
        self.url = (
505
            "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html"
506
            "#disable-bias-for-convolutions-directly-followed-by-a-batch-norm"
507
        )
508

509
    @property
510
    def skip(self):
511
        return self.prof.record_shapes is False or super().skip
512

513
    def match(self, event: _ProfilerEvent):
514
        if event.name != "aten::conv2d":
515
            return False
516
        if len(input_dtypes(event)) < 3 or input_dtypes(event)[2] is None:
517
            return False
518
        # This means bias=True
519
        event = self.go_up_until(
520
            event, lambda e: e.name.startswith("nn.Module: Conv2d")
521
        )
522
        if not event:
523
            return False
524
        event = self.next_of(event)
525
        if not event:
526
            return False
527
        return event.name.startswith("nn.Module: BatchNorm2d")
528

529

530
class MatMulDimInFP16Pattern(Pattern):
531
    def __init__(self, prof: profile, should_benchmark: bool = False):
532
        super().__init__(prof, should_benchmark)
533
        self.name = "Matrix Multiplication Dimension Not Aligned Pattern"
534
        self.description = "Detected matmul with dimension not aligned. Please use matmul with aligned dimension."
535
        self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-mixed-precision-and-amp"
536

537
    @property
538
    def skip(self):
539
        return not self.prof.with_stack or not self.prof.record_shapes
540

541
    def match(self, event: _ProfilerEvent):
542
        def mutiple_of(shapes, multiple):
543
            return all(dim % multiple == 0 for shape in shapes for dim in shape[-2:])
544

545
        if event.name not in ("aten::mm", "aten::bmm", "aten::addmm"):
546
            return False
547
        if not input_dtypes(event):
548
            return False
549
        arg_dtype = input_dtypes(event)[0]
550
        if arg_dtype in (torch.bfloat16, torch.half) and not mutiple_of(
551
            input_shapes(event), 8
552
        ):
553
            return True
554
        return False
555

556
    def benchmark(self, events: List[_ProfilerEvent]):
557
        def closest_multiple(shapes, multiple):
558
            return [multiple * math.ceil(shape / multiple) for shape in shapes]
559

560
        shapes_factor_map = {input_shapes(event): 0.0 for event in events}
561
        for shape in shapes_factor_map:
562
            matrixA = torch.randn(shape[0], device="cuda", dtype=torch.float16)
563
            matrixB = torch.randn(shape[1], device="cuda", dtype=torch.float16)
564
            not_aligned_dim_timer = benchmark.Timer(
565
                stmt="torch.mm(matrixA, matrixB)",
566
                globals={"matrixA": matrixA, "matrixB": matrixB},
567
            )
568
            matrixA = torch.randn(
569
                closest_multiple(shape[0], 8), device="cuda", dtype=torch.float16
570
            )
571
            matrixB = torch.randn(
572
                closest_multiple(shape[1], 8), device="cuda", dtype=torch.float16
573
            )
574
            aligned_dim_timer = benchmark.Timer(
575
                stmt="torch.mm(matrixA, matrixB)",
576
                globals={"matrixA": matrixA, "matrixB": matrixB},
577
            )
578
            not_aligned_dim_time = not_aligned_dim_timer.timeit(10).mean
579
            aligned_dim_time = aligned_dim_timer.timeit(10).mean
580
            shapes_factor_map[shape] = aligned_dim_time / not_aligned_dim_time
581
        return shapes_factor_map
582

583

584
def source_code_location(event: Optional[_ProfilerEvent]):
585
    while event:
586
        if event.tag == _EventType.PyCall or event.tag == _EventType.PyCCall:
587
            assert isinstance(
588
                event.extra_fields, (_ExtraFields_PyCall, _ExtraFields_PyCCall)
589
            )
590
            if not event.extra_fields.caller.file_name.startswith("torch" + os.sep):
591
                return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}"
592
        event = event.parent
593
    return "No source code location found"
594

595

596
def input_shapes(event: _ProfilerEvent):
597
    assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
598
    return tuple(tuple(getattr(i, "sizes", ())) for i in event.extra_fields.inputs)
599

600

601
def input_dtypes(event: _ProfilerEvent):
602
    assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
603
    return tuple(getattr(i, "dtype", None) for i in event.extra_fields.inputs)
604

605

606
def report_all_anti_patterns(
607
    prof,
608
    should_benchmark: bool = False,
609
    print_enable: bool = True,
610
    json_report_dir: Optional[str] = None,
611
):
612
    report_dict: Dict = {}
613
    anti_patterns = [
614
        ExtraCUDACopyPattern(prof, should_benchmark),
615
        # ForLoopIndexingPattern(prof, should_benchmark),
616
        FP32MatMulPattern(prof, should_benchmark),
617
        OptimizerSingleTensorPattern(prof, should_benchmark),
618
        SynchronizedDataLoaderPattern(prof, should_benchmark),
619
        GradNotSetToNonePattern(prof, should_benchmark),
620
        Conv2dBiasFollowedByBatchNorm2dPattern(prof, should_benchmark),
621
        MatMulDimInFP16Pattern(prof, should_benchmark),
622
    ]
623
    reported = set()
624
    summaries = []
625
    message_list = [f"{'-'*40}TorchTidy Report{'-'*40}"]
626
    message_list.append("Matched Events:")
627

628
    for anti_pattern in anti_patterns:
629
        matched_events = anti_pattern.matched_events()
630
        if not matched_events:
631
            continue
632
        summaries.append(anti_pattern.summary(matched_events))
633
        for event in matched_events:
634
            report_msg = anti_pattern.report(event)
635
            if report_msg not in reported:
636
                message_list.append(report_msg)
637
                reported.add(report_msg)
638
                src_location, line_no = source_code_location(event).split(":")
639
                report_dict.setdefault(src_location, []).append(
640
                    {
641
                        "line_number": int(line_no),
642
                        "name": anti_pattern.name,
643
                        "url": anti_pattern.url,
644
                        "message": anti_pattern.description,
645
                    }
646
                )
647

648
    if json_report_dir is not None:
649
        json_report_path = os.path.join(json_report_dir, "torchtidy_report.json")
650
        if os.path.exists(json_report_path):
651
            with open(json_report_path) as f:
652
                exisiting_report = json.load(f)
653
                exisiting_report.update(report_dict)
654
                report_dict = exisiting_report
655
        with open(json_report_path, "w") as f:
656
            json.dump(report_dict, f, indent=4)
657

658
    message_list.append("Summary:")
659
    message_list += summaries
660
    message_list.append(f"{'-'*40}TorchTidy Report{'-'*40}")
661
    if print_enable:
662
        print("\n".join(message_list))
663

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.