pytorch

Форк
0
4784 строки · 168.3 Кб
1
#!/usr/bin/env python3
2

3
from __future__ import annotations
4

5
import abc
6
import argparse
7
import collections
8
import contextlib
9
import copy
10
import csv
11
import dataclasses
12
import functools
13
import importlib
14
import itertools
15
import logging
16
import os
17
import shutil
18
import signal
19
import subprocess
20
import sys
21
import time
22
import weakref
23
from contextlib import contextmanager
24
from pathlib import Path
25
from typing import (
26
    Any,
27
    Callable,
28
    Generator,
29
    List,
30
    Mapping,
31
    NamedTuple,
32
    Optional,
33
    Sequence,
34
    Tuple,
35
    Type,
36
    TYPE_CHECKING,
37
)
38
from typing_extensions import Self
39
from unittest.mock import MagicMock
40

41
import numpy as np
42
import pandas as pd
43
import psutil
44
import yaml
45
from scipy.stats import gmean, ttest_ind
46
from tqdm.auto import tqdm, trange
47

48
import torch
49
import torch._dynamo
50
import torch._dynamo.utils
51
import torch._export
52
import torch.distributed
53
import torch.multiprocessing as mp
54
from torch._C import _has_cuda as HAS_CUDA, _has_xpu as HAS_XPU
55
from torch._dynamo.profiler import fx_insert_profiling, Profiler
56
from torch._dynamo.testing import (
57
    dummy_fx_compile,
58
    format_speedup,
59
    reset_rng_state,
60
    same,
61
)
62

63

64
try:
65
    from torch._dynamo.utils import (
66
        clone_inputs,
67
        graph_break_reasons,
68
        maybe_enable_compiled_autograd,
69
    )
70
    from torch._inductor.utils import fresh_inductor_cache
71
except ImportError:
72
    from _dynamo.utils import (
73
        clone_inputs,
74
        graph_break_reasons,
75
        maybe_enable_compiled_autograd,
76
    )
77

78
import torch._functorch.config
79
from torch._functorch.aot_autograd import set_model_name
80
from torch._inductor import config as inductor_config, metrics
81
from torch._subclasses.fake_tensor import FakeTensorMode
82
from torch.utils import _pytree as pytree
83
from torch.utils._pytree import tree_map, tree_map_only
84

85

86
try:
87
    import torch_xla
88
    import torch_xla.core.xla_model as xm
89

90
    # This is to woraround the backward issue https://github.com/pytorch/xla/issues/4174
91
    torch_xla._XLAC._init_computation_client()
92
except ImportError:
93
    # ignore the error if torch_xla is not installed
94
    pass
95

96

97
if TYPE_CHECKING:
98
    from torch.onnx._internal.fx import diagnostics
99

100

101
log = logging.getLogger(__name__)
102

103
# We are primarily interested in TF32
104
torch.backends.cuda.matmul.allow_tf32 = True
105

106
# Suppress torch.profiler spam
107
os.environ["KINETO_LOG_LEVEL"] = "5"
108

109
current_name = ""
110
current_device = ""
111
current_onnx_compiler = ""
112
current_batch_size = None
113
output_filename = None
114
disable_output = False
115

116
MAX_DOWNLOAD_ATTEMPTS = 5
117

118

119
class CI(NamedTuple):
120
    backend: str  # aot_eager or inductor
121
    training: bool
122
    dynamic: bool = False
123
    device: str = "cuda"
124

125

126
CI_SKIP_OPTIMIZER = {
127
    # TIMM
128
    "convmixer_768_32",  # accuracy
129
    "hrnet_w18",  # Stack issue in fx
130
    # HF
131
    "pnasnet5large",  # Stack issue in fx
132
    "MobileBertForMaskedLM",  # Stack issue in fx
133
    "MobileBertForQuestionAnswering",  # Stack issue in fx
134
    "PegasusForConditionalGeneration",  # OOM
135
}
136

137
try:
138
    from .fb.common import INTERNAL_CI_SKIP_DYNAMIC_BATCH_ONLY
139
except ImportError:
140
    INTERNAL_CI_SKIP_DYNAMIC_BATCH_ONLY = set()
141

142
CI_SKIP_DYNAMIC_BATCH_ONLY = {
143
    "sam",
144
    # See https://github.com/mindee/doctr/blob/f2114758d529ed8d3d0030581638f0520b6b98d8/doctr/models/detection/core.py#L89
145
    # It iterates over the batch, which is dynamic, and dynamo chokes
146
    # We should be able to graphbreak there.
147
    "doctr_det_predictor",
148
    "dlrm",
149
    "pyhpc_isoneutral_mixing",
150
    "pyhpc_equation_of_state",
151
    "pyhpc_turbulent_kinetic_energy",
152
    "detectron2_fcos_r_50_fpn",
153
    "detectron2_fasterrcnn_r_101_c4",
154
    "detectron2_fasterrcnn_r_101_dc5",
155
    "detectron2_fasterrcnn_r_101_fpn",
156
    "detectron2_fasterrcnn_r_50_c4",
157
    "detectron2_fasterrcnn_r_50_dc5",
158
    "detectron2_fasterrcnn_r_50_fpn",
159
    "hf_T5_generate",
160
    "Reformer",
161
}.union(INTERNAL_CI_SKIP_DYNAMIC_BATCH_ONLY)
162

163
# These models currently fail accuracy with eager Adam optimizer
164
# so we use SGD when running the full benchmarks
165
# https://github.com/pytorch/pytorch/issues/115966
166
BENCHMARK_USE_SGD = {
167
    # TorchBench
168
    "BERT_pytorch",
169
    "LearningToPaint",
170
    "alexnet",
171
    "dcgan",
172
    "demucs",
173
    "densenet121",
174
    "dlrm",
175
    "fastNLP_Bert",
176
    "mobilenet_v2",
177
    "phlippe_densenet",
178
    "phlippe_resnet",
179
    "pytorch_stargan",
180
    "resnet18",
181
    "shufflenet_v2_x1_0",
182
    "speech_transformer",
183
    "squeezenet1_1",
184
    "stable_diffusion_text_encoder",
185
    "timm_efficientdet",
186
    "timm_nfnet",
187
    "timm_regnet",
188
    "timm_vision_transformer",
189
    "timm_vovnet",
190
    "vgg16",
191
    "hf_T5",  # Fails dynamic https://github.com/pytorch/pytorch/issues/115968
192
    # HF
193
    "AlbertForMaskedLM",
194
    "BartForCausalLM",
195
    "BartForConditionalGeneration",
196
    "BlenderbotSmallForCausalLM",
197
    "BlenderbotSmallForConditionalGeneration",
198
    "DebertaV2ForQuestionAnswering",  # eager OOM
199
    "ElectraForCausalLM",
200
    "M2M100ForConditionalGeneration",
201
    "MBartForCausalLM",
202
    "MBartForConditionalGeneration",
203
    "OPTForCausalLM",
204
    "PLBartForCausalLM",
205
    "PLBartForConditionalGeneration",
206
    "PegasusForCausalLM",
207
    "Speech2Text2ForCausalLM",
208
    "TrOCRForCausalLM",
209
    "XGLMForCausalLM",
210
    # TIMM
211
    "adv_inception_v3",
212
    "botnet26t_256",
213
    "cait_m36_384",  # OOM
214
    "coat_lite_mini",
215
    "convit_base",
216
    "dpn107",
217
    "fbnetv3_b",
218
    "gernet_l",
219
    "lcnet_050",
220
    "mixnet_l",
221
    "res2net101_26w_4s",
222
    "res2net50_14w_8s",
223
    "res2next50",
224
    "resnest101e",
225
    "sebotnet33ts_256",
226
    "swsl_resnext101_32x16d",
227
    "tf_efficientnet_b0",
228
    "ghostnet_100",
229
    "gmixer_24_224",
230
    "tinynet_a",
231
}
232

233
# These models OOM in CI
234
# due to the extra memory of Adam optimizer states,
235
# so we fall back to SGD in CI
236
CI_USE_SGD = {
237
    "torchrec_dlrm",
238
    "demucs",
239
    "detectron2_fasterrcnn_r_101_c4",
240
    "detectron2_fasterrcnn_r_101_dc5",
241
    "detectron2_fasterrcnn_r_101_fpn",
242
    "detectron2_fasterrcnn_r_50_c4",
243
    "detectron2_fasterrcnn_r_50_dc5",
244
    "detectron2_fasterrcnn_r_50_fpn",
245
    "detectron2_maskrcnn_r_101_c4",
246
    "detectron2_maskrcnn_r_101_fpn",
247
    "detectron2_maskrcnn_r_50_c4",
248
    "detectron2_maskrcnn_r_50_fpn",
249
    "hf_T5_base",
250
    "hf_clip",
251
    "llama_v2_7b_16h",
252
    "mobilenet_v2_quantized_qat",
253
    "phi_1_5 resnet50_quantized_qat",
254
    "BlenderbotForCausalLM",
255
    "cait_m36_384",
256
    "DALLE2_pytorch",
257
    "moco",
258
    "timm_efficientdet",
259
    "ghostnet_100",
260
    "regnety_002",
261
    "poolformer_m36",
262
    "inception_v3",
263
    "tinynet_a",
264
    "selecsls42b",
265
    "mobilevit_s",
266
    "pytorch_CycleGAN_and_pix2pix",
267
    "vision_maskrcnn",
268
    "resmlp_12_224",
269
    "dlrm",
270
    "resnet50",
271
    "dm_nfnet_f0",
272
    "pit_b_224",
273
    "tf_mixnet_l",
274
}
275

276

277
DO_NOT_CAST_INPUTS = {"stable_diffusion"}
278

279

280
# Maps a benchmark model name to a list of status codes. For any listed entry, we'll
281
# capture TORCH_COMPILE_DEBUG logs in CI runs and preseve them (i.e., for upload) if
282
# the result status matches one listed.
283
CI_PRESERVE_COMPILE_DEBUG = {
284
    # For example:
285
    # "mnasnet1_0": ["fail_accuracy"],
286
}
287

288

289
@functools.lru_cache(maxsize=1)
290
def load_yaml_file(filename):
291
    filepath = os.path.join(os.path.dirname(__file__), filename)
292

293
    with open(filepath) as f:
294
        data = yaml.safe_load(f)
295

296
    internal_file_path = os.path.join(os.path.dirname(__file__), "fb", filename)
297
    if os.path.exists(internal_file_path):
298
        with open(internal_file_path) as f:
299
            internal_data = yaml.safe_load(f)
300
            data.update(internal_data)
301

302
    def flatten(lst):
303
        for item in lst:
304
            if isinstance(item, list):
305
                yield from flatten(item)
306
            else:
307
                yield item
308

309
    def maybe_list_to_set(obj):
310
        if isinstance(obj, dict):
311
            return {k: maybe_list_to_set(v) for k, v in obj.items()}
312
        if isinstance(obj, list):
313
            return set(flatten(obj))
314
        return obj
315

316
    return maybe_list_to_set(data)
317

318

319
def model_specified_by_path(path_and_class_str):
320
    return ":" in path_and_class_str
321

322

323
def load_model_from_path(path_and_class_str):
324
    configs = {}
325
    for kvstr in path_and_class_str.split(","):
326
        k, v = kvstr.split(":")
327
        configs[k] = v
328

329
    for name in ["path", "class"]:
330
        if name not in configs:
331
            raise RuntimeError(
332
                "Invalid --only arguments. Check help message for the correct format"
333
            )
334

335
    path = configs["path"]
336
    class_name = configs["class"]
337

338
    if path[:1] != "/":
339
        raise RuntimeError(
340
            "Use absolute path since dynamo may change the current working directory which makes using relative path tricky"
341
        )
342

343
    spec = importlib.util.spec_from_file_location("module_name", path)
344
    module = importlib.util.module_from_spec(spec)
345
    spec.loader.exec_module(module)
346

347
    model_class = getattr(module, class_name)
348
    assert issubclass(model_class, torch.nn.Module)
349
    model = model_class()
350
    assert hasattr(model, "get_example_inputs")
351
    inputs = model.get_example_inputs()
352
    return model, inputs
353

354

355
def output_csv(filename, headers, row):
356
    global disable_output
357
    if disable_output:
358
        return
359
    if os.path.exists(filename):
360
        with open(filename) as fd:
361
            lines = list(csv.reader(fd)) or [[]]
362
            if headers and len(headers) > len(lines[0]):
363
                # if prior results failed the header might not be filled in yet
364
                lines[0] = headers
365
            else:
366
                headers = lines[0]
367
    else:
368
        lines = [headers]
369
    lines.append([(f"{x:.6f}" if isinstance(x, float) else x) for x in row])
370
    with open(filename, "w") as fd:
371
        writer = csv.writer(fd, lineterminator="\n")
372
        for line in lines:
373
            writer.writerow(list(line) + ["0"] * (len(headers) - len(line)))
374

375

376
def nothing(f):
377
    return f
378

379

380
@functools.lru_cache(None)
381
def patch_torch_manual_seed():
382
    """Make torch manual seed deterministic. Helps with accuracy testing."""
383

384
    def deterministic_torch_manual_seed(*args, **kwargs):
385
        from torch._C import default_generator
386

387
        seed = 1337
388
        if HAS_CUDA:
389
            import torch.cuda
390

391
            if not torch.cuda._is_in_bad_fork():
392
                torch.cuda.manual_seed_all(seed)
393
        if HAS_XPU:
394
            import torch.xpu
395

396
            if not torch.xpu._is_in_bad_fork():
397
                torch.xpu.manual_seed_all(seed)
398
        return default_generator.manual_seed(seed)
399

400
    torch.manual_seed = deterministic_torch_manual_seed
401

402

403
def empty_gpu_cache(device):
404
    """
405
    Explicitly empty gpu cache to avoid OOM in subsequent run.
406
    """
407

408
    if device not in ["cuda", "xpu"]:
409
        log.warning(
410
            "Trying to call the empty_gpu_cache for device: %s, which is not in list [cuda, xpu]",
411
            device,
412
        )
413
        return
414

415
    if device == "cuda":
416
        torch.cuda.empty_cache()
417
    elif device == "xpu":
418
        torch.xpu.empty_cache()
419

420

421
def synchronize():
422
    pass
423

424

425
def summarize_graph_break(filename):
426
    """
427
    Sorts and de-dupes the graphs breaks on the reason string. Note that this
428
    function is just a best effort to reduce the logging information. We could
429
    miss some graph breaks because of de-duping. We can further refine this
430
    function as need arises.
431
    """
432
    log_file = f"{filename.rstrip('.csv')}_graph_breaks.csv"
433
    if os.path.exists(log_file):
434
        df = pd.read_csv(log_file)
435
        df = df.sort_values("reason").drop_duplicates(subset="reason")
436

437
        # Specialize for multi tensor sgd as reason is not identical
438
        multi_tensor_sgd_row = df.loc[df["reason"].str.contains("_multi_tensor_sgd")]
439
        if len(multi_tensor_sgd_row):
440
            df = df[
441
                ~df["reason"].str.contains("_multi_tensor_sgd")
442
            ]  # Drop all sgd rows
443
            df = pd.concat(
444
                [df, pd.DataFrame([multi_tensor_sgd_row.iloc[0]])], axis=0
445
            )  # Add back a single row
446
        df.to_csv(f"{log_file.rstrip('.csv')}_deduped.csv", index=False)
447

448

449
def print_summary(filename, print_dataframe=False):
450
    if not (filename and os.path.exists(filename)):
451
        return
452
    data = pd.read_csv(filename)
453
    if "tag" in data.columns:
454
        for tag in data.tag.unique():
455
            if tag == "0.0000":
456
                continue  # This happens for failed runs
457
            print(f"\nSummary for tag={tag}:")
458
            print_summary_table(data[data.tag == tag], print_dataframe=print_dataframe)
459
    else:
460
        print_summary_table(data, print_dataframe=print_dataframe)
461
    summarize_graph_break(filename)
462

463

464
def print_summary_table(data, print_dataframe=False):
465
    if print_dataframe:
466
        pd.options.display.max_rows = 1000
467
        pd.options.display.max_columns = 1000
468
        pd.options.display.width = 2000
469
        print(data)
470
    width = max(map(len, data.columns))
471
    for col in data.columns:
472
        try:
473
            if col in ("dev", "name", "batch_size", "tag"):
474
                continue
475
            elif col in ("pct_ops", "pct_time"):
476
                print(col.ljust(width), f"{data[col].mean():.3%}")
477
            elif col in ("graphs", "graph_calls", "captured_ops", "total_ops"):
478
                print(col.ljust(width), f"{data[col].mean():.3f}")
479
            elif col in ("compilation_latency"):
480
                print(col.ljust(width), f"mean={data[col].mean():.3f} seconds")
481
            elif col in ("compression_ratio"):
482
                print(col.ljust(width), f"mean={data[col].mean():.3f}x")
483
            elif col in ("accuracy"):
484
                pass_rate = (data[col] == "pass").mean()
485
                print(col.ljust(width), f"pass_rate={100*pass_rate:.2f}%")
486
            else:
487
                cdata = data[col]
488
                print(
489
                    col.ljust(width),
490
                    f"gmean={gmean(cdata):.2f}x mean={cdata.mean():.3f}x",
491
                )
492
        except Exception as e:
493
            pass
494

495

496
def tensor_is_on_xla(tensors):
497
    def visit(x: torch.Tensor):
498
        nonlocal result
499
        if x.device.type == "xla":
500
            result = True
501

502
    result = False
503
    tree_map_only(torch.Tensor, visit, tensors)
504
    return result
505

506

507
def timed(
508
    model,
509
    model_iter_fn,
510
    example_inputs,
511
    times=1,
512
    return_result=False,
513
    collect_outputs=False,
514
):
515
    use_xla = tensor_is_on_xla(example_inputs)
516
    synchronize()
517

518
    if use_xla:
519
        xm.mark_step()
520
        xm.wait_device_ops()
521

522
    time_total = 0
523
    # Dont collect outputs to correctly measure timing
524
    for _ in range(times):
525
        # Put this call inside the loop to reset the seed for each iteration.
526
        # Don't include reset_rng_state() to correctly measure timing
527
        reset_rng_state(use_xla)
528
        t_iter_begin = time.perf_counter()
529
        result = model_iter_fn(model, example_inputs, collect_outputs=collect_outputs)
530

531
        # instead of calling sync on result_list, we should call mark_step.
532
        # In training case, result_list may be empty, but we want to
533
        # send all the pending graphs for compilation.
534
        if use_xla:
535
            # For the model running on regular torchxla (baseline), we need the
536
            # mark step to send the accumulated graph for compilation.
537
            #
538
            # For the model running with dynamo/torchxla bridge, in training case,
539
            # we need the mark step to send the optimizer graph out for
540
            # compilation.
541
            xm.mark_step()
542
        t_iter_end = time.perf_counter()
543
        time_total += t_iter_end - t_iter_begin
544

545
    t_0 = time.perf_counter()
546
    if use_xla:
547
        xm.wait_device_ops()
548
    synchronize()
549
    t_1 = time.perf_counter()
550
    time_total += t_1 - t_0
551
    return (time_total, result) if return_result else time_total
552

553

554
def _normalize_bench_inputs(example_inputs) -> Tuple[Tuple[Any], Mapping[str, Any]]:
555
    # NOTE(bowbao): For huggingface benchmark, example_inputs are formatted as dictionary,
556
    # and consumed like `model(**example_inputs)`.
557
    # For other benchmarks, example_inputs are formatted as tuple and consumed
558
    # like `model(*example_inputs)`.
559
    if isinstance(example_inputs, dict):
560
        return (), example_inputs
561
    else:
562
        return tuple(example_inputs), {}
563

564

565
def _register_dataclass_output_as_pytree(example_outputs) -> None:
566
    # NOTE(angelayi): For huggingface benchmark, some example outputs are
567
    # formatted as a dataclass which pytree cannot consume. So we want
568
    # to register the pytree implementation here
569
    example_outputs_flat = pytree.tree_leaves(example_outputs)
570
    output_dataclass_types = [
571
        type(out) for out in example_outputs_flat if dataclasses.is_dataclass(type(out))
572
    ]
573
    for output_type in output_dataclass_types:
574
        from torch._export.utils import register_dataclass_as_pytree_node
575

576
        register_dataclass_as_pytree_node(
577
            output_type,
578
            serialized_type_name=f"{output_type.__module__}.{output_type.__name__}",
579
        )
580

581

582
class Stats:
583
    totals = collections.defaultdict(collections.Counter)
584

585
    @classmethod
586
    def reset_counters(cls):
587
        for k, v in torch._dynamo.utils.counters.items():
588
            cls.totals[k].update(v)
589
        ok = torch._dynamo.utils.counters["frames"]["ok"]
590
        total = torch._dynamo.utils.counters["frames"]["total"]
591
        torch._dynamo.utils.counters.clear()
592
        return ok, total
593

594
    @classmethod
595
    def print_summary(cls):
596
        for k, v in sorted(cls.totals.items()):
597
            lines = "\n  ".join(map(str, v.most_common(50)))
598
            print(f"STATS {k}\n  {lines}")
599

600
    @classmethod
601
    def aot_summary(cls):
602
        return [cls.totals["aot_autograd"]["total"], cls.totals["aot_autograd"]["ok"]]
603

604

605
def coverage_experiment(args, model_iter_fn, model, example_inputs):
606
    """
607
    Test operator/model coverage of TorchDynamo and record statistics
608
    taken from a profiler.  This target is mainly intended to check
609
    correctness.
610

611
    Writes to ./coverage.csv
612
    """
613
    profiler = Profiler()
614
    frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
615
    with profiler.prof:
616
        frozen_model_iter_fn(model, example_inputs)
617
    coverage_result = profiler.results()
618
    output_csv(
619
        output_filename,
620
        (
621
            "dev",
622
            "name",
623
            "batch_size",
624
            "graphs",
625
            "graph_calls",
626
            "captured_ops",
627
            "total_ops",
628
            "pct_ops",
629
            "pct_time",
630
        ),
631
        [
632
            current_device,
633
            current_name,
634
            current_batch_size,
635
        ]
636
        + coverage_result.tocsv(),
637
    )
638
    return coverage_result
639

640

641
def speedup_experiment_fx2trt(args, model_iter_fn, model, example_inputs):
642
    """
643
    Measure speedups over eager using the trt inference backend. TRT backend is based fx graph
644
    generated by torch._dynamo.
645
    Writes to ./speedups_fx2trt.csv
646
    """
647
    return speedup_experiment(args, model_iter_fn, model, example_inputs)
648

649

650
def recompile_profiler_experiment(args, model_iter_fn, model, example_inputs):
651
    prof = torch._dynamo.utils.CompilerProfiler()
652
    opt_model_iter_fn = torch._dynamo.optimize(prof, nopython=args.nopython)(
653
        model_iter_fn
654
    )
655
    opt_model_iter_fn(model, example_inputs)
656
    output_csv(
657
        output_filename, ["model", "profiler report"], [current_name, prof.report()]
658
    )
659
    met = prof.get_metrics()
660
    guard_failures = len(met["guard_failures"])
661
    return [guard_failures]
662

663

664
def randomize_input(inputs):
665
    if isinstance(inputs, (list, tuple)):
666
        return type(inputs)([randomize_input(x) for x in inputs])
667
    elif isinstance(inputs, torch.Tensor):
668
        if inputs.dtype in (torch.float32, torch.float64):
669
            torch._dynamo.utils.counters["randomize_input"]["times"] += 1
670
            return torch.randn_like(inputs)
671
        elif inputs.dtype == torch.int64:
672
            # Note: we can not simply tune integer tensors as follows
673
            #   `return torch.randint_like(inputs, high=inputs.max().item())`
674
            # This may break some invariants between tensors.
675
            # E.g. in embedding lookup case, one tensor is the length
676
            # and another is an indices tensor.
677
            return inputs
678
        else:
679
            raise RuntimeError(
680
                f"randomize_input need support tensor of type {inputs.dtype}"
681
            )
682
    else:
683
        raise RuntimeError(
684
            f"randomize_input can not handle input of type {type(inputs)}"
685
        )
686

687

688
def maybe_mark_step(args):
689
    if args.trace_on_xla:
690
        xm.mark_step()
691

692

693
def latency_experiment(args, model_iter_fn, model, example_inputs, mark, **kwargs):
694
    """
695
    Measure latency on a specific backend.
696
    """
697

698
    timings = np.zeros((args.repeat,), np.float64)
699
    # if we randomize the input, we should also check the result is correct
700
    should_randomize_input = args.randomize_input
701

702
    import contextlib
703

704
    from torch._inductor.utils import maybe_profile
705

706
    @contextlib.contextmanager
707
    def maybe_mark_profile(*args, **kwargs):
708
        prof: torch.profiler.profile = kwargs.pop("p", None)
709
        mark = kwargs.pop("mark", None)
710
        if prof:
711
            with torch.profiler.record_function(mark):
712
                yield
713
        else:
714
            yield
715

716
    times = args.iterations_per_run
717

718
    with maybe_profile(args.export_profiler_trace) as p:
719
        for rep in trange(args.repeat, desc="running benchmark"):
720
            inputs = (
721
                randomize_input(copy.deepcopy(example_inputs))
722
                if should_randomize_input
723
                else example_inputs
724
            )
725
            # need call mark_step to perform the computation
726
            # on randomize_input. Otherwise the first call using the
727
            # inputs will incur high penalty then the next one.
728
            maybe_mark_step(args)
729

730
            with maybe_mark_profile(p=p, mark=mark), maybe_enable_compiled_autograd(
731
                args.compiled_autograd,
732
                fullgraph=args.nopython,
733
                dynamic=args.dynamic_shapes,
734
            ):
735
                timings[rep], actual_output = timed(
736
                    model,
737
                    model_iter_fn,
738
                    inputs,
739
                    return_result=True,
740
                    times=times,
741
                    collect_outputs=args.collect_outputs,
742
                )
743

744
    if args.export_profiler_trace:
745
        name = args.profiler_trace_name + "_" + model.name
746
        if hasattr(args, "rank"):
747
            name += f"_rank_{args.rank}"
748
        name += ".json"
749
        name = os.path.join(torch._dynamo.config.base_dir, name)
750
        p.export_chrome_trace(name)
751
    return timings
752

753

754
def latency_experiment_summary(args, model, timings, **kwargs):
755
    median = np.median(timings, axis=0)
756
    speedup = median[0] / median[1]
757
    if args.dump_raw_metrics:
758
        np.save(
759
            f"{output_filename[:-4]}-raw_timings-{current_name}-{current_device}.npy",
760
            timings,
761
        )
762

763
    first_headers = ["dev", "name", "batch_size"]
764
    first_fields = [current_device, current_name, current_batch_size]
765
    if "tag" in kwargs:
766
        first_headers.append("tag")
767
        first_fields.append(kwargs["tag"])
768
    headers = first_headers + ["speedup", "abs_latency"]
769
    row = first_fields + [float(speedup), median[1] * 1000]
770
    msg = f"{speedup:.3f}x"
771
    if args.baseline:
772
        headers.extend(
773
            [
774
                "baseline",
775
                "speedup_vs_baseline",
776
            ]
777
        )
778
        df = pd.read_csv(args.baseline)
779
        try:
780
            baseline_speedup = df[df["name"] == current_name]["speedup"].item()
781
            row.extend([baseline_speedup, speedup / baseline_speedup])
782
            msg = f"{baseline_speedup:.3f}x -> {speedup:.3f}x [{speedup / baseline_speedup:.3f}x]"
783
        except (KeyError, ZeroDivisionError):
784
            row.extend(
785
                [
786
                    0.0,
787
                    0.0,
788
                ]
789
            )
790
    if "compilation_latency" in kwargs:
791
        headers += [
792
            "compilation_latency",
793
            "compression_ratio",
794
            "eager_peak_mem",
795
            "dynamo_peak_mem",
796
        ]
797
        row.append(kwargs["compilation_latency"])
798
        row.append(kwargs["compression_ratio"])
799
        row.append(kwargs["eager_peak_mem"])
800
        row.append(kwargs["dynamo_peak_mem"])
801

802
    if "cache_lookup_latency" in kwargs:
803
        headers.append("cache_lookup_latency")
804
        row.append(kwargs["cache_lookup_latency"])
805

806
    if "dynamo_stats" in kwargs:
807
        for k, v in kwargs["dynamo_stats"].items():
808
            headers.append(k)
809
            row.append(v)
810
    output_csv(
811
        output_filename,
812
        headers,
813
        row,
814
    )
815
    headers, data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True)
816
    assert (
817
        output_filename.find(".csv") > 0
818
    ), f"expected output_filename to be a .csv, but got {output_filename}"
819
    output_csv(
820
        output_filename[:-4] + "_compilation_metrics.csv",
821
        first_headers + headers,
822
        first_fields + data,
823
    )
824
    return msg
825

826

827
def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
828
    """
829
    Measure speedups over eager.
830

831
    Writes to ./speedups.csv
832
    """
833
    # if args.dynamic_shapes:
834
    #     return speedup_experiment_ds(args, model_iter_fn, model, example_inputs)
835

836
    timings = np.zeros((args.repeat, 2), np.float64)
837
    # if we randomize the input, we should also check the result is correct
838
    should_randomize_input = args.randomize_input
839

840
    import contextlib
841

842
    from torch._inductor.utils import maybe_profile
843

844
    @contextlib.contextmanager
845
    def maybe_mark_profile(*args, **kwargs):
846
        prof: torch.profiler.profile = kwargs.pop("p", None)
847
        mark = kwargs.pop("mark", None)
848
        if prof:
849
            with torch.profiler.record_function(mark):
850
                yield
851
        else:
852
            yield
853

854
    times = args.iterations_per_run
855

856
    # Use higher tolerance for XLA since XLA cause numerical unstability when
857
    # graph size changes
858
    tolerance = args.xla_tolerance if args.trace_on_xla else 1e-4
859
    torch._dynamo.config.repro_tolerance = tolerance
860

861
    with maybe_profile(args.export_profiler_trace) as p:
862
        if args.export_aot_inductor:
863
            frozen_model_iter_fn = export_aot_inductor(
864
                model, example_inputs, args.devices[0]
865
            )
866
        else:
867
            frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
868

869
        for rep in trange(args.repeat, desc="running benchmark"):
870
            inputs = (
871
                randomize_input(copy.deepcopy(example_inputs))
872
                if should_randomize_input
873
                else example_inputs
874
            )
875
            # need call mark_step to perform the computation
876
            # on randomize_input. Otherwise the first call using the
877
            # inputs will incur high penalty then the next one.
878
            maybe_mark_step(args)
879

880
            # interleave the runs to handle frequency scaling and load changes
881
            with maybe_mark_profile(p=p, mark="expected"):
882
                timings[rep, 0], expected_output = timed(
883
                    model,
884
                    model_iter_fn,
885
                    inputs,
886
                    return_result=True,
887
                    times=times,
888
                    collect_outputs=args.collect_outputs,
889
                )
890

891
            # call mark_step between the 2 calls to make the comparison fair.
892
            maybe_mark_step(args)
893

894
            with maybe_mark_profile(p=p, mark="actual"), maybe_enable_compiled_autograd(
895
                args.compiled_autograd,
896
                fullgraph=args.nopython,
897
                dynamic=args.dynamic_shapes,
898
            ):
899
                timings[rep, 1], actual_output = timed(
900
                    model,
901
                    frozen_model_iter_fn,
902
                    inputs,
903
                    return_result=True,
904
                    times=times,
905
                    collect_outputs=args.collect_outputs,
906
                )
907

908
    if args.export_profiler_trace:
909
        name = args.profiler_trace_name + "_" + model.name
910
        if hasattr(args, "rank"):
911
            name += f"_rank_{args.rank}"
912
        name += ".json"
913
        name = os.path.join(torch._dynamo.config.base_dir, name)
914
        p.export_chrome_trace(name)
915
    median = np.median(timings, axis=0)
916
    speedup = median[0] / median[1]
917
    if args.dump_raw_metrics:
918
        np.save(
919
            f"{output_filename[:-4]}-raw_timings-{current_name}-{current_device}.npy",
920
            timings,
921
        )
922

923
    first_headers = ["dev", "name", "batch_size"]
924
    first_fields = [current_device, current_name, current_batch_size]
925
    if "tag" in kwargs:
926
        first_headers.append("tag")
927
        first_fields.append(kwargs["tag"])
928
    headers = first_headers + ["speedup", "abs_latency"]
929
    row = first_fields + [float(speedup), median[1] * 1000]
930
    msg = f"{speedup:.3f}x"
931
    if args.baseline:
932
        headers.extend(
933
            [
934
                "baseline",
935
                "speedup_vs_baseline",
936
            ]
937
        )
938
        df = pd.read_csv(args.baseline)
939
        try:
940
            baseline_speedup = df[df["name"] == current_name]["speedup"].item()
941
            row.extend([baseline_speedup, speedup / baseline_speedup])
942
            msg = f"{baseline_speedup:.3f}x -> {speedup:.3f}x [{speedup / baseline_speedup:.3f}x]"
943
        except (KeyError, ZeroDivisionError):
944
            row.extend(
945
                [
946
                    0.0,
947
                    0.0,
948
                ]
949
            )
950
    if "compilation_latency" in kwargs:
951
        headers += [
952
            "compilation_latency",
953
            "compression_ratio",
954
            "eager_peak_mem",
955
            "dynamo_peak_mem",
956
        ]
957
        row.append(kwargs["compilation_latency"])
958
        row.append(kwargs["compression_ratio"])
959
        row.append(kwargs["eager_peak_mem"])
960
        row.append(kwargs["dynamo_peak_mem"])
961

962
    if "cache_lookup_latency" in kwargs:
963
        headers.append("cache_lookup_latency")
964
        row.append(kwargs["cache_lookup_latency"])
965

966
    if "dynamo_stats" in kwargs:
967
        for k, v in kwargs["dynamo_stats"].items():
968
            headers.append(k)
969
            row.append(v)
970
    output_csv(
971
        output_filename,
972
        headers,
973
        row,
974
    )
975
    headers, data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True)
976
    assert (
977
        output_filename.find(".csv") > 0
978
    ), f"expected output_filename to be a .csv, but got {output_filename}"
979
    output_csv(
980
        output_filename[:-4] + "_compilation_metrics.csv",
981
        first_headers + headers,
982
        first_fields + data,
983
    )
984
    return msg
985

986

987
def speedup_experiment_ds(args, model_iter_fn, model, example_inputs):
988
    """
989
    Run dynamic shapes benchmarks.
990

991
    Requires dynamic shape compatible models, which provide a list of example inputs.
992

993
    Warms up using the first input example and then iterates the inputs,
994
    measuring (and expecting minimal) variance between the runtime for different examples.
995

996
    """
997
    timings = np.zeros((args.repeat, len(example_inputs), 2), np.float64)
998

999
    if args.repeat > 5:
1000
        print(
1001
            f"\ndynamic shapes experiments are slow, consider setting --repeat less than {args.repeat}\n"
1002
        )
1003

1004
    nwarmup = 4
1005
    for rep in range(args.repeat):
1006
        # Start each rep fresh, e.g. only warmup on example 0
1007
        torch._dynamo.reset()
1008
        optimized_model_iter_fn = optimize_ctx(model_iter_fn)
1009
        for _ in range(nwarmup):
1010
            optimized_model_iter_fn(model, example_inputs[0])
1011

1012
        for input_idx, inputs in enumerate(example_inputs):
1013
            # interleave the runs to handle frequency scaling and load changes
1014
            timings[rep, input_idx, 0] = timed(
1015
                model, model_iter_fn, inputs, return_result=False
1016
            )
1017
            # different from regular speedup_experiment, we _DO_ want to allow recompilation
1018
            timings[rep, input_idx, 1] = timed(
1019
                model, optimized_model_iter_fn, inputs, return_result=False
1020
            )
1021
    medians = np.median(timings, axis=0)
1022
    speedups = list(medians[:, 0] / medians[:, 1])
1023
    speedups_mean = np.mean(speedups)
1024
    speedups_median = np.median(speedups)
1025
    speedups_var = np.var(speedups)
1026

1027
    # TODO this x[0] is not going to work in general but bert only has 1 input
1028
    shapes = [x[0].shape for x in example_inputs]
1029
    shape_keys = sorted(set(shapes))
1030
    shape_speedups = {
1031
        shape: [
1032
            it[1] for it in filter(lambda it: it[0] == shape, zip(shapes, speedups))
1033
        ]
1034
        for shape in shape_keys
1035
    }
1036
    output_str = (
1037
        f"mean: {speedups_mean:.3f}, median: {speedups_median:.3f}, var: {speedups_var:.3f}"
1038
        + "\nSpeedups by shape: "
1039
        + "\n".join(
1040
            [
1041
                f"{shape}: "
1042
                + ", ".join([f"{speedup: .3g}" for speedup in shape_speedups[shape]])
1043
                for shape in shape_keys
1044
            ]
1045
        )
1046
    )
1047
    output_csv(
1048
        output_filename,
1049
        ("dev", "name", "batch_size", "speedup mean", "speedup median", "speedup var"),
1050
        [
1051
            current_device,
1052
            current_name,
1053
            current_batch_size,
1054
            speedups_mean,
1055
            speedups_median,
1056
            speedups_var,
1057
        ],
1058
    )
1059
    return output_str
1060

1061

1062
@contextlib.contextmanager
1063
def override_synchronize_with_onnx_iobinding(iobinding):
1064
    global synchronize
1065
    prev_synchrnoize = synchronize
1066
    try:
1067
        if iobinding is not None:
1068

1069
            def new_synchronize():
1070
                iobinding.synchronize_inputs()
1071
                iobinding.synchronize_outputs()
1072

1073
            synchronize = new_synchronize
1074
        yield
1075
    finally:
1076
        synchronize = prev_synchrnoize
1077

1078

1079
def speedup_experiment_onnx(
1080
    args,
1081
    model_iter_fn,
1082
    onnx_model: OnnxModel,
1083
    model,
1084
    example_inputs,
1085
    **kwargs,
1086
):
1087
    """
1088
    Measure speedups over eager.
1089

1090
    This function is responsible for the following:
1091
        1. Creating iobinding with OnnxModel if device is CUDA, which is essential for perf measurement.
1092
        2. Running ORT with OnnxModel.
1093

1094
    Writes to ./{output_filename}, which should be
1095
        `Path(self.output_dir) / f"{self.compiler}_{suite}_{self.dtype}_{self.mode}_{self.device}_{self.testing}.csv".
1096

1097
    TODO(bowbao): Record export time and export peak memory usage.
1098
    """
1099
    timings = np.zeros((args.repeat, 2), np.float64)
1100
    is_correct = True
1101
    should_randomize_input = args.randomize_input
1102
    times = args.iterations_per_run
1103

1104
    def create_onnx_input_binded_fn(onnx_model: OnnxModel, pt_inputs, example_outputs):
1105
        # Goal is to move the iobinding creation outside of the timer function.
1106
        iobinding, outputs = onnx_model.create_iobinding(pt_inputs, example_outputs)
1107

1108
        def onnxrt_model_iter_fn(model, inputs, collect_outputs=True):
1109
            onnx_model.run_with_iobinding(iobinding, outputs)
1110
            if collect_outputs:
1111
                return outputs
1112

1113
        return onnxrt_model_iter_fn, iobinding
1114

1115
    def create_onnx_fn(onnx_model: OnnxModel, pt_inputs):
1116
        # NOTE: Making perf comparison fair by moving out the i/o adapting part.
1117
        # 1. Pre-adapt `pt_inputs` to `onnx_inputs` here.
1118
        # 2. Drop `onnx_outputs` to `pt_outputs` adapting. Output comparison is not part of perf measurement.
1119
        onnx_inputs = onnx_model.adapt_pt_inputs_to_onnx(pt_inputs)
1120

1121
        def onnxrt_model_iter_fn(model, inputs, collect_outputs=True):
1122
            return onnx_model.run_with_onnx_inputs(onnx_inputs)
1123

1124
        return onnxrt_model_iter_fn
1125

1126
    def timed_onnx(model, onnx_model: OnnxModel, inputs):
1127
        if current_device == "cpu" or onnx_model.is_cpu():
1128
            onnxrt_model_iter_fn = create_onnx_fn(onnx_model, inputs)
1129
            iobinding = None
1130
        else:
1131
            onnxrt_model_iter_fn, iobinding = create_onnx_input_binded_fn(
1132
                onnx_model, inputs, expected_output
1133
            )
1134
        with override_synchronize_with_onnx_iobinding(iobinding):
1135
            return timed(
1136
                model,
1137
                onnxrt_model_iter_fn,
1138
                inputs,
1139
                return_result=True,
1140
                times=times,
1141
                collect_outputs=args.collect_outputs,
1142
            )
1143

1144
    # Insert ONNX warm-up
1145
    inputs = (
1146
        randomize_input(copy.deepcopy(example_inputs))
1147
        if should_randomize_input
1148
        else example_inputs
1149
    )
1150
    _, expected_output = timed(
1151
        model,
1152
        model_iter_fn,
1153
        inputs,
1154
        return_result=True,
1155
        times=times,
1156
        collect_outputs=args.collect_outputs,
1157
    )
1158
    for _ in range(2):
1159
        timed_onnx(model, onnx_model, inputs)
1160

1161
    for rep in range(args.repeat):
1162
        inputs = (
1163
            randomize_input(copy.deepcopy(example_inputs))
1164
            if should_randomize_input
1165
            else example_inputs
1166
        )
1167
        if torch.cuda.device_count() > 1:
1168
            # Manually set correct torch.cuda.current_device to ensure torch.cuda.synchronize() works as intended.
1169
            # When there are more than 1 cuda devices, the first one is used for pytorch eager.
1170
            # The second one is used for onnx ort.
1171
            torch.cuda.set_device(0)
1172
        timings[rep, 0], expected_output = timed(
1173
            model,
1174
            model_iter_fn,
1175
            inputs,
1176
            return_result=True,
1177
            times=times,
1178
            collect_outputs=args.collect_outputs,
1179
        )
1180
        if torch.cuda.device_count() > 1:
1181
            # Manually set correct torch.cuda.current_device to ensure torch.cuda.synchronize() works as intended.
1182
            # When there are more than 1 cuda devices, the first one is used for pytorch eager.
1183
            # The second one is used for onnx ort.
1184
            torch.cuda.set_device(1)
1185
        timings[rep, 1], actual_output = timed_onnx(model, onnx_model, inputs)
1186

1187
    pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue
1188
    median = np.median(timings, axis=0)
1189
    speedup = median[0] / median[1]
1190
    if args.dump_raw_metrics:
1191
        np.save(
1192
            f"{output_filename[:-4]}-raw_timings-{current_name}-{current_device}.npy",
1193
            timings,
1194
        )
1195

1196
    headers = ["dev", "name", "batch_size", "speedup", "abs_latency"]
1197
    row = [
1198
        current_device,
1199
        current_name,
1200
        current_batch_size,
1201
        float(speedup),
1202
        median[1] * 1000,
1203
    ]
1204
    if "compilation_latency" in kwargs:
1205
        headers = headers + ["compilation_latency", "compression_ratio"]
1206
        row.append(kwargs["compilation_latency"])
1207
        row.append(kwargs["compression_ratio"])
1208

1209
    output_csv(
1210
        output_filename,
1211
        headers,
1212
        row,
1213
    )
1214
    headers, data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True)
1215
    assert (
1216
        output_filename.find(".csv") > 0
1217
    ), f"expected output_filename to be a .csv, but got {output_filename}"
1218
    output_csv(
1219
        output_filename[:-4] + "_compilation_metrics.csv",
1220
        ["dev", "name", "batch_size"] + headers,
1221
        [current_device, current_name, current_batch_size] + data,
1222
    )
1223
    return format_speedup(speedup, pvalue, is_correct=is_correct)
1224

1225

1226
def overhead_experiment(*args, model_iter_fn):
1227
    """
1228
    Measure overheads of TorchDynamo by running with no backend (only
1229
    eager+FX), and reporting speedup/slowdown over eager.
1230

1231
    Writes to ./overheads.csv
1232
    """
1233
    return speedup_experiment(*args, model_iter_fn)
1234

1235

1236
def print_fx(gm, example_inputs):
1237
    print(gm.graph)
1238
    return gm
1239

1240

1241
def print_aten_ops(gm, example_inputs):
1242
    from functorch.compile import aot_module
1243

1244
    def trace_printer(gm, _):
1245
        print(gm.graph)
1246
        return gm
1247

1248
    return aot_module(gm, fw_compiler=trace_printer, bw_compiler=trace_printer)
1249

1250

1251
def baselines(models, model_iter_fn, example_inputs, args):
1252
    """
1253
    Common measurement code across all baseline experiments.
1254
    """
1255
    models = list(models)
1256
    for idx, (name, model) in enumerate(models):
1257
        if idx == 0:
1258
            result0 = model_iter_fn(model, example_inputs)
1259
        elif model is not None:
1260
            try:
1261
                result = model_iter_fn(model, example_inputs)
1262
                if same(result0, result):
1263
                    continue
1264
                print(name, "is INCORRECT")
1265
            except Exception:
1266
                log.exception("error checking %s", name)
1267
            models[idx] = (name, None)
1268
    timings = np.zeros((args.repeat, len(models)), np.float64)
1269
    timings.fill(1.0e10)
1270
    for rep in range(args.repeat):
1271
        for idx, (name, model) in enumerate(models):
1272
            if model is not None:
1273
                try:
1274
                    timings[rep, idx] = timed(model, model_iter_fn, example_inputs)
1275
                except Exception:
1276
                    pass
1277
    pvalue = [
1278
        ttest_ind(timings[:, 0], timings[:, i]).pvalue
1279
        for i in range(1, timings.shape[1])
1280
    ]
1281
    median = np.median(timings, axis=0)
1282
    speedup = median[0] / median[1:]
1283
    for idx, (name, model) in enumerate(models[1:]):
1284
        if model is None:
1285
            speedup[idx] = 0.0
1286
    result = " ".join(
1287
        [
1288
            format_speedup(s, p, m is not None)
1289
            for s, p, m in zip(speedup, pvalue, [m for n, m in models[1:]])
1290
        ]
1291
    )
1292
    output_csv(
1293
        output_filename,
1294
        ("dev", "name", "batch_size") + tuple(n for n, m in models[1:]),
1295
        [current_device, current_name, current_batch_size]
1296
        + [f"{x:.4f}" for x in speedup],
1297
    )
1298
    return result
1299

1300

1301
def xla(args, model_iter_fn, model, example_inputs):
1302
    xla_dev = xm.xla_device(devkind=current_device)
1303
    model_xla = copy.deepcopy(model).to("cpu").to(device=xla_dev)
1304
    example_inputs_xla = tree_map_only(
1305
        torch.Tensor, lambda x: x.to("cpu").to(device=xla_dev), example_inputs
1306
    )
1307
    for _ in range(3):  # warmup
1308
        timed(model, model_iter_fn, example_inputs)
1309
        timed(model_xla, model_iter_fn, example_inputs_xla)
1310
    timings = np.zeros((args.repeat, 2), np.float64)
1311
    timings.fill(1.0e10)
1312
    for rep in range(args.repeat):
1313
        timings[rep, 0] = timed(model, model_iter_fn, example_inputs)
1314
        timings[rep, 1] = timed(model_xla, model_iter_fn, example_inputs_xla)
1315

1316
    pvalue = ttest_ind(timings[:, 0], timings[:, 1]).pvalue
1317
    time_baseline, time_xla = np.median(timings, axis=0)
1318
    speedup = time_baseline / time_xla
1319
    output_csv(
1320
        output_filename,
1321
        ("dev", "name", "batch_size", "speedup", "time_baseline", "time_xla"),
1322
        [
1323
            current_device,
1324
            current_name,
1325
            current_batch_size,
1326
            speedup,
1327
            time_baseline,
1328
            time_xla,
1329
        ],
1330
    )
1331
    return format_speedup(speedup, pvalue)
1332

1333

1334
def try_script(model, example_inputs):
1335
    try:
1336
        return torch.jit.script(model)
1337
    except Exception:
1338
        return None
1339

1340

1341
class AOTInductorModelCache:
1342
    cache = {}
1343

1344
    @classmethod
1345
    def load(cls, model, example_inputs, device):
1346
        import torch._inductor
1347
        import torch.export._trace
1348

1349
        key = weakref.ref(model)
1350
        if key not in cls.cache:
1351
            # Register the output dataclass to pytree
1352
            example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1353
            with torch.no_grad():
1354
                # copy.deepcopy is required to prevent any surprising side-effect,
1355
                # see https://github.com/pytorch/pytorch/issues/113029
1356
                example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
1357

1358
            if pytree._is_namedtuple_instance(example_outputs):
1359
                typ = type(example_outputs)
1360
                pytree._register_namedtuple(
1361
                    typ,
1362
                    serialized_type_name=f"{typ.__module__}.{typ.__name__}",
1363
                )
1364
            else:
1365
                _register_dataclass_output_as_pytree(example_outputs)
1366

1367
            gm = torch.export._trace._export(
1368
                model,
1369
                example_args,
1370
                example_kwargs,
1371
                pre_dispatch=True,
1372
                strict=False,
1373
            ).module()
1374
            with torch.no_grad():
1375
                so_path = torch._inductor.aot_compile(
1376
                    gm, example_args, example_kwargs
1377
                )  # type: ignore[arg-type]
1378

1379
            cls.cache[key] = torch._export.aot_load(so_path, device)
1380

1381
        return cls.cache[key]
1382

1383

1384
def export(model, example_inputs):
1385
    example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1386
    example_outputs = model(*example_args, **example_kwargs)
1387
    _register_dataclass_output_as_pytree(example_outputs)
1388

1389
    ep = torch.export.export(model, example_args, example_kwargs)
1390

1391
    def opt_export(_, example_inputs):
1392
        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1393
        return ep.module()(*example_args, **example_kwargs)
1394

1395
    return opt_export
1396

1397

1398
def export_aot_inductor(model, example_inputs, device):
1399
    optimized = AOTInductorModelCache.load(model, example_inputs, device)
1400

1401
    def opt_aot_inductor(_, example_inputs, collect_outputs=False):
1402
        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1403
        return optimized(*example_args, **example_kwargs)
1404

1405
    return opt_aot_inductor
1406

1407

1408
def download_retry_decorator(download_fn):
1409
    """
1410
    Decorator function for applying retry logic to a download function.
1411

1412
    The wrapped function will be called up to 5 times and raises an exception if the function fails each time.
1413
    After each unsuccessful attempt, there is a delay before the next attempt, which is increased linearly with the number of tries.
1414

1415
    Usage:
1416
    @download_retry_decorator
1417
    def download_function(model_name: str):
1418
        # download logic goes here
1419
    """
1420

1421
    @functools.wraps(download_fn)
1422
    def wrapper(self, *args, **kwargs) -> Any:
1423
        tries = 0
1424
        total_allowed_tries = MAX_DOWNLOAD_ATTEMPTS
1425
        while tries <= total_allowed_tries:
1426
            try:
1427
                model = download_fn(self, *args, **kwargs)
1428
                return model
1429
            except Exception as e:
1430
                tries += 1
1431
                if tries <= total_allowed_tries:
1432
                    wait = tries * 30
1433
                    print(
1434
                        f"Failed to load model: {e}. Trying again ({tries}/{total_allowed_tries}) after {wait}s"
1435
                    )
1436
                    time.sleep(wait)
1437
                else:
1438
                    raise RuntimeError(  # noqa: B904
1439
                        f"Failed to load model '{args}' with following error(s): {str(e)}."
1440
                    )
1441

1442
    return wrapper
1443

1444

1445
class OnnxModel(abc.ABC):
1446
    TORCH_TO_NUMPY_DTYPE = {
1447
        torch.float16: np.float16,
1448
        torch.float32: np.float32,
1449
        torch.float64: np.float64,
1450
        torch.uint8: np.uint8,
1451
        torch.int8: np.int8,
1452
        torch.int16: np.int16,
1453
        torch.int32: np.int32,
1454
        torch.int64: np.longlong,
1455
        torch.bool: np.bool_,
1456
    }
1457

1458
    _COMPILER_NAME: str
1459

1460
    def __init__(
1461
        self,
1462
        output_directory,
1463
        model,
1464
        example_inputs,
1465
        dynamic_shapes: bool,
1466
        copy_before_export: bool = False,
1467
        use_experimental_patch: bool = False,
1468
    ):
1469
        """The abstract class for exporting ONNX model.
1470

1471
        Args:
1472
            output_directory: output path
1473
            model: model
1474
            example_inputs: example inputs for exporting
1475
            dynamic_shapes (bool): Whether to export the model with dynamic shapes.
1476
            copy_before_export (bool,): copy before export. Defaults to False.
1477
            use_experimental_patch (bool): Whether to apply torch_onnx patch which exports
1478
                with torch.export and onnx ir. Defaults to False.
1479
        """
1480
        model_name = current_name
1481
        self.copy_before_export = copy_before_export
1482
        self.use_experimental_patch = use_experimental_patch
1483
        # NOTE: torch_onnx patch is using OnnxModelFromTorchScript to export ONNX model.
1484
        if self.use_experimental_patch:
1485
            self._COMPILER_NAME = "torch_onnx_patch"
1486
        self.model_dir = self._generate_onnx_model_directory(
1487
            output_directory, self._COMPILER_NAME, model_name
1488
        )
1489
        self.model_path = str(
1490
            self.model_dir / f"{model_name}_{self._COMPILER_NAME}.onnx"
1491
        )
1492

1493
    def _determine_deepcopy_target_device(self):
1494
        if current_device == "cpu":
1495
            target_device = "cpu"
1496
        else:
1497
            if torch.cuda.device_count() > 1:
1498
                # Copy to another cuda device to avoid OOM.
1499
                target_device = "cuda:1"
1500
            else:
1501
                target_device = "cuda"
1502
        return target_device
1503

1504
    def deepcopy_model_and_inputs_to_device(self, model, example_inputs, target_device):
1505
        # Deepcopy model before export to avoid modification to baseline model.
1506
        # To avoid OOM, the model is first moved to CPU. Both models are then moved to device.
1507
        model_device = next(model.parameters()).device
1508
        model.to("cpu")
1509
        model_copy = copy.deepcopy(model).to(target_device)
1510
        model.to(model_device)
1511

1512
        target_device_example_inputs = tree_map_only(
1513
            torch.Tensor, lambda x: x.to(device=target_device), example_inputs
1514
        )
1515

1516
        return model_copy, target_device_example_inputs
1517

1518
    @classmethod
1519
    def _generate_onnx_model_directory(
1520
        cls, output_directory: str, compiler_name: str, model_name: str
1521
    ) -> Path:
1522
        model_path = Path(
1523
            output_directory,
1524
            ".onnx_models",
1525
            model_name,
1526
            compiler_name,
1527
        )
1528
        if model_path.exists() and model_path.is_dir():
1529
            shutil.rmtree(model_path)
1530
        model_path.mkdir(parents=True, exist_ok=True)
1531
        return model_path
1532

1533
    @abc.abstractmethod
1534
    def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]:
1535
        ...
1536

1537
    @abc.abstractmethod
1538
    def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]:
1539
        ...
1540

1541
    def adapt_pt_inputs_to_onnx(self, pt_inputs) -> Mapping[str, np.ndarray]:
1542
        pt_inputs = self.format_pt_inputs(pt_inputs)
1543
        return {
1544
            ort_input.name: pt_input.cpu().numpy()
1545
            for ort_input, pt_input in zip(self.onnx_session.get_inputs(), pt_inputs)
1546
        }
1547

1548
    def adapt_onnx_outputs_to_pt(self, onnx_outputs: List[np.ndarray]) -> Any:
1549
        pt_outputs = [
1550
            torch.from_numpy(onnx_output).to(current_device)
1551
            for onnx_output in onnx_outputs
1552
        ]
1553
        if len(pt_outputs) == 1:
1554
            return pt_outputs[0]
1555
        return pt_outputs
1556

1557
    def _init_ort_session(self, model_path: str):
1558
        import onnxruntime
1559

1560
        if current_device == "cpu":
1561
            ort_providers = ["CPUExecutionProvider"]
1562
        else:
1563
            # NOTE(bowbao): Reduce OOM by running ORT on another gpu.
1564
            # TODO(bowbao): This works to avoid OOM, but performance is surprisingly very bad.
1565
            cuda_provider_options = {
1566
                "device_id": 1 if torch.cuda.device_count() > 1 else 0,
1567
            }
1568
            ort_providers = [("CUDAExecutionProvider", cuda_provider_options)]
1569
        session_options = onnxruntime.SessionOptions()
1570
        session_options.log_severity_level = 3  # Error
1571

1572
        ort_session = onnxruntime.InferenceSession(
1573
            self.model_path,
1574
            providers=ort_providers,
1575
            sess_options=session_options,
1576
        )
1577
        return ort_session
1578

1579
    def is_cpu(self) -> bool:
1580
        return self.onnx_session.get_providers()[0] == "CPUExecutionProvider"
1581

1582
    def cpu(self) -> Self:
1583
        self.onnx_session.set_providers(["CPUExecutionProvider"])
1584
        return self
1585

1586
    def create_outputs(self, *example_outputs):
1587
        return tuple(torch.empty_like(x) for x in example_outputs)
1588

1589
    def create_iobinding(self, pt_inputs, example_outputs):
1590
        pt_inputs = self.format_pt_inputs(pt_inputs)
1591
        example_outputs = self.format_pt_outputs(example_outputs)
1592

1593
        iobinding = self.onnx_session.io_binding()
1594
        args = [arg.contiguous() for arg in pt_inputs]
1595
        for ort_input, arg in zip(self.onnx_session.get_inputs(), args):
1596
            # NOTE: Run ORT on another cuda device to reduce OOM.
1597
            if torch.cuda.device_count() > 1:
1598
                arg = arg.detach().to("cuda:1")
1599
            device = arg.device
1600
            iobinding.bind_input(
1601
                ort_input.name,
1602
                device.type,
1603
                device.index or 0,
1604
                self.TORCH_TO_NUMPY_DTYPE[arg.dtype],
1605
                arg.size(),
1606
                arg.data_ptr(),
1607
            )
1608

1609
        outputs = self.create_outputs(*example_outputs)
1610
        for ort_output, output in zip(self.onnx_session.get_outputs(), outputs):
1611
            if torch.cuda.device_count() > 1:
1612
                output = output.detach().to("cuda:1")
1613
            device = output.device
1614
            iobinding.bind_output(
1615
                ort_output.name,
1616
                device.type,
1617
                device.index or 0,
1618
                self.TORCH_TO_NUMPY_DTYPE[output.dtype],
1619
                output.size(),
1620
                output.data_ptr(),
1621
            )
1622
        return iobinding, outputs
1623

1624
    def run_with_iobinding(self, iobinding, outputs):
1625
        # 'outputs' are torch empty tensors binded to 'iobinding'.
1626
        self.onnx_session.run_with_iobinding(iobinding)
1627
        return outputs
1628

1629
    def run_with_onnx_inputs(self, onnx_inputs):
1630
        return self.onnx_session.run(None, onnx_inputs)
1631

1632
    @classmethod
1633
    def save_tensor_data(cls, numpy_tensor, output_path):
1634
        from onnx import numpy_helper
1635

1636
        proto_tensor = numpy_helper.from_array(numpy_tensor)
1637
        with open(output_path, "wb") as f:
1638
            f.write(proto_tensor.SerializeToString())
1639

1640
    def run_and_serialize_inputs_outputs(self, pt_inputs):
1641
        test_data_dir = self.model_dir / "test_data_set_0"
1642
        test_data_dir.mkdir(parents=True, exist_ok=True)
1643

1644
        onnx_inputs = self.adapt_pt_inputs_to_onnx(pt_inputs)
1645
        for i, onnx_input in enumerate(onnx_inputs.values()):
1646
            self.save_tensor_data(onnx_input, str(test_data_dir / f"input_{i}.pb"))
1647

1648
        onnx_outputs = self.run_with_onnx_inputs(onnx_inputs)
1649

1650
        for i, onnx_output in enumerate(onnx_outputs):
1651
            self.save_tensor_data(onnx_output, str(test_data_dir / f"output_{i}.pb"))
1652

1653
        return self.adapt_onnx_outputs_to_pt(onnx_outputs)
1654

1655
    def run(self, pt_inputs):
1656
        # NOTE: For CUDA performance testing, use `run_with_iobinding` to exclude memory
1657
        # copying overhead for inputs/outputs between cpu and gpu.
1658
        # Otherwise perf number is inaccurate.
1659
        onnx_inputs = self.adapt_pt_inputs_to_onnx(pt_inputs)
1660
        onnx_outputs = self.run_with_onnx_inputs(onnx_inputs)
1661
        return self.adapt_onnx_outputs_to_pt(onnx_outputs)
1662

1663

1664
class OnnxModelFromTorchScript(OnnxModel):
1665
    """TorchScript based onnx export. `torch.onnx.export`
1666

1667
    TODO(bowbao):
1668
    * large model export failed.
1669
          Onnx Model is larger than 2GB, but exporter makes decision based pt model size, which is
1670
          smaller than 2GB.
1671
    * OOM on slightly larger model.
1672
          Both pt model and ort inference session are on gpu. Attempt has been made to move ORT to
1673
          cuda:1, however ORT perf drop significantly.
1674
          For now running everything with batch_size 1 set in launch script.
1675
    """
1676

1677
    _COMPILER_NAME = "torchscript"
1678

1679
    def __init__(
1680
        self, output_directory, model, example_inputs, dynamic_shapes: bool, **kwargs
1681
    ):
1682
        if dynamic_shapes:
1683
            raise NotImplementedError("NYI dynamic shapes for OnnxModelFromTorchScript")
1684
        super().__init__(
1685
            output_directory, model, example_inputs, dynamic_shapes, **kwargs
1686
        )
1687
        self._export(
1688
            model,
1689
            example_inputs,
1690
            self.model_path,
1691
            opset_version=17,
1692
            do_constant_folding=False,
1693
            verbose=False,
1694
        )
1695
        self.onnx_session = self._init_ort_session(self.model_path)
1696

1697
    def _export(self, model, example_inputs, output_path: str, /, **kwargs) -> None:
1698
        if self.copy_before_export:
1699
            # Deepcopy model before export to avoid modification to baseline model.
1700
            model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1701
                model, example_inputs, self._determine_deepcopy_target_device()
1702
            )
1703

1704
        # Hack for huggingface models (kwargs only).
1705
        if isinstance(example_inputs, dict):
1706

1707
            class WrapperModel(torch.nn.Module):
1708
                def __init__(self, model, keys):
1709
                    super().__init__()
1710
                    self.model = model
1711
                    self.keys = keys
1712

1713
                def forward(self, *args):
1714
                    return self.model(**dict(zip(self.keys, args)))
1715

1716
            model = WrapperModel(model, list(example_inputs.keys()))
1717

1718
        if self.use_experimental_patch:
1719
            import torch_onnx
1720

1721
            torch_onnx.patch_torch(
1722
                error_report=True,
1723
                profile=True,
1724
                dump_exported_program=True,
1725
                artifacts_dir=os.path.dirname(output_path),
1726
            )
1727
        else:
1728
            # make sure the patch is not in effect
1729
            try:
1730
                import torch_onnx
1731

1732
                torch_onnx.unpatch_torch()
1733
            except ImportError:
1734
                pass
1735

1736
        torch.onnx.export(
1737
            model,
1738
            self.format_pt_inputs(example_inputs),
1739
            output_path,
1740
            **kwargs,
1741
        )
1742

1743
    def format_pt_inputs(self, pt_inputs):
1744
        # NOTE(bowbao): For huggingface benchmark, pt_inputs are formatted as dictionary,
1745
        # and consumed like `model(**pt_inputs)`.
1746
        # For other benchmarks, pt_inputs are formatted as tuple and consumed
1747
        # like `model(*pt_inputs)`.
1748
        if isinstance(pt_inputs, dict):
1749
            pt_inputs = list(pt_inputs.values())
1750
        if isinstance(pt_inputs, torch.Tensor):
1751
            pt_inputs = (pt_inputs,)
1752
        return tuple(arg.contiguous() for arg in pt_inputs)
1753

1754
    def format_pt_outputs(self, pt_outputs):
1755
        if isinstance(pt_outputs, torch.Tensor):
1756
            pt_outputs = (pt_outputs,)
1757

1758
        pt_outputs = pytree.tree_leaves(pt_outputs)
1759

1760
        # Hack for huggingface model outputs
1761
        try:
1762
            from transformers import modeling_outputs
1763
        except ImportError:
1764
            pass
1765
        else:
1766

1767
            def _to_tuple(x):
1768
                if isinstance(x, modeling_outputs.ModelOutput):
1769
                    return x.to_tuple()
1770
                return x
1771

1772
            pt_outputs = pytree.tree_map(_to_tuple, pt_outputs)
1773
            pt_outputs = pytree.tree_leaves(pt_outputs)
1774

1775
        return pt_outputs
1776

1777

1778
class OnnxModelFromDynamo(OnnxModel):
1779
    """Dynamo and Fx based export. `torch.onnx.dynamo_export`."""
1780

1781
    _COMPILER_NAME = "dynamo"
1782

1783
    def __init__(
1784
        self, output_directory, model, example_inputs, dynamic_shapes: bool, **kwargs
1785
    ):
1786
        super().__init__(
1787
            output_directory, model, example_inputs, dynamic_shapes, **kwargs
1788
        )
1789
        self._dynamic_shapes = dynamic_shapes
1790
        self._onnx_program = self._export(model, example_inputs, self.model_path)
1791
        # Clear the model proto to save memory.
1792
        # The model proto is saved to disk and no longer needed from `onnx_program`.
1793
        # `onnx_program` is kept for i/o adapter usage.
1794
        self._onnx_program.model_proto.Clear()
1795
        self.onnx_session = self._init_ort_session(self.model_path)
1796

1797
    def _export(
1798
        self, model, example_inputs, output_path: str
1799
    ) -> torch.onnx.ONNXProgram:
1800
        if self.copy_before_export:
1801
            # Deepcopy model before export to avoid modification to baseline model.
1802
            model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1803
                model, example_inputs, self._determine_deepcopy_target_device()
1804
            )
1805

1806
        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1807
        options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
1808
        onnx_program = torch.onnx.dynamo_export(
1809
            model, *example_args, **example_kwargs, export_options=options
1810
        )
1811

1812
        onnx_program.save(output_path)
1813
        return onnx_program
1814

1815
    def format_pt_inputs(self, pt_inputs):
1816
        pt_args, pt_kwargs = _normalize_bench_inputs(pt_inputs)
1817
        return self._onnx_program.adapt_torch_inputs_to_onnx(*pt_args, **pt_kwargs)
1818

1819
    def format_pt_outputs(self, pt_outputs):
1820
        return self._onnx_program.adapt_torch_outputs_to_onnx(pt_outputs)
1821

1822

1823
class OnnxModelFromDynamoAotInline(OnnxModelFromDynamo):
1824
    """Dynamo and Fx based export, with AOT inline post export. `torch.onnx.dynamo_export`."""
1825

1826
    _COMPILER_NAME = "dynamo_aot_inline"
1827

1828
    def _export(
1829
        self, model, example_inputs, output_path: str
1830
    ) -> torch.onnx.ONNXProgram:
1831
        if self.copy_before_export:
1832
            # Deepcopy model before export to avoid modification to baseline model.
1833
            model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1834
                model, example_inputs, self._determine_deepcopy_target_device()
1835
            )
1836

1837
        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1838
        options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
1839
        onnx_program = torch.onnx.dynamo_export(
1840
            model, *example_args, **example_kwargs, export_options=options
1841
        )
1842
        # Apply AOT inline post export.
1843
        # Requires onnx >= 1.15
1844
        import onnx
1845
        import onnx.inliner
1846

1847
        # Workaround for inliner not supporting with models larger than 2GB.
1848
        # Save model to disk first separating out external data,
1849
        # and load back without external data for inliner to work on.
1850
        model_proto = onnx_program.model_proto
1851
        onnx.save_model(model_proto, output_path, save_as_external_data=True)
1852
        model_proto = onnx.load(output_path, load_external_data=False)
1853
        model_proto = onnx.inliner.inline_local_functions(model_proto)
1854
        onnx.save_model(model_proto, output_path)
1855
        return onnx_program
1856

1857

1858
class OnnxModelFromDynamoAotOptimize(OnnxModelFromDynamo):
1859
    """Dynamo and Fx based export, with AOT optimize post export. `torch.onnx.dynamo_export`."""
1860

1861
    _COMPILER_NAME = "dynamo_aot_optimize"
1862

1863
    def _export(
1864
        self, model, example_inputs, output_path: str
1865
    ) -> torch.onnx.ONNXProgram:
1866
        if self.copy_before_export:
1867
            # Deepcopy model before export to avoid modification to baseline model.
1868
            model, example_inputs = self.deepcopy_model_and_inputs_to_device(
1869
                model, example_inputs, self._determine_deepcopy_target_device()
1870
            )
1871

1872
        example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
1873
        options = torch.onnx.ExportOptions(dynamic_shapes=self._dynamic_shapes)
1874
        export_output = torch.onnx.dynamo_export(
1875
            model, *example_args, **example_kwargs, export_options=options
1876
        )
1877

1878
        import onnx
1879
        from onnxscript.rewriter.onnxruntime import rewrite
1880

1881
        model_proto = rewrite(export_output.model_proto)
1882
        onnx.save_model(
1883
            model_proto,
1884
            output_path,
1885
            save_as_external_data=True,
1886
            all_tensors_to_one_file=True,
1887
        )
1888

1889
        return export_output
1890

1891

1892
class _OnnxPatch:
1893
    @classmethod
1894
    def patch_non_tensor_outputs(cls, correct_result, new_result, fp64_outputs):
1895
        """Patch non-tensor outputs to make them comparable with the correct result.
1896

1897
        ONNX model always returns a flat tuple of tensors, but the PyTorch model outputs
1898
        `correct_result` and `fp64_outputs` can be arbitrary types. This function normalizes
1899
        the outputs to make them comparable with the ONNX model output.
1900
        """
1901
        try:
1902
            from transformers import modeling_outputs
1903
        except ImportError:
1904
            has_transformers = False
1905
        else:
1906
            has_transformers = True
1907

1908
        if has_transformers and isinstance(
1909
            correct_result, modeling_outputs.ModelOutput
1910
        ):
1911
            correct_result = correct_result.to_tuple()
1912
            fp64_outputs = fp64_outputs.to_tuple() if fp64_outputs is not None else None
1913
        elif type(correct_result).__name__ in (
1914
            "MaskedLMOutput",
1915
            "Seq2SeqLMOutput",
1916
            "CausalLMOutputWithCrossAttentions",
1917
            "LongformerMaskedLMOutput",
1918
            "Instances",
1919
            "SquashedNormal",
1920
            "Boxes",
1921
            "Normal",
1922
            "TanhTransform",
1923
            "Foo",
1924
            "Variable",
1925
        ):
1926
            # Copied from `same` function in `torch._dynamo.utils`
1927
            correct_result = [
1928
                value
1929
                for key in correct_result.__dict__.keys()
1930
                if (value := getattr(correct_result, key)) is not None
1931
            ]
1932
            fp64_outputs = (
1933
                [
1934
                    value
1935
                    for key in fp64_outputs.__dict__.keys()
1936
                    if (value := getattr(fp64_outputs, key)) is not None
1937
                ]
1938
                if fp64_outputs is not None
1939
                else None
1940
            )
1941

1942
        # Flatten nested tuple of tensors, i.e. past_key_values
1943
        correct_result = pytree.tree_leaves(correct_result)
1944
        # Hack to put results from different runs on same device.
1945
        # This is needed for ONNX CPU fallback benchmark, where PyTorch eager is run on GPU.
1946
        # Assuming outputs from a single run are always on same device!
1947
        devices = [x.device for x in correct_result if isinstance(x, torch.Tensor)]
1948
        assert devices and all(
1949
            x == devices[0] for x in devices
1950
        ), "All tensors must be on same device!"
1951
        device = devices[0]
1952
        new_result = pytree.tree_leaves(new_result)
1953
        new_result = pytree.tree_map(
1954
            lambda x: x.to(device=device) if isinstance(x, torch.Tensor) else x,
1955
            new_result,
1956
        )
1957
        fp64_outputs = pytree.tree_leaves(fp64_outputs)
1958

1959
        return correct_result, new_result, fp64_outputs
1960

1961

1962
@dataclasses.dataclass
1963
class OnnxExportErrorRow:
1964
    device: str
1965
    model_name: str
1966
    batch_size: int
1967
    rule_id: Optional[str] = None
1968
    rule_name: Optional[str] = None
1969
    diagnostic_level: Optional[str] = None
1970
    diagnostic_message: Optional[str] = None
1971
    exception_type_name: Optional[str] = None
1972
    exception_message: Optional[str] = None
1973

1974
    def __post_init__(self):
1975
        assert (
1976
            self.rule_id is not None
1977
            and self.rule_name is not None
1978
            and self.diagnostic_level is not None
1979
            and self.diagnostic_message is not None
1980
        ) or self.exception_type_name, (
1981
            "Either rule_id, rule_name, diagnostic_level and diagnostic_message "
1982
            "must be set or exception_type_name must be set"
1983
        )
1984

1985
    @property
1986
    def headers(self) -> List[str]:
1987
        return [field.name for field in dataclasses.fields(self)]
1988

1989
    @property
1990
    def row(self) -> List[str]:
1991
        return [getattr(self, field.name) for field in dataclasses.fields(self)]
1992

1993

1994
class OnnxExportErrorParser:
1995
    def __init__(self, device: str, model_name: str, batch_size: int):
1996
        self.device = device
1997
        self.model_name = model_name
1998
        self.batch_size = batch_size
1999

2000
    def _qualified_exception_class_name(self, exception: Exception) -> str:
2001
        if exception.__class__.__module__ == "builtins":
2002
            return exception.__class__.__name__
2003
        return f"{exception.__class__.__module__}.{exception.__class__.__name__}"
2004

2005
    def parse_diagnostic_context(
2006
        self,
2007
        diagnostic_context: diagnostics.DiagnosticContext,
2008
    ) -> Generator[OnnxExportErrorRow, Any, Any]:
2009
        from torch.onnx._internal.fx import diagnostics
2010

2011
        for diagnostic in diagnostic_context.diagnostics:
2012
            if diagnostic.level >= diagnostics.levels.ERROR:
2013
                yield OnnxExportErrorRow(
2014
                    device=self.device,
2015
                    model_name=self.model_name,
2016
                    batch_size=self.batch_size,
2017
                    rule_id=diagnostic.rule.id,
2018
                    rule_name=diagnostic.rule.name,
2019
                    diagnostic_level=diagnostic.level.name,
2020
                    diagnostic_message=diagnostic.message,
2021
                )
2022

2023
    def parse_exception(self, exception: Exception) -> OnnxExportErrorRow:
2024
        return OnnxExportErrorRow(
2025
            device=self.device,
2026
            model_name=self.model_name,
2027
            batch_size=self.batch_size,
2028
            exception_type_name=self._qualified_exception_class_name(exception),
2029
            exception_message=str(exception),
2030
        )
2031

2032

2033
@dataclasses.dataclass
2034
class OnnxContext:
2035
    onnx_model: Optional[OnnxModel] = None
2036

2037

2038
def optimize_onnx_ctx(
2039
    output_directory: str,
2040
    onnx_model_cls: Type[OnnxModel],
2041
    run_n_iterations: Callable,
2042
    dynamic_shapes: bool = False,
2043
    copy_before_export: bool = False,
2044
    use_experimental_patch: bool = False,
2045
) -> Callable:
2046
    # NOTE(bowbao): This function creates and returns the onnx version of 'run_n_iterations',
2047
    # which does the following:
2048
    #   1. Export and cache model.
2049
    #   2. Create iobinding for ORT.
2050
    #   3. Run ORT for n iterations.
2051
    # The cached model is stored in 'context' under the returned callable.
2052
    context = OnnxContext()
2053
    test_data_dumped = False
2054

2055
    def run_n_iterations_onnx(model, inputs, n=2):
2056
        from torch.onnx._internal import _exporter_legacy
2057
        from torch.onnx._internal.fx import diagnostics
2058

2059
        # NOTE(bowbao): Capture all export & ort errors and diagnostics.
2060
        # Serialize to csv, to be parsed and summarized later by '._onnx/reporter.py'.
2061
        # TODO: Accuracy mismatch is not reported here in csv.
2062
        assert (
2063
            output_filename.find(".csv") > 0
2064
        ), f"expected output_filename to be a .csv, but got {output_filename}"
2065
        output_error_filename = output_filename[:-4] + "_export_error.csv"
2066
        parser = OnnxExportErrorParser(current_device, current_name, current_batch_size)
2067
        try:
2068
            nonlocal context
2069
            if context.onnx_model is None:
2070
                context.onnx_model = onnx_model_cls(
2071
                    output_directory,
2072
                    model,
2073
                    copy.deepcopy(inputs),
2074
                    dynamic_shapes=dynamic_shapes,
2075
                    copy_before_export=copy_before_export,
2076
                    use_experimental_patch=use_experimental_patch,
2077
                )
2078
            onnx_model = context.onnx_model
2079

2080
            for _ in range(n):
2081
                nonlocal test_data_dumped
2082
                if not test_data_dumped:
2083
                    # Serializes inputs and outputs to .pb files for further offline analysis.
2084
                    # Due to this, this function is not and should not be used for perf measurement.
2085
                    outputs = onnx_model.run_and_serialize_inputs_outputs(inputs)
2086
                    test_data_dumped = True
2087
                else:
2088
                    outputs = onnx_model.run(inputs)
2089
            return outputs
2090
        except _exporter_legacy.OnnxExporterError as e:
2091
            # `torch.onnx.dynamo_export` raises error that encloses diagnostics.
2092
            diagnostic_context = e.onnx_program.diagnostic_context
2093
            for parsed_error in parser.parse_diagnostic_context(diagnostic_context):
2094
                output_csv(
2095
                    output_error_filename, parsed_error.headers, parsed_error.row
2096
                )
2097
            if context.onnx_model is not None:
2098
                e.onnx_program.save_diagnostics(
2099
                    f"{context.onnx_model.model_dir}/"
2100
                    f"{current_onnx_compiler}_{current_name}_{current_device}.sarif"
2101
                )
2102

2103
            # Check also the raw exception that caused export failure.
2104
            # Skip if it is already analyzed by diagnostics.
2105
            cause_of_exception = e.__cause__
2106
            if not isinstance(
2107
                cause_of_exception, diagnostics.RuntimeErrorWithDiagnostic
2108
            ):
2109
                parsed_error = parser.parse_exception(cause_of_exception)
2110
                output_csv(
2111
                    output_error_filename, parsed_error.headers, parsed_error.row
2112
                )
2113
            raise
2114
        except Exception as e:
2115
            # `torch.onnx.export` errors.
2116
            # ORT errors.
2117
            parsed_error = parser.parse_exception(e)
2118
            output_csv(output_error_filename, parsed_error.headers, parsed_error.row)
2119
            raise
2120

2121
    run_n_iterations_onnx.context = context
2122

2123
    return run_n_iterations_onnx
2124

2125

2126
def read_batch_size_from_file(args, filename, model_name):
2127
    batch_size = None
2128
    if os.path.exists("benchmarks"):
2129
        filename = os.path.join("benchmarks", filename)
2130
    assert os.path.exists(filename), filename
2131
    with open(filename) as f:
2132
        lines = f.readlines()
2133
        lines = [i.split(",") for i in lines if len(i.strip()) > 0]
2134
        for val in lines:
2135
            cur_name, b = val
2136
            if model_name == cur_name:
2137
                batch_size = int(b)
2138
    if batch_size is None:
2139
        log.warning("Could not find batch size for %s", model_name)
2140
    elif batch_size == -1:
2141
        raise RuntimeError(
2142
            f"Batch size is unset for {model_name} in {args.batch_size_file}"
2143
        )
2144
    print(f"batch size: {batch_size}")
2145
    return batch_size
2146

2147

2148
class TimeOutException(Exception):
2149
    pass
2150

2151

2152
def alarm_handler(signum, frame):
2153
    raise TimeOutException
2154

2155

2156
def exit_after(s):
2157
    """
2158
    Decorator to raise TimeoutException if the fn is taking more than s seconds
2159
    to run.
2160
    """
2161

2162
    def outer(fn):
2163
        def inner(*args, **kwargs):
2164
            signal.signal(signal.SIGALRM, alarm_handler)
2165
            signal.alarm(s)
2166
            try:
2167
                result = fn(*args, **kwargs)
2168
            finally:
2169
                signal.alarm(0)
2170
            return result
2171

2172
        return inner
2173

2174
    return outer
2175

2176

2177
def get_peak_memory():
2178
    return torch.cuda.max_memory_allocated() / 10**9
2179

2180

2181
def null_experiment(args, model_iter_fn, model, example_inputs):
2182
    """
2183
    A no-op experiment useful for making sure TorchBenchark alone works properly.
2184
    """
2185

2186
    return []
2187

2188

2189
def cast_to(dtype, model, inputs):
2190
    # cast model and inputs to fp16
2191
    if dtype == torch.float16:
2192
        model = model.half()
2193
    else:
2194
        model = model.to(dtype)
2195

2196
    inputs = tree_map(
2197
        lambda x: x.to(dtype)
2198
        if isinstance(x, torch.Tensor) and x.is_floating_point()
2199
        else x,
2200
        inputs,
2201
    )
2202
    return model, inputs
2203

2204

2205
def cast_to_bf16(model, inputs):
2206
    return cast_to(torch.bfloat16, model, inputs)
2207

2208

2209
def cast_to_fp16(model, inputs):
2210
    return cast_to(torch.float16, model, inputs)
2211

2212

2213
def cast_to_fp64(model, inputs):
2214
    return cast_to(torch.float64, model, inputs)
2215

2216

2217
def cast_to_fp32(model, inputs):
2218
    return cast_to(torch.float32, model, inputs)
2219

2220

2221
class DummyGradScaler:
2222
    def scale(self, loss):
2223
        return loss
2224

2225

2226
def get_dynamo_stats():
2227
    # TODO: consider deepcopy'ing the entire counters struct and
2228
    # adding a helper to do subtraction on it
2229
    return collections.Counter(
2230
        {
2231
            "calls_captured": torch._dynamo.utils.counters["stats"]["calls_captured"],
2232
            "unique_graphs": torch._dynamo.utils.counters["stats"]["unique_graphs"],
2233
            "graph_breaks": sum(torch._dynamo.utils.counters["graph_break"].values()),
2234
            # NB: The plus removes zero counts
2235
            "unique_graph_breaks": len(+torch._dynamo.utils.counters["graph_break"]),
2236
            "autograd_captures": torch._dynamo.utils.counters["compiled_autograd"][
2237
                "captures"
2238
            ],
2239
            "autograd_compiles": torch._dynamo.utils.counters["compiled_autograd"][
2240
                "compiles"
2241
            ],
2242
            "cudagraph_skips": torch._dynamo.utils.counters["inductor"][
2243
                "cudagraph_skips"
2244
            ],
2245
        }
2246
    )
2247

2248

2249
@contextmanager
2250
def maybe_init_distributed(should_init_distributed, rank, world_size, port="6789"):
2251
    try:
2252
        if should_init_distributed:
2253
            torch.cuda.set_device(rank)
2254
            os.environ["MASTER_ADDR"] = "localhost"
2255
            os.environ["MASTER_PORT"] = port
2256
            torch.distributed.init_process_group(
2257
                "nccl", rank=rank, world_size=world_size
2258
            )
2259
        yield
2260
    finally:
2261
        if should_init_distributed:
2262
            torch.distributed.destroy_process_group()
2263

2264

2265
@contextmanager
2266
def maybe_snapshot_memory(should_snapshot_memory, suffix):
2267
    # Enables Memory Snapshot tool for memory deep dives:
2268
    # https://pytorch.org/blog/understanding-gpu-memory-1/
2269
    try:
2270
        if should_snapshot_memory:
2271
            torch.cuda.memory._record_memory_history(max_entries=100000)
2272
        yield
2273
    finally:
2274
        if should_snapshot_memory:
2275
            try:
2276
                torch.cuda.memory._dump_snapshot(
2277
                    os.path.join(
2278
                        torch._dynamo.config.base_dir,
2279
                        f"{output_filename.rstrip('.csv')}_{suffix}.pickle",
2280
                    )
2281
                )
2282
            except Exception as e:
2283
                logging.error("Failed to save memory snapshot, %s", e)
2284

2285
            torch.cuda.memory._record_memory_history(enabled=None)
2286

2287

2288
class BenchmarkRunner:
2289
    def __init__(self):
2290
        self.model_iter_fn = None
2291
        self.grad_scaler = DummyGradScaler()
2292
        self.autocast = contextlib.nullcontext
2293
        self.autocast_arg = {}
2294
        self.optimizer = None
2295
        self._args = None
2296

2297
    def setup_amp(self, current_device=None):
2298
        if self.args.only in self.fp32_only_models:
2299
            return
2300

2301
        devices = [current_device] if current_device else self.args.devices
2302
        if self.args.amp:
2303
            # AMP training can lead to small loss values which can undeflow
2304
            # gradient values returning in zero gradients. To solve this
2305
            # problem, PyTorch introduces GradScaler. GradScaler is a stateful
2306
            # structure, that scales the loss values to prevent underflow. Loss
2307
            # values are big at the beginning of training (therefore not
2308
            # requiring scaling), while loss value tends to be small as network
2309
            # starts getting better (requiring scaling). GradScaler manages all
2310
            # of this fine tuning, checking the gradients are turning to inf,
2311
            # discarding such batches.
2312

2313
            # Since we are not running a long iteration, default value of
2314
            # init_scale 65536 is going to turn all gradients to inf. Therefore,
2315
            # we just use a init_scale of 2.0 for benchmarking purpose.
2316

2317
            # Disabling Gradscaler because
2318
            #  1) Benchmark setup runs 2 iterations of fwd-bwd. So, not useful.
2319
            #  2) Current setup shares grad_scaler for eager and dynamo model,
2320
            #  which is bad as Gradscaler has state and can adjust the scaling
2321
            #  factor between eager and dynamo run, making accuracy check
2322
            #  harder.
2323
            # self.grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
2324
            self.autocast = functools.partial(
2325
                torch.amp.autocast, device_type=devices[0]
2326
            )
2327
            if self.args.amp_dtype:
2328
                amp_dtype = (
2329
                    torch.float16
2330
                    if self.args.amp_dtype == "float16"
2331
                    else torch.bfloat16
2332
                )
2333
                self.autocast_arg["dtype"] = amp_dtype
2334

2335
    def init_optimizer(self, name, device, params):
2336
        if device == "cuda" and self.args.training and name not in CI_SKIP_OPTIMIZER:
2337
            if (name in CI_USE_SGD and self.args.ci) or name in BENCHMARK_USE_SGD:
2338
                self.optimizer = torch.optim.SGD(params, lr=0.01, foreach=True)
2339
                # Disable multi_tensor_sgd for benchmarking, there isn't a large performance benefit (~1%) to compiling
2340
                # this optimizer because it is a single foreach add, and increases compile time.
2341
                # After autotuning and fake tensor caching lands, we can enable, becuase the compile time impact will be lower.
2342
                # Fake Tensor caching: https://github.com/pytorch/pytorch/pull/113873
2343
                # Autotuning: https://github.com/pytorch/pytorch/issues/117447
2344
                self.optimizer.step = torch._dynamo.disable(self.optimizer.step)
2345
            else:
2346
                self.optimizer = torch.optim.Adam(
2347
                    params, lr=0.01, capturable=True, foreach=True
2348
                )
2349
        else:
2350
            self.optimizer = None
2351

2352
    @property
2353
    def args(self):
2354
        return self._args
2355

2356
    @args.setter
2357
    def args(self, args):
2358
        self._args = args
2359

2360
    @property
2361
    def skip_models(self):
2362
        return set()
2363

2364
    @property
2365
    def skip_models_for_cuda(self):
2366
        return set()
2367

2368
    @property
2369
    def skip_models_for_cpu(self):
2370
        return set()
2371

2372
    @property
2373
    def skip_models_for_freezing(self):
2374
        return set()
2375

2376
    @property
2377
    def slow_models(self):
2378
        return set()
2379

2380
    @property
2381
    def very_slow_models(self):
2382
        return set()
2383

2384
    @property
2385
    def non_deterministic_models(self):
2386
        return set()
2387

2388
    @property
2389
    def fp32_only_models(self):
2390
        return set()
2391

2392
    @property
2393
    def force_amp_for_fp16_bf16_models(self):
2394
        return set()
2395

2396
    @property
2397
    def force_fp16_for_bf16_models(self):
2398
        return set()
2399

2400
    @property
2401
    def skip_not_suitable_for_training_models(self):
2402
        return set()
2403

2404
    @property
2405
    def failing_torchinductor_models(self):
2406
        return set()
2407

2408
    @property
2409
    def failing_fx2trt_models(self):
2410
        return set()
2411

2412
    @property
2413
    def skip_accuracy_checks_large_models_dashboard(self):
2414
        return set()
2415

2416
    @property
2417
    def skip_accuracy_check_as_eager_non_deterministic(self):
2418
        return set()
2419

2420
    @property
2421
    def skip_multiprocess_models(self):
2422
        return set()
2423

2424
    @property
2425
    def skip_models_due_to_control_flow(self):
2426
        return set()
2427

2428
    @property
2429
    def guard_on_nn_module_models(self):
2430
        return set()
2431

2432
    @property
2433
    def inline_inbuilt_nn_modules_models(self):
2434
        return set()
2435

2436
    def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
2437
        raise NotImplementedError
2438

2439
    @property
2440
    def equal_nan(self):
2441
        equal_nan = True
2442
        if self.args.float32:
2443
            equal_nan = False
2444
        return equal_nan
2445

2446
    def use_larger_multiplier_for_smaller_tensor(self, name):
2447
        return False
2448

2449
    def iter_models(self, args):
2450
        for model_name in self.iter_model_names(args):
2451
            for device in args.devices:
2452
                try:
2453
                    yield self.load_model(
2454
                        device,
2455
                        model_name,
2456
                        batch_size=args.batch_size,
2457
                    )
2458
                except NotImplementedError:
2459
                    continue  # bad benchmark implementation
2460

2461
    def deepcopy_model(self, model):
2462
        return copy.deepcopy(model)
2463

2464
    def cast_based_on_args(self, model, example_inputs):
2465
        if self.args.float32 or self.args.only in self.fp32_only_models:
2466
            if not self.args.float32:
2467
                log.warning("Model %s supports float32 only", self.args.only)
2468
            model, example_inputs = cast_to_fp32(model, example_inputs)
2469
        elif self.args.float16:
2470
            if self.args.only in self.force_amp_for_fp16_bf16_models:
2471
                log.warning(
2472
                    "Model %s does not support float16, running with amp instead",
2473
                    self.args.only,
2474
                )
2475
                self.args.amp = True
2476
                self.setup_amp()
2477
            else:
2478
                model, example_inputs = cast_to_fp16(model, example_inputs)
2479
        elif self.args.bfloat16:
2480
            if self.args.only in self.force_amp_for_fp16_bf16_models:
2481
                log.warning(
2482
                    "Model %s does not support bfloat16, running with amp instead",
2483
                    self.args.only,
2484
                )
2485
                self.args.amp = True
2486
                self.setup_amp()
2487
            elif self.args.only in self.force_fp16_for_bf16_models:
2488
                log.warning(
2489
                    "Model %s does not support bfloat16, running with float16 instead",
2490
                    self.args.only,
2491
                )
2492
                model, example_inputs = cast_to_fp16(model, example_inputs)
2493
            else:
2494
                model, example_inputs = cast_to_bf16(model, example_inputs)
2495

2496
        return model, example_inputs
2497

2498
    def validate_model(self, model, example_inputs):
2499
        """
2500
        Runs the eager model with example inputs to ensure that eager passes.
2501
        """
2502
        model = self.deepcopy_model(model)
2503
        example_inputs = clone_inputs(example_inputs)
2504
        model, example_inputs = self.cast_based_on_args(model, example_inputs)
2505
        try:
2506
            self.model_iter_fn(model, example_inputs)
2507
        except Exception as e:
2508
            raise RuntimeError("Eager run failed") from e
2509

2510
    def maybe_cast(self, model, example_inputs):
2511
        model, example_inputs = self.cast_based_on_args(model, example_inputs)
2512
        return model, example_inputs
2513

2514
    def decay_batch_exp(self, batch_size, factor=0.5, divisor=2):
2515
        out_batch_size = batch_size * factor
2516
        if out_batch_size > divisor:
2517
            out_batch_size = (out_batch_size + 1) // divisor * divisor
2518
        else:
2519
            out_batch_size = batch_size - 1
2520
        return max(0, int(out_batch_size))
2521

2522
    def batch_size_finder(self, device, model_name, initial_batch_size=1024):
2523
        batch_size = initial_batch_size
2524
        while batch_size >= 1:
2525
            empty_gpu_cache(current_device)
2526
            try:
2527
                device, name, model, example_inputs, _ = self.load_model(
2528
                    device,
2529
                    model_name,
2530
                    batch_size,
2531
                )
2532
                self.model_iter_fn(model, example_inputs)
2533
                return batch_size
2534
            except RuntimeError as e:
2535
                error_str = str(e)
2536
                if "channels_last" in error_str:
2537
                    break
2538
            batch_size = self.decay_batch_exp(batch_size)
2539
        return 1
2540

2541
    def run_n_iterations(self, mod, inputs):
2542
        n = self.args.iterations
2543
        for _ in range(n - 1):
2544
            self.model_iter_fn(mod, inputs, collect_outputs=False)
2545
        return self.model_iter_fn(mod, inputs, collect_outputs=True)
2546

2547
    @torch._disable_dynamo(recursive=True)
2548
    def optimizer_zero_grad(self, mod):
2549
        if self.optimizer is not None:
2550
            self.optimizer.zero_grad(True)
2551
        else:
2552
            mod.zero_grad(True)
2553

2554
    def optimizer_step(self):
2555
        if self.optimizer is not None:
2556
            self.optimizer.step()
2557

2558
    def get_benchmark_indices(self, length):
2559
        start = self._args.partition_id * (length // self._args.total_partitions)
2560
        end = (
2561
            (self._args.partition_id + 1) * (length // self._args.total_partitions)
2562
            if self._args.partition_id < self._args.total_partitions - 1
2563
            else length
2564
        )
2565
        return start, end
2566

2567
    def get_fsdp_auto_wrap_policy(self, model_name: str):
2568
        from diffusers.models.transformer_2d import Transformer2DModel
2569
        from torchbenchmark.models.nanogpt.model import Block
2570
        from transformers.models.llama.modeling_llama import LlamaDecoderLayer
2571
        from transformers.models.t5.modeling_t5 import T5Block
2572
        from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer
2573

2574
        from torch.distributed.fsdp.wrap import (
2575
            ModuleWrapPolicy,
2576
            size_based_auto_wrap_policy,
2577
        )
2578

2579
        # handcrafted wrap policy
2580
        MODEL_FSDP_WRAP = {
2581
            "stable_diffusion_unet": (Transformer2DModel,),
2582
            "hf_T5": (T5Block,),
2583
            "hf_T5_base": (T5Block,),
2584
            "hf_T5_large": (T5Block,),
2585
            "hf_Whisper": (WhisperEncoderLayer,),
2586
            "llama_v2_7b_16h": (LlamaDecoderLayer,),
2587
            "nanogpt": (Block,),
2588
        }
2589

2590
        if model_name not in MODEL_FSDP_WRAP:
2591
            # default to using wrap policy based on module size
2592
            return functools.partial(
2593
                size_based_auto_wrap_policy, recurse=True, min_num_params=int(1e5)
2594
            )
2595

2596
        return ModuleWrapPolicy(MODEL_FSDP_WRAP[model_name])
2597

2598
    def deepcopy_and_maybe_parallelize(self, model):
2599
        model = self.deepcopy_model(model)
2600
        if self.args.ddp:
2601
            assert (
2602
                torch.distributed.is_available()
2603
            ), "Can't use DDP without a distributed enabled build"
2604
            from torch.nn.parallel import DistributedDataParallel as DDP
2605

2606
            model = DDP(model, find_unused_parameters=True)
2607
        elif self.args.fsdp:
2608
            assert (
2609
                torch.distributed.is_available()
2610
            ), "Can't use FSDP without a distributed enabled build"
2611
            from torch.distributed.fsdp import (
2612
                FullyShardedDataParallel as FSDP,
2613
                MixedPrecision,
2614
            )
2615

2616
            if self.args.float16:
2617
                dtype = torch.float16
2618
            elif self.args.bfloat16:
2619
                dtype = torch.bfloat16
2620
            else:
2621
                dtype = torch.float32
2622

2623
            mp_policy = MixedPrecision(
2624
                param_dtype=dtype,
2625
                # Gradient communication precision.
2626
                reduce_dtype=dtype,
2627
                # Buffer precision.
2628
                buffer_dtype=dtype,
2629
            )
2630

2631
            model = FSDP(
2632
                model,
2633
                use_orig_params=True,
2634
                device_id=torch.cuda.current_device()
2635
                if self.args.devices[-1] == "cuda"
2636
                else None,
2637
                mixed_precision=mp_policy,
2638
                limit_all_gathers=True,
2639
                auto_wrap_policy=self.get_fsdp_auto_wrap_policy(self.args.only),
2640
            )
2641
        return model
2642

2643
    def check_accuracy(
2644
        self, name, model, example_inputs, optimize_ctx, experiment, tag
2645
    ):
2646
        """
2647
        Checks accuracy.
2648
        1) Collect the outputs with fp64 datatype. This is useful for error checking.
2649
        2) Checks if eager itself has variations.
2650
        """
2651
        start_stats = get_dynamo_stats()
2652

2653
        def record_status(accuracy_status, dynamo_start_stats):
2654
            """
2655
            Records the status in the csv file
2656
            """
2657
            if current_name in self.non_deterministic_models:
2658
                if accuracy_status in (
2659
                    "pass",
2660
                    "eager_two_runs_differ",
2661
                    "fail_accuracy",
2662
                ):
2663
                    accuracy_status = "pass"
2664

2665
            headers = ["dev", "name", "batch_size", "accuracy"]
2666
            fields = [current_device, current_name, current_batch_size, accuracy_status]
2667

2668
            if tag is not None:
2669
                headers.insert(3, "tag")
2670
                fields.insert(3, tag)
2671

2672
            dynamo_stats = get_dynamo_stats()
2673
            dynamo_stats.subtract(dynamo_start_stats)
2674
            for k, v in dynamo_stats.items():
2675
                headers.append(k)
2676
                fields.append(v)
2677

2678
            output_csv(output_filename, headers, fields)
2679
            return accuracy_status
2680

2681
        if name in self.skip_accuracy_checks_large_models_dashboard:
2682
            return record_status("pass_due_to_skip", dynamo_start_stats=start_stats)
2683

2684
        # Skip all accuracy check for the torchao backend
2685
        if self.args.backend == "torchao":
2686
            return record_status("pass_due_to_skip", dynamo_start_stats=start_stats)
2687

2688
        with self.pick_grad(name, self.args.training):
2689
            # Collect the fp64 reference outputs to be used later for accuracy checking.
2690
            fp64_outputs = None
2691
            model_fp64 = None
2692
            inputs_fp64 = None
2693
            try:
2694
                model_fp64, inputs_fp64 = cast_to_fp64(
2695
                    self.deepcopy_and_maybe_parallelize(model),
2696
                    clone_inputs(example_inputs),
2697
                )
2698
                self.init_optimizer(name, current_device, model_fp64.parameters())
2699
                fp64_outputs = self.run_n_iterations(model_fp64, inputs_fp64)
2700
                fp64_outputs = tree_map(
2701
                    lambda x: x.to(torch.float64)
2702
                    if isinstance(x, torch.Tensor) and x.is_floating_point()
2703
                    else x,
2704
                    fp64_outputs,
2705
                )
2706
            except Exception:
2707
                log.warning(
2708
                    "fp64 golden ref were not generated for %s. Setting accuracy check to cosine",
2709
                    name,
2710
                )
2711
                self.args.cosine = True
2712
                fp64_outputs = None
2713
            finally:
2714
                del model_fp64, inputs_fp64
2715
                empty_gpu_cache(current_device)
2716

2717
            tolerance, cos_similarity = self.get_tolerance_and_cosine_flag(
2718
                self.args.training, current_device, name
2719
            )
2720

2721
            # Cast the model to float16/float32 as necessary
2722
            model, example_inputs = self.maybe_cast(model, example_inputs)
2723
            accuracy_status = "pass"
2724

2725
            # Get results of native pytorch
2726
            reset_rng_state()
2727
            model_copy = None
2728
            try:
2729
                model_copy = self.deepcopy_and_maybe_parallelize(model)
2730
                self.init_optimizer(name, current_device, model_copy.parameters())
2731
                correct_result = self.run_n_iterations(
2732
                    model_copy, clone_inputs(example_inputs)
2733
                )
2734
            except Exception as e:
2735
                accuracy_status = (
2736
                    "eager_1st_run_OOM"
2737
                    if isinstance(e, torch.cuda.OutOfMemoryError)
2738
                    else "eager_1st_run_fail"
2739
                )
2740
                log.exception("")
2741
                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2742
            finally:
2743
                del model_copy
2744
                empty_gpu_cache(current_device)
2745

2746
            # Rerun native pytorch
2747
            reset_rng_state()
2748
            model_copy = None
2749
            try:
2750
                model_copy = self.deepcopy_and_maybe_parallelize(model)
2751
                self.init_optimizer(name, current_device, model_copy.parameters())
2752
                correct_rerun_result = self.run_n_iterations(
2753
                    model_copy, clone_inputs(example_inputs)
2754
                )
2755
            except Exception as e:
2756
                accuracy_status = (
2757
                    "eager_2nd_run_OOM"
2758
                    if isinstance(e, torch.cuda.OutOfMemoryError)
2759
                    else "eager_2nd_run_fail"
2760
                )
2761
                log.exception("")
2762
                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2763
            finally:
2764
                del model_copy
2765
                empty_gpu_cache(current_device)
2766

2767
            # Two eager runs should have exactly same result
2768
            is_same = True
2769
            try:
2770
                if (
2771
                    name not in self.skip_accuracy_check_as_eager_non_deterministic
2772
                    and not same(
2773
                        correct_result,
2774
                        correct_rerun_result,
2775
                        fp64_ref=None,
2776
                        cos_similarity=False,
2777
                        tol=0,
2778
                        equal_nan=self.equal_nan,
2779
                        use_larger_multiplier_for_smaller_tensor=self.use_larger_multiplier_for_smaller_tensor(
2780
                            name
2781
                        ),
2782
                    )
2783
                ):
2784
                    is_same = False
2785
            except Exception as e:
2786
                # Sometimes torch.allclose may throw RuntimeError
2787
                is_same = False
2788

2789
            if not is_same:
2790
                accuracy_status = "eager_two_runs_differ"
2791
                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2792

2793
            correct_rerun_result = None
2794

2795
            # Run with Dynamo
2796
            reset_rng_state()
2797
            torch._dynamo.reset()
2798
            model_copy = None
2799
            try:
2800
                model_copy = self.deepcopy_and_maybe_parallelize(model)
2801
                self.init_optimizer(name, current_device, model_copy.parameters())
2802
                if self.args.export or self.args.export_aot_inductor:
2803
                    # apply export on module directly
2804
                    # no need for n iterations
2805
                    # the logic should be the same to self.model_iter_fn (forward_pass)
2806
                    with self.autocast(**self.autocast_arg):
2807
                        optimized_model_iter_fn = optimize_ctx(
2808
                            model_copy, example_inputs
2809
                        )
2810
                        new_result = optimized_model_iter_fn(model_copy, example_inputs)
2811
                else:
2812
                    optimized_model_iter_fn = optimize_ctx(self.run_n_iterations)
2813
                    with maybe_enable_compiled_autograd(
2814
                        self.args.compiled_autograd,
2815
                        fullgraph=self.args.nopython,
2816
                        dynamic=self.args.dynamic_shapes,
2817
                    ):
2818
                        new_result = optimized_model_iter_fn(model_copy, example_inputs)
2819
            except Exception as e:
2820
                log.exception("")
2821
                print(
2822
                    "TorchDynamo optimized model failed to run because of following error"
2823
                )
2824
                accuracy_status = (
2825
                    "OOM"
2826
                    if isinstance(e, torch.cuda.OutOfMemoryError)
2827
                    else "fail_to_run"
2828
                )
2829
                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2830
            finally:
2831
                del model_copy
2832

2833
            if name in self.skip_accuracy_check_as_eager_non_deterministic:
2834
                return record_status("pass_due_to_skip", dynamo_start_stats=start_stats)
2835

2836
            if (
2837
                current_onnx_compiler == "torchscript"
2838
                or current_onnx_compiler == "dynamo"
2839
            ):
2840
                # Workaround for ONNX for non-tensor outputs
2841
                (
2842
                    correct_result,
2843
                    new_result,
2844
                    fp64_outputs,
2845
                ) = _OnnxPatch.patch_non_tensor_outputs(
2846
                    correct_result, new_result, fp64_outputs
2847
                )
2848
                # Relax tolerance for ONNX cuda
2849
                if current_device == "cuda":
2850
                    tolerance = 1e-2
2851

2852
                # TODO: store correct_result into the dumped file for offline onnx model validation.
2853
                # The downside and potential problem, is that the output formats may be different.
2854
                # E.g., the output order might not match, None might be part of output, etc.
2855

2856
            try:
2857
                if self.args.training and self.args.amp:
2858
                    if process_fn := self.get_output_amp_train_process_func.get(
2859
                        name, None
2860
                    ):
2861
                        correct_result = process_fn(correct_result)
2862
                        new_result = process_fn(new_result)
2863
                        fp64_outputs = process_fn(fp64_outputs)
2864

2865
                if not same(
2866
                    correct_result,
2867
                    new_result,
2868
                    fp64_outputs,
2869
                    equal_nan=self.equal_nan,
2870
                    use_larger_multiplier_for_smaller_tensor=self.use_larger_multiplier_for_smaller_tensor(
2871
                        name
2872
                    ),
2873
                    cos_similarity=cos_similarity,
2874
                    tol=tolerance,
2875
                ):
2876
                    is_same = False
2877
            except Exception as e:
2878
                # Sometimes torch.allclose may throw RuntimeError
2879
                is_same = False
2880

2881
            if not is_same:
2882
                if self.args.skip_accuracy_check:
2883
                    accuracy_status = "pass_due_to_skip"
2884
                else:
2885
                    accuracy_status = "fail_accuracy"
2886
                return record_status(accuracy_status, dynamo_start_stats=start_stats)
2887

2888
        return record_status(accuracy_status, dynamo_start_stats=start_stats)
2889

2890
    def check_tolerance(
2891
        self, name, model, example_inputs, optimize_ctx, base_device="cpu"
2892
    ):
2893
        """
2894
        Checks tolerance based on https://pytorch.org/docs/stable/generated/torch.allclose.html.
2895
        """
2896
        tolerance_status = "pass"
2897
        if name in self.skip_accuracy_checks_large_models_dashboard:
2898
            tolerance_status = "pass_due_to_skip"
2899
            return tolerance_status
2900
        # Cast the model to float16/float32 as necessary
2901
        model, example_inputs = self.maybe_cast(model, example_inputs)
2902

2903
        with self.pick_grad(name, self.args.training):
2904
            # Get results of native pytorch
2905
            reset_rng_state()
2906
            model_copy = copy.deepcopy(model)
2907
            model_copy = model_copy.to(base_device)
2908
            example_inputs_copy = copy.deepcopy(example_inputs)
2909
            example_inputs_copy = tree_map(
2910
                lambda x: x.to(base_device), example_inputs_copy
2911
            )
2912
            self.init_optimizer(name, base_device, model_copy.parameters())
2913
            correct_result = self.run_n_iterations(model_copy, example_inputs_copy)
2914

2915
            # Run with Dynamo
2916
            # Sometime CI fails with random triton compilation failure which will be skipped for now
2917
            # TODO: revisit this after switching to new Triton runtime
2918
            reset_rng_state()
2919
            torch._dynamo.reset()
2920
            try:
2921
                self.init_optimizer(name, current_device, model.parameters())
2922
                optimized_model_iter_fn = optimize_ctx(self.run_n_iterations)
2923
                new_result = optimized_model_iter_fn(model, example_inputs)
2924
            except Exception as e:
2925
                log.exception("")
2926
                print(
2927
                    "TorchDynamo optimized model failed to run because of following error"
2928
                )
2929
                return "fail_to_run"
2930

2931
            def dump_max_mean_values(tol, ref, res):
2932
                if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)):
2933
                    for refi, resi in zip(ref, res):
2934
                        dump_max_mean_values(tol, refi, resi)
2935
                elif isinstance(ref, dict):
2936
                    for k in ref.keys():
2937
                        dump_max_mean_values(tol, ref[k], res[k])
2938
                elif isinstance(ref, torch.Tensor):
2939
                    res = res.to(base_device)
2940
                    t = torch.abs(ref - res) / (1 + torch.abs(ref))
2941
                    tol.append(t.flatten().to(torch.float32))
2942
                return tol
2943

2944
            tol = []
2945
            dump_max_mean_values(tol, correct_result, new_result)
2946
            tol = torch.cat(tol)
2947
            tol = torch.tensor(tol)
2948
            max = torch.max(tol)
2949
            mean = torch.mean(tol)
2950
            div = torch.std(tol)
2951
            headers = ["dev", "name", "batch_size", "max", "mean", "std"]
2952
            fields = [
2953
                current_device,
2954
                current_name,
2955
                current_batch_size,
2956
                max.item(),
2957
                mean.item(),
2958
                div.item(),
2959
            ]
2960
            output_csv(output_filename, headers, fields)
2961
        return tolerance_status
2962

2963
    def run_performance_test_non_alternate(
2964
        self, name, model, example_inputs, optimize_ctx, experiment, tag=None
2965
    ):
2966
        "Run performance test in non-alternately."
2967
        assert (
2968
            experiment.func is latency_experiment
2969
        ), "Must run with latency_experiment."
2970

2971
        def warmup(fn, model, example_inputs, mode, niters=10):
2972
            peak_mem = 0
2973
            start_stats = get_dynamo_stats()
2974
            try:
2975
                if current_device == "cuda":
2976
                    torch.cuda.reset_peak_memory_stats()
2977
                    empty_gpu_cache(current_device)
2978
                t0 = time.perf_counter()
2979
                for _ in range(niters):
2980
                    fn(model, example_inputs)
2981
                t1 = time.perf_counter()
2982
                latency = t1 - t0
2983
                if current_device == "cuda":
2984
                    peak_mem = get_peak_memory()
2985
                elif current_device == "cpu":
2986
                    total = psutil.virtual_memory().total
2987
                    percentage = psutil.Process(os.getpid()).memory_percent()
2988
                    peak_mem = percentage * total / 10**9
2989
            except Exception:
2990
                log.exception("Backend %s failed in warmup()", mode)
2991
                write_csv_when_exception(
2992
                    self.args, current_name, "warmup_failed", current_device
2993
                )
2994
                return sys.exit(-1)
2995
            dynamo_stats = get_dynamo_stats()
2996
            dynamo_stats.subtract(start_stats)
2997
            return latency, peak_mem, dynamo_stats
2998

2999
        # Cast the model to float16/float32 as necessary
3000
        model, example_inputs = self.maybe_cast(model, example_inputs)
3001

3002
        # Use distributed wrapping as necessary
3003
        model = self.deepcopy_and_maybe_parallelize(model)
3004

3005
        self.init_optimizer(name, current_device, model.parameters())
3006

3007
        # The self.autocast context is needed for the model we export with aot_compile,
3008
        # similar to what we do in the check_accuracy function
3009
        ctx = (
3010
            self.autocast(**self.autocast_arg)
3011
            if self.args.export_aot_inductor
3012
            else contextlib.nullcontext()
3013
        )
3014

3015
        with self.pick_grad(name, self.args.training), ctx:
3016
            ok, total = Stats.reset_counters()
3017
            experiment_kwargs = {}
3018
            if tag is not None:
3019
                experiment_kwargs["tag"] = tag
3020
            results = []
3021

3022
            with maybe_snapshot_memory(
3023
                self.args.snapshot_memory, f"eager_{self.args.only}"
3024
            ):
3025
                eager_latency, eager_peak_mem, _ = warmup(
3026
                    self.model_iter_fn, model, example_inputs, "eager"
3027
                )
3028
                if self.args.use_warm_peak_memory:
3029
                    _, eager_peak_mem, _ = warmup(
3030
                        self.model_iter_fn, model, example_inputs, "eager", niters=1
3031
                    )
3032

3033
            baseline_timings = experiment(
3034
                model, example_inputs, mark="expected", **experiment_kwargs
3035
            )
3036

3037
            if self.args.export_aot_inductor:
3038
                t_0 = time.perf_counter()
3039
                optimized_model_iter_fn = optimize_ctx
3040
                t_1 = time.perf_counter()
3041
                aot_compilation_time = t_1 - t_0
3042
            else:
3043
                optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
3044
                aot_compilation_time = 0
3045

3046
            with maybe_enable_compiled_autograd(
3047
                self.args.compiled_autograd,
3048
                fullgraph=self.args.nopython,
3049
                dynamic=self.args.dynamic_shapes,
3050
            ), maybe_snapshot_memory(
3051
                self.args.snapshot_memory, f"compiled_{self.args.only}"
3052
            ):
3053
                dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup(
3054
                    optimized_model_iter_fn, model, example_inputs, "dynamo"
3055
                )
3056
                if self.args.use_warm_peak_memory:
3057
                    _, dynamo_peak_mem, _ = warmup(
3058
                        optimized_model_iter_fn,
3059
                        model,
3060
                        example_inputs,
3061
                        "dynamo",
3062
                        niters=1,
3063
                    )
3064

3065
            if self.args.profile_dynamo_cache_lookup:
3066
                with torch.profiler.profile(
3067
                    activities=[torch.profiler.ProfilerActivity.CPU]
3068
                ) as prof:
3069
                    with maybe_enable_compiled_autograd(
3070
                        self.args.compiled_autograd,
3071
                        fullgraph=self.args.nopython,
3072
                        dynamic=self.args.dynamic_shapes,
3073
                    ):
3074
                        warmup(optimized_model_iter_fn, model, example_inputs, "dynamo")
3075

3076
                events = list(
3077
                    filter(
3078
                        lambda event: "TorchDynamo Cache Lookup" in event.key,
3079
                        prof.key_averages(),
3080
                    )
3081
                )
3082
                dynamo_cache_lookup_latency = events[0].self_cpu_time_total
3083

3084
            compilation_time = dynamo_latency - eager_latency + aot_compilation_time
3085
            compression_ratio = (
3086
                eager_peak_mem / dynamo_peak_mem if dynamo_peak_mem else 0.0
3087
            )
3088
            if self.args.print_memory:
3089
                print(
3090
                    f"memory: eager: {eager_peak_mem:.2f} GB, "
3091
                    f"dynamo: {dynamo_peak_mem:.2f} GB, "
3092
                    f"ratio: {compression_ratio:.2f}"
3093
                )
3094

3095
            if self.args.print_compilation_time:
3096
                print(f"Compilation time: {compilation_time:.2f}")
3097

3098
            if experiment.func is speedup_experiment:
3099
                experiment_kwargs["compilation_latency"] = compilation_time
3100
                experiment_kwargs["compression_ratio"] = compression_ratio
3101
                experiment_kwargs["eager_peak_mem"] = eager_peak_mem
3102
                experiment_kwargs["dynamo_peak_mem"] = dynamo_peak_mem
3103
                experiment_kwargs["dynamo_stats"] = dynamo_stats
3104
                if self.args.profile_dynamo_cache_lookup:
3105
                    experiment_kwargs[
3106
                        "cache_lookup_latency"
3107
                    ] = dynamo_cache_lookup_latency
3108

3109
            if experiment.func is speedup_experiment_onnx:
3110
                experiment = functools.partial(
3111
                    experiment, optimized_model_iter_fn.context.onnx_model
3112
                )
3113
            backend_timings = experiment(
3114
                model, example_inputs, mark="expected", **experiment_kwargs
3115
            )
3116
            timings = np.stack((baseline_timings, backend_timings), axis=1)
3117
            result_summary = latency_experiment_summary(
3118
                self.args, model, timings, **experiment_kwargs
3119
            )
3120
            if not hasattr(model, name):
3121
                model.name = name
3122
            results.append(result_summary)
3123
            return " ".join(map(str, results))
3124

3125
    def run_performance_test(
3126
        self, name, model, example_inputs, optimize_ctx, experiment, tag=None
3127
    ):
3128
        if self.args.xla:
3129
            with self.pick_grad(name, self.args.training):
3130
                return experiment(*self.maybe_cast(model, example_inputs))
3131

3132
        def warmup(fn, model, example_inputs, mode, niters=5):
3133
            peak_mem = 0
3134
            start_stats = get_dynamo_stats()
3135
            try:
3136
                if current_device == "cuda":
3137
                    torch.cuda.reset_peak_memory_stats()
3138
                    empty_gpu_cache(current_device)
3139
                t0 = time.perf_counter()
3140
                for _ in range(niters):
3141
                    fn(model, example_inputs)
3142
                t1 = time.perf_counter()
3143
                latency = t1 - t0
3144
                if current_device == "cuda":
3145
                    peak_mem = get_peak_memory()
3146
                elif current_device == "cpu":
3147
                    total = psutil.virtual_memory().total
3148
                    percentage = psutil.Process(os.getpid()).memory_percent()
3149
                    peak_mem = percentage * total / 10**9
3150
            except Exception:
3151
                log.exception("Backend %s failed in warmup()", mode)
3152
                write_csv_when_exception(
3153
                    self.args, current_name, "warmup_failed", current_device
3154
                )
3155
                return sys.exit(-1)
3156
            dynamo_stats = get_dynamo_stats()
3157
            dynamo_stats.subtract(start_stats)
3158
            return latency, peak_mem, dynamo_stats
3159

3160
        # Cast the model to float16/float32 as necessary
3161
        model, example_inputs = self.maybe_cast(model, example_inputs)
3162

3163
        # Use distributed wrapping as necessary
3164
        model = self.deepcopy_and_maybe_parallelize(model)
3165

3166
        self.init_optimizer(name, current_device, model.parameters())
3167

3168
        # The self.autocast context is needed for the model we export with aot_compile,
3169
        # similar to what we do in the check_accuracy function
3170
        ctx = (
3171
            self.autocast(**self.autocast_arg)
3172
            if self.args.export_aot_inductor
3173
            else contextlib.nullcontext()
3174
        )
3175

3176
        with self.pick_grad(name, self.args.training), ctx:
3177
            ok, total = Stats.reset_counters()
3178
            experiment_kwargs = {}
3179
            if tag is not None:
3180
                experiment_kwargs["tag"] = tag
3181
            results = []
3182
            with maybe_snapshot_memory(
3183
                self.args.snapshot_memory, f"eager_{self.args.only}"
3184
            ):
3185
                eager_latency, eager_peak_mem, _ = warmup(
3186
                    self.model_iter_fn, model, example_inputs, "eager"
3187
                )
3188
                if self.args.use_warm_peak_memory:
3189
                    _, eager_peak_mem, _ = warmup(
3190
                        self.model_iter_fn, model, example_inputs, "eager", niters=1
3191
                    )
3192

3193
            if self.args.export_aot_inductor:
3194
                t_0 = time.perf_counter()
3195
                optimized_model_iter_fn = optimize_ctx
3196
                t_1 = time.perf_counter()
3197
                aot_compilation_time = t_1 - t_0
3198
            else:
3199
                optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
3200
                aot_compilation_time = 0
3201

3202
            with maybe_enable_compiled_autograd(
3203
                self.args.compiled_autograd,
3204
                fullgraph=self.args.nopython,
3205
                dynamic=self.args.dynamic_shapes,
3206
            ), maybe_snapshot_memory(
3207
                self.args.snapshot_memory, f"compiled_{self.args.only}"
3208
            ):
3209
                dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup(
3210
                    optimized_model_iter_fn, model, example_inputs, "dynamo"
3211
                )
3212
                if self.args.use_warm_peak_memory:
3213
                    _, dynamo_peak_mem, _ = warmup(
3214
                        optimized_model_iter_fn,
3215
                        model,
3216
                        example_inputs,
3217
                        "dynamo",
3218
                        niters=1,
3219
                    )
3220

3221
            if self.args.profile_dynamo_cache_lookup:
3222
                with torch.profiler.profile(
3223
                    activities=[torch.profiler.ProfilerActivity.CPU]
3224
                ) as prof:
3225
                    with maybe_enable_compiled_autograd(
3226
                        self.args.compiled_autograd,
3227
                        fullgraph=self.args.nopython,
3228
                        dynamic=self.args.dynamic_shapes,
3229
                    ):
3230
                        warmup(optimized_model_iter_fn, model, example_inputs, "dynamo")
3231

3232
                events = list(
3233
                    filter(
3234
                        lambda event: "TorchDynamo Cache Lookup" in event.key,
3235
                        prof.key_averages(),
3236
                    )
3237
                )
3238
                dynamo_cache_lookup_latency = events[0].self_cpu_time_total
3239

3240
            compilation_time = dynamo_latency - eager_latency + aot_compilation_time
3241
            compression_ratio = (
3242
                eager_peak_mem / dynamo_peak_mem if dynamo_peak_mem else 0.0
3243
            )
3244
            if self.args.print_memory:
3245
                print(
3246
                    f"memory: eager: {eager_peak_mem:.2f} GB, "
3247
                    f"dynamo: {dynamo_peak_mem:.2f} GB, "
3248
                    f"ratio: {compression_ratio:.2f}"
3249
                )
3250

3251
            if self.args.print_compilation_time:
3252
                print(f"Compilation time: {compilation_time:.2f}")
3253

3254
            if experiment.func is speedup_experiment:
3255
                experiment_kwargs["compilation_latency"] = compilation_time
3256
                experiment_kwargs["compression_ratio"] = compression_ratio
3257
                experiment_kwargs["eager_peak_mem"] = eager_peak_mem
3258
                experiment_kwargs["dynamo_peak_mem"] = dynamo_peak_mem
3259
                experiment_kwargs["dynamo_stats"] = dynamo_stats
3260
                if self.args.profile_dynamo_cache_lookup:
3261
                    experiment_kwargs[
3262
                        "cache_lookup_latency"
3263
                    ] = dynamo_cache_lookup_latency
3264

3265
            if experiment.func is coverage_experiment:
3266
                ok, total = Stats.reset_counters()
3267
                results = []
3268
                # run with torch._dynamo few times to populate the cache
3269
                for _ in range(3):
3270
                    optimized_model_iter_fn(model, example_inputs)
3271
                _, frames_second_pass = Stats.reset_counters()  # should be 0
3272
                if frames_second_pass > 0:
3273
                    optimized_model_iter_fn(model, example_inputs)
3274
                    _, frames_third_pass = Stats.reset_counters()  # should be 0
3275
                else:
3276
                    frames_third_pass = 0
3277

3278
                results.append(
3279
                    f"{ok:3}/{total:3} +{frames_third_pass} frames {compilation_time:3.0f}s"
3280
                )
3281

3282
            if experiment.func is speedup_experiment_onnx:
3283
                experiment = functools.partial(
3284
                    experiment, optimized_model_iter_fn.context.onnx_model
3285
                )
3286

3287
            if not hasattr(model, name):
3288
                model.name = name
3289
            results.append(experiment(model, example_inputs, **experiment_kwargs))
3290
            return " ".join(map(str, results))
3291

3292
    def minify_model(
3293
        self,
3294
        name,
3295
        model,
3296
        example_inputs,
3297
        optimize_ctx,
3298
        experiment,
3299
        tag,
3300
    ):
3301
        logging.info("Minifying %s...", name)
3302
        os.environ["TORCH_COMPILE_DEBUG"] = "1"
3303
        os.environ["TORCHDYNAMO_REPRO_AFTER"] = "dynamo"
3304
        os.environ["TORCHDYNAMO_REPRO_LEVEL"] = "4"
3305

3306
        self.check_accuracy(name, model, example_inputs, optimize_ctx, experiment, tag)
3307

3308
        if self.args.output_directory:
3309
            repro_dir = self.args.output_directory
3310
        else:
3311
            repro_dir = torch._dynamo.config.base_dir
3312

3313
        try:
3314
            shutil.move("repro.py", f"{repro_dir}/{name}_repro.py")
3315
        except OSError as e:
3316
            logging.error("Could not find repro script for model %s", name)
3317
        else:
3318
            logging.info(
3319
                "Repro script for model %s with minified graph saved to %s",
3320
                name,
3321
                repro_dir,
3322
            )
3323

3324
    def maybe_preserve_compile_debug(self, name, status):
3325
        if (
3326
            name in CI_PRESERVE_COMPILE_DEBUG
3327
            and status in CI_PRESERVE_COMPILE_DEBUG[name]
3328
        ):
3329
            src_dir = torch._dynamo.utils.get_debug_dir()
3330
            if os.path.isdir(src_dir):
3331
                dbg_dir = os.path.join(
3332
                    os.getcwd(), "test", "debug", "torch_compile_debug"
3333
                )
3334
                dst_dir = os.path.join(dbg_dir, os.path.basename(src_dir))
3335
                try:
3336
                    os.makedirs(dbg_dir, exist_ok=True)
3337
                    os.rename(src_dir, dst_dir)
3338
                    log.warning("Moved %s to %s", src_dir, dst_dir)
3339
                except OSError:
3340
                    log.exception("Failed to preserve %s", src_dir)
3341

3342
    def run_one_model(
3343
        self,
3344
        name,
3345
        model,
3346
        example_inputs,
3347
        optimize_ctx,
3348
        experiment,
3349
        explain=False,
3350
        tag=None,
3351
    ):
3352
        mode = "train" if self.args.training else "eval"
3353
        msg = f"{current_device:4} {mode:5} {current_name:34} "
3354
        if tag:
3355
            msg += f" {tag:26}"
3356
        print(msg, flush=True)
3357

3358
        start_stats = get_dynamo_stats()
3359

3360
        if self.args.accuracy:
3361
            status = self.check_accuracy(
3362
                name, model, example_inputs, optimize_ctx, experiment, tag
3363
            )
3364
            print(status)
3365
            if status == "fail_accuracy" and self.args.minify:
3366
                self.minify_model(
3367
                    name, model, example_inputs, optimize_ctx, experiment, tag
3368
                )
3369
        elif self.args.tolerance:
3370
            status = self.check_tolerance(name, model, example_inputs, optimize_ctx)
3371
            print(status)
3372
        elif self.args.performance:
3373
            if self.args.backend == "torchao":
3374
                status = self.run_performance_test_non_alternate(
3375
                    name, model, example_inputs, optimize_ctx, experiment, tag
3376
                )
3377
            else:
3378
                status = self.run_performance_test(
3379
                    name, model, example_inputs, optimize_ctx, experiment, tag
3380
                )
3381
            print(status)
3382
        empty_gpu_cache(current_device)
3383

3384
        self.maybe_preserve_compile_debug(name, status)
3385

3386
        if self.args.timing:
3387
            from torch._dynamo.utils import op_count, print_time_report
3388
            from torch.utils._stats import simple_call_counter
3389

3390
            print_time_report()
3391
            stats = "STATS: "
3392
            stats = stats + " | ".join(
3393
                itertools.chain(
3394
                    [f"call_* op count: {op_count}"],
3395
                    (f"{key}:{value}" for key, value in simple_call_counter.items()),
3396
                )
3397
            )
3398
            print(stats)
3399
        stats = get_dynamo_stats()
3400
        stats.subtract(start_stats)
3401

3402
        if explain:
3403
            print(
3404
                f"Dynamo produced {stats['unique_graphs']} graphs "
3405
                f"covering {stats['calls_captured']} ops with "
3406
                f"{stats['graph_breaks']} graph breaks ({stats['unique_graph_breaks']} unique)"
3407
            )
3408

3409
        if explain or self.args.log_graph_breaks or self.args.print_graph_breaks:
3410
            filename = f"{output_filename.rstrip('.csv')}_graph_breaks.csv"
3411

3412
            def add_double_quotes(x):
3413
                # Delimiter because reason could have comma
3414
                return f'"{x}"'
3415

3416
            for graph_break in graph_break_reasons:
3417
                reason = add_double_quotes(graph_break.reason)
3418
                user_stack = add_double_quotes(
3419
                    ", ".join([str(x) for x in graph_break.user_stack])
3420
                )
3421
                output_csv(
3422
                    filename,
3423
                    ["model", "reason", "user_stack"],
3424
                    [current_name, reason, user_stack],
3425
                )
3426

3427
        if self.args.stats:
3428
            Stats.print_summary()
3429

3430

3431
def help(fn):
3432
    return fn.__doc__
3433

3434

3435
diff_branch_default = "DIFF-BRANCH-DEFAULT"
3436

3437

3438
def should_diff_branch(args):
3439
    return args.diff_branch != diff_branch_default
3440

3441

3442
def parse_args(args=None):
3443
    parser = argparse.ArgumentParser()
3444
    parser.add_argument(
3445
        "--filter", "-k", action="append", help="filter benchmarks with regexp"
3446
    )
3447
    parser.add_argument(
3448
        "--exclude", "-x", action="append", help="filter benchmarks with regexp"
3449
    )
3450
    parser.add_argument(
3451
        "--exclude-exact", action="append", help="filter benchmarks with exact match"
3452
    )
3453
    parser.add_argument(
3454
        "--total-partitions",
3455
        type=int,
3456
        default=1,
3457
        choices=range(1, 16),
3458
        help="Total number of partitions we want to divide the benchmark suite into",
3459
    )
3460
    parser.add_argument(
3461
        "--partition-id",
3462
        type=int,
3463
        default=0,
3464
        help="ID of the benchmark suite partition to be run. Used to divide CI tasks",
3465
    )
3466
    parser.add_argument(
3467
        "--devices", "--device", "-d", action="append", help="cpu or cuda"
3468
    )
3469
    parser.add_argument("--device-index", help="CUDA device index")
3470
    parser.add_argument(
3471
        "--repeat", "-n", type=int, default=30, help="number of timing runs"
3472
    )
3473
    iterations_per_run_help = """
3474
        Run this may iterations for each time measurement. This is mainly used for
3475
        XLA training. We want to run multiple iterations per measurement so the
3476
        tracing and computation for different iteartions can overlap with each
3477
        other. This makes sure we have an accurate xla baseline.
3478
    """
3479
    parser.add_argument(
3480
        "--iterations-per-run", type=int, default=1, help=iterations_per_run_help
3481
    )
3482
    parser.add_argument(
3483
        "--randomize-input",
3484
        action="store_true",
3485
        help="Whether to randomize the input values. Dimensions will be kept the same.",
3486
    )
3487
    parser.add_argument(
3488
        "--threads",
3489
        "-t",
3490
        type=int,
3491
        help="number of threads to use for eager and inductor",
3492
    )
3493
    parser.add_argument(
3494
        "--nopython", action="store_true", help="Turn graph breaks into errors"
3495
    )
3496
    parser.add_argument(
3497
        "--no-skip",
3498
        action="store_true",
3499
        help="run models that are in the global SKIP list",
3500
    )
3501
    parser.add_argument(
3502
        "--prims-nvfuser", action="store_true", help="user prims + nvfuser backend"
3503
    )
3504
    parser.add_argument(
3505
        "--dump-raw-metrics",
3506
        action="store_true",
3507
        help="dump raw timing metrics from speedup experiment",
3508
    )
3509
    parser.add_argument(
3510
        "--log-operator-inputs",
3511
        action="store_true",
3512
        default=False,
3513
    )
3514
    parser.add_argument(
3515
        "--channels-last",
3516
        action="store_true",
3517
        default=False,
3518
        help="use channels last format",
3519
    )
3520
    parser.add_argument(
3521
        "--batch-size", "--batch_size", type=int, help="batch size for benchmarking"
3522
    )
3523
    parser.add_argument(
3524
        "--iterations", type=int, default=2, help="how many iterations to run"
3525
    )
3526
    parser.add_argument(
3527
        "--batch-size-file", type=str, help="String to load batch size from"
3528
    )
3529
    parser.add_argument("--cosine", action="store_true", help="use cosine similarity")
3530
    parser.add_argument(
3531
        "--freezing", action="store_true", help="turn on freezing", default=False
3532
    )
3533
    parser.add_argument(
3534
        "--inductor-config",
3535
        "-c",
3536
        action="append",
3537
        help="key=value in torch._inductor.config",
3538
    )
3539
    parser.add_argument(
3540
        "--ci", action="store_true", help="Flag to tell that its a CI run"
3541
    )
3542
    parser.add_argument(
3543
        "--dashboard", action="store_true", help="Flag to tell that its a Dashboard run"
3544
    )
3545
    parser.add_argument(
3546
        "--skip-fp64-check", action="store_true", help="skip accuracy check using fp64"
3547
    )
3548
    parser.add_argument(
3549
        "--fast", "-f", action="store_true", help="skip slow benchmarks"
3550
    )
3551
    parser.add_argument(
3552
        "--only",
3553
        help="""Run just one model from torchbench. Or
3554
        specify the path and class name of the model in format like:
3555
        --only=path:<MODEL_FILE_PATH>,class:<CLASS_NAME>
3556

3557
        Due to the fact that dynamo changes current working directory,
3558
        the path should be an absolute path.
3559

3560
        The class should have a method get_example_inputs to return the inputs
3561
        for the model. An example looks like
3562
        ```
3563
        class LinearModel(nn.Module):
3564
            def __init__(self):
3565
                super().__init__()
3566
                self.linear = nn.Linear(10, 10)
3567

3568
            def forward(self, x):
3569
                return self.linear(x)
3570

3571
            def get_example_inputs(self):
3572
                return (torch.randn(2, 10),)
3573
        ```
3574
    """,
3575
    )
3576
    parser.add_argument(
3577
        "--multiprocess",
3578
        action="store_true",
3579
        help="Create n processes based on the number of devices (distributed use case).",
3580
    )
3581
    parser.add_argument(
3582
        "--ddp",
3583
        action="store_true",
3584
        help="Wraps model in DDP before running it, and uses dynamo DDPOptmizer (graph breaks) by default.",
3585
    )
3586
    parser.add_argument(
3587
        "--fsdp",
3588
        action="store_true",
3589
        help="""Wraps model in FSDP before running it.
3590
        Doesn't recursively wrap, mainly useful for checking dynamo UnspecNNModule compatibility
3591
    """,
3592
    )
3593
    parser.add_argument(
3594
        "--optimize-ddp-mode",
3595
        type=str,
3596
        default="ddp_optimizer",
3597
        help="Specify the DDP optimization mode -- the value of torch._dynamo.config.optimize_ddp.",
3598
    )
3599
    parser.add_argument(
3600
        "--distributed-master-port",
3601
        default="6789",
3602
        help="Port to bind for for torch.distributed.  Use the default unless it's conflicting with another user",
3603
    )
3604
    parser.add_argument(
3605
        "--dynamic-shapes",
3606
        action="store_true",
3607
        help="Runs a dynamic shapes version of the benchmark, if available.",
3608
    )
3609
    parser.add_argument(
3610
        "--propagate-real-tensors",
3611
        action="store_true",
3612
        help="Capture as much data dependent as you can by unsoundly propagating real tensors",
3613
    )
3614
    parser.add_argument(
3615
        "--dynamic-batch-only",
3616
        action="store_true",
3617
        help="Only assume batch dimension is dynamic.  Implies --dynamic-shapes",
3618
    )
3619
    parser.add_argument(
3620
        "--specialize-int", action="store_true", help="Run with specialize_int=True."
3621
    )
3622
    parser.add_argument(
3623
        "--use-eval-mode",
3624
        action="store_true",
3625
        help="sets model.eval() to reduce randomness",
3626
    )
3627
    parser.add_argument(
3628
        "--skip-accuracy-check",
3629
        action="store_true",
3630
        help="keeps running even when accuracy fails",
3631
    )
3632
    parser.add_argument(
3633
        "--generate-aot-autograd-stats",
3634
        action="store_true",
3635
        help="Generates AOT Autograd stats like how mnay graphs are sent to AOT",
3636
    )
3637
    parser.add_argument(
3638
        "--inductor-settings",
3639
        action="store_true",
3640
        help="Use same settings as --inductor for baseline comparisons",
3641
    )
3642
    parser.add_argument(
3643
        "--suppress-errors",
3644
        action="store_true",
3645
        help="Suppress errors instead of raising them",
3646
    )
3647
    parser.add_argument(
3648
        "--output",
3649
        help="Overrides the output filename",
3650
    )
3651
    parser.add_argument(
3652
        "--output-directory",
3653
        help="Overrides the directory to place output files.",
3654
    )
3655
    parser.add_argument(
3656
        "--disable-output",
3657
        action="store_true",
3658
        help="Disable writing of output files, e.g., for warm-up runs",
3659
    )
3660
    parser.add_argument(
3661
        "--baseline",
3662
        help="Compare with a prior --output",
3663
    )
3664
    parser.add_argument(
3665
        "--part",
3666
        default=None,
3667
        help="Specify the part of the model to run.",
3668
    )
3669
    parser.add_argument(
3670
        "--export-profiler-trace",
3671
        action="store_true",
3672
        help="exports trace of kineto profiler",
3673
    )
3674
    parser.add_argument(
3675
        "--profiler-trace-name",
3676
        "--profiler_trace_name",
3677
        help="Overwrites exported trace name",
3678
    )
3679
    parser.add_argument(
3680
        "--diff-branch",
3681
        default=diff_branch_default,
3682
        help="delta current branch against given branch.",
3683
    )
3684
    parser.add_argument(
3685
        "--tag", default=None, help="Specify a tag to be included in csv files."
3686
    )
3687
    parser.add_argument(
3688
        "--explain",
3689
        action="store_true",
3690
        help="print some graph/op statistics during the run, similar to .explain()",
3691
    )
3692
    parser.add_argument(
3693
        "--stats",
3694
        action="store_true",
3695
        help="print graph counter stats",
3696
    )
3697
    parser.add_argument(
3698
        "--use-warm-peak-memory",
3699
        "--use_warm_peak_memory",
3700
        action="store_true",
3701
        help="Measure peak memory using a warm run to reduce autotuning noise",
3702
    )
3703
    parser.add_argument(
3704
        "--print-memory",
3705
        action="store_true",
3706
        help="print extra memory statistics",
3707
    )
3708
    parser.add_argument(
3709
        "--print-compilation-time",
3710
        action="store_true",
3711
        help="print compilation latency",
3712
    )
3713
    parser.add_argument(
3714
        "--print-dataframe-summary",
3715
        action="store_true",
3716
        help="print dataframe result used for calculating accuracy",
3717
    )
3718
    parser.add_argument(
3719
        "--disable-cudagraphs",
3720
        action="store_true",
3721
        help="Disables cudagraphs for Inductor",
3722
    )
3723
    parser.add_argument(
3724
        "--disable-split-reductions",
3725
        action="store_true",
3726
        help="Disables split reductions for Inductor",
3727
    )
3728
    parser.add_argument(
3729
        "--disable-persistent-reductions",
3730
        action="store_true",
3731
        help="Disables split reductions for Inductor",
3732
    )
3733
    parser.add_argument(
3734
        "--disable-divisible-by-16",
3735
        action="store_true",
3736
        help="Disables divisible by 16 hint to Triton for Inductor",
3737
    )
3738
    parser.add_argument(
3739
        "--inductor-compile-mode",
3740
        default=None,
3741
        help="torch.compile mode argument for inductor runs.",
3742
    )
3743
    parser.add_argument(
3744
        "--print-graph-breaks",
3745
        action="store_true",
3746
        help="Show a warning whenever graph break",
3747
    )
3748
    parser.add_argument(
3749
        "--log-graph-breaks",
3750
        action="store_true",
3751
        help="log graph breaks in a file",
3752
    )
3753
    parser.add_argument(
3754
        "--trace-on-xla",
3755
        action="store_true",
3756
        help="Whether to trace the model on XLA or on eager device",
3757
    )
3758
    parser.add_argument(
3759
        "--xla-tolerance",
3760
        type=float,
3761
        default=1e-2,
3762
        help="XLA needs a loose tolerance to pass the correctness check",
3763
    )
3764
    parser.add_argument(
3765
        "--collect-outputs",
3766
        action="store_true",
3767
        help="""Whether to collect outputs for training. Set this to true if we
3768
        want to verify the numerical correctness of graidents. But that may
3769
        cause time measurement not accurate""",
3770
    )
3771
    parser.add_argument(
3772
        "--enable-activation-checkpointing",
3773
        action="store_true",
3774
        help="Enables activation checkpointing for HF models",
3775
    )
3776
    parser.add_argument("--timing", action="store_true", help="Emits phase timing")
3777

3778
    parser.add_argument(
3779
        "--progress",
3780
        action="store_true",
3781
        help="Print n/k models message between each model run.",
3782
    )
3783

3784
    parser.add_argument(
3785
        "--timeout",
3786
        type=int,
3787
        default=2000,
3788
        help="timeout (second) for benchmarking.",
3789
    )
3790

3791
    parser.add_argument(
3792
        "--per_process_memory_fraction",
3793
        type=float,
3794
        default=1,
3795
        help="Set per-process GPU memory fraction (limit) for reducing usable size and reproducing OOMs",
3796
    )
3797

3798
    parser.add_argument(
3799
        "--no-translation-validation",
3800
        action="store_true",
3801
        help="Disable translation validation for accuracy builds.",
3802
    )
3803

3804
    parser.add_argument(
3805
        "--minify",
3806
        action="store_true",
3807
        help="Enable minification when failure is below tolerance. Save repro script for each model.",
3808
    )
3809

3810
    parser.add_argument(
3811
        "--compiled-autograd",
3812
        action="store_true",
3813
        help="Enables compiled autograd on compiled benchmark",
3814
    )
3815

3816
    parser.add_argument(
3817
        "--profile_dynamo_cache_lookup",
3818
        "--profile-dynamo-cache-lookup",
3819
        action="store_true",
3820
        help="profiles TorchDynamo cache lookup",
3821
    )
3822

3823
    parser.add_argument(
3824
        "--snapshot-memory",
3825
        "--snapshot_memory",
3826
        action="store_true",
3827
        help="Enables Memory Snapshot tool for memory deep dives: https://pytorch.org/blog/understanding-gpu-memory-1/",
3828
    )
3829

3830
    group_latency = parser.add_mutually_exclusive_group()
3831
    group_latency.add_argument(
3832
        "--cold-start-latency",
3833
        "--cold_start_latency",
3834
        action="store_true",
3835
        help="Use a fresh triton cachedir when running each model, to force cold-start compile.",
3836
    )
3837
    group_latency.add_argument(
3838
        "--warm-start-latency",
3839
        "--warm_start_latency",
3840
        action="store_true",
3841
        help="Run model(s) twice and preseve caches in between to enable a 'warm start' on the 2nd run",
3842
    )
3843

3844
    group_fuser = parser.add_mutually_exclusive_group()
3845
    # --nvfuser is now the default, keep the option to not break scripts
3846
    group_fuser.add_argument("--nvfuser", action="store_true", help=argparse.SUPPRESS)
3847
    group_fuser.add_argument("--nnc", action="store_true", help="enable NNC for GPUs")
3848

3849
    group_prec = parser.add_mutually_exclusive_group()
3850
    group_prec.add_argument("--float16", action="store_true", help="cast model to fp16")
3851
    group_prec.add_argument(
3852
        "--bfloat16", action="store_true", help="cast model to bf16"
3853
    )
3854
    group_prec.add_argument("--float32", action="store_true", help="cast model to fp32")
3855
    group_prec.add_argument(
3856
        "--amp", action="store_true", help="use automatic mixed precision"
3857
    )
3858
    parser.add_argument(
3859
        "--amp-dtype",
3860
        choices=("bfloat16", "float16"),
3861
        help="the data type used with automatic mixed precision",
3862
    )
3863
    group_printout = parser.add_mutually_exclusive_group()
3864
    group_printout.add_argument(
3865
        "--verbose", "-v", action="store_true", help="enable verbose debug printouts"
3866
    )
3867
    group_printout.add_argument(
3868
        "--quiet", "-q", action="store_true", help="suppress debug printouts"
3869
    )
3870

3871
    group = parser.add_mutually_exclusive_group()
3872
    group.add_argument(
3873
        "--coverage", action="store_true", help="(default) " + help(coverage_experiment)
3874
    )
3875
    group.add_argument(
3876
        "--overhead", action="store_true", help=help(overhead_experiment)
3877
    )
3878
    group.add_argument(
3879
        "--speedup-dynamo-ts",
3880
        action="store_true",
3881
        help="TorchDynamo frontend with torchscript backend",
3882
    )
3883
    group.add_argument(
3884
        "--speedup-fx2trt", action="store_true", help=help(speedup_experiment_fx2trt)
3885
    )
3886
    group.add_argument(
3887
        "--speedup-fx2trt-fp16",
3888
        action="store_true",
3889
        help=help(speedup_experiment_fx2trt),
3890
    )
3891
    group.add_argument(
3892
        "--print-fx",
3893
        action="store_true",
3894
        help="Print fx traces captured from model",
3895
    )
3896
    group.add_argument(
3897
        "--print-aten-ops",
3898
        action="store_true",
3899
        help="Print traces of aten ops captured by AOT autograd",
3900
    )
3901
    group.add_argument(
3902
        "--inductor",
3903
        action="store_true",
3904
        help="Measure speedup with TorchInductor",
3905
    )
3906
    group.add_argument(
3907
        "--quantization",
3908
        choices=[
3909
            "int8dynamic",
3910
            "int8weightonly",
3911
            "int4weightonly",
3912
            "autoquant",
3913
            "noquant",
3914
        ],
3915
        default=None,
3916
        help="Measure speedup of torchao quantization with TorchInductor baseline",
3917
    )
3918
    group.add_argument(
3919
        "--export",
3920
        action="store_true",
3921
        help="Measure pass rate with export",
3922
    )
3923
    group.add_argument(
3924
        "--export-aot-inductor",
3925
        action="store_true",
3926
        help="Measure pass rate with Export+AOTInductor",
3927
    )
3928
    group.add_argument(
3929
        "--xla", action="store_true", help="Compare TorchXLA to eager PyTorch"
3930
    )
3931
    group.add_argument(
3932
        "--torchscript-onnx",
3933
        "--torchscript_onnx",
3934
        action="store_true",
3935
        help="Measure speedup with TorchScript ONNX, i.e. `torch.onnx.export`",
3936
    )
3937
    group.add_argument(
3938
        "--torch-onnx-patch",
3939
        "--torch_onnx_patch",
3940
        action="store_true",
3941
        help="Measure speedup with dynamo ONNX patch, i.e. `torch_onnx`",
3942
    )
3943
    group.add_argument(
3944
        "--dynamo-onnx",
3945
        "--dynamo_onnx",
3946
        action="store_true",
3947
        help="Measure speedup with Dynamo ONNX, i.e. `torch.onnx.dynamo_export`",
3948
    )
3949
    group.add_argument(
3950
        "--dynamo-onnx-aot-inline",
3951
        "--dynamo_onnx_aot_inline",
3952
        action="store_true",
3953
        help="Measure speedup with Dynamo ONNX AOT Inline, i.e. `torch.onnx.dynamo_export`",
3954
    )
3955
    group.add_argument(
3956
        "--dynamo-onnx-aot-optimize",
3957
        "--dynamo_onnx_aot_optimize",
3958
        action="store_true",
3959
        help="Measure speedup with Dynamo ONNX w/ ort fusions, i.e. `torch.onnx.dynamo_export`",
3960
    )
3961
    group.add_argument(
3962
        "--backend",
3963
        choices=torch._dynamo.list_backends(exclude_tags=None),
3964
        help="measure speedup with a given backend",
3965
    )
3966
    group.add_argument("--nothing", action="store_true", help=help(null_experiment))
3967
    group.add_argument(
3968
        "--log-conv-args",
3969
        action="store_true",
3970
        help="Dump convolution input/weight/bias's shape/stride/dtype and other options to json",
3971
    )
3972
    group.add_argument(
3973
        "--recompile-profiler",
3974
        "--recompile_profiler",
3975
        action="store_true",
3976
        help="Run the dynamo recompilation profiler on each model.",
3977
    )
3978
    group.add_argument(
3979
        "--find-batch-sizes",
3980
        action="store_true",
3981
        help="finds the largest batch size that could fit on GPUs",
3982
    )
3983

3984
    mode_group = parser.add_mutually_exclusive_group(required=True)
3985
    mode_group.add_argument(
3986
        "--accuracy",
3987
        action="store_true",
3988
        help="Checks accuracy with small batch size and eval mode",
3989
    )
3990
    mode_group.add_argument(
3991
        "--performance", action="store_true", help="Measures performance speedup"
3992
    )
3993
    mode_group.add_argument(
3994
        "--tolerance",
3995
        action="store_true",
3996
        help="extracts the tolerance for each model with small batch size and eval mode",
3997
    )
3998
    run_mode_group = parser.add_mutually_exclusive_group(required=True)
3999
    run_mode_group.add_argument(
4000
        "--training",
4001
        action="store_true",
4002
        help="Performs training",
4003
    )
4004
    run_mode_group.add_argument(
4005
        "--inference", action="store_true", help="Performs inference"
4006
    )
4007
    return parser.parse_args(args)
4008

4009

4010
def process_entry(rank, runner, original_dir, args):
4011
    args.rank = rank
4012
    with maybe_init_distributed(
4013
        args.init_distributed,
4014
        rank=rank,
4015
        world_size=args.world_size,
4016
        port=args.distributed_master_port,
4017
    ):
4018
        return run(runner, args, original_dir)
4019

4020

4021
def maybe_fresh_cache(args):
4022
    cache_dir_assigned = "TORCHINDUCTOR_CACHE_DIR" in os.environ
4023
    if not cache_dir_assigned and (
4024
        args.cold_start_latency or args.warm_start_latency or args.ci
4025
    ):
4026
        return fresh_inductor_cache()
4027
    else:
4028
        return contextlib.nullcontext()
4029

4030

4031
def main(runner, original_dir=None, args=None):
4032
    if original_dir:
4033
        os.chdir(original_dir)
4034
    args = parse_args() if not args else parse_args(args)
4035
    if args.baseline:
4036
        args.baseline = os.path.abspath(args.baseline)
4037

4038
    if should_diff_branch(args):
4039
        import git
4040

4041
        # We do this here so we error out earlier if there's an issue
4042
        repo = git.Repo()
4043
        if repo.is_dirty():
4044
            raise RuntimeError(
4045
                "--diff-branch called on dirty branch. Commit, stash, or reset."
4046
            )
4047
        main_branch = repo.active_branch.name
4048
        if main_branch == args.diff_branch:
4049
            raise RuntimeError(
4050
                f"--diff-branch: current branch is same as {args.diff_branch} branch, what are you diffing?"
4051
            )
4052

4053
    with maybe_fresh_cache(args):
4054
        args.init_distributed = args.only and args.multiprocess
4055
        if args.init_distributed:
4056
            # NB: Do NOT query device count before CUDA initialization; we're
4057
            # going to overwrite CUDA_VISIBLE_DEVICES and this will result in
4058
            # https://github.com/pytorch/pytorch/issues/107300
4059
            device_count = torch.cuda.device_count()
4060
            if device_count <= 1:
4061
                log.warning(
4062
                    "The use multiprocess flag is set but there are <= 1 devices available."
4063
                )
4064
            # multiprocess path
4065
            args.world_size = device_count
4066
            mp.spawn(
4067
                process_entry, args=(runner, original_dir, args), nprocs=device_count
4068
            )
4069
        elif args.only and args.warm_start_latency:
4070
            # Warm start mode. Enable FX graph caching and perform back-to-back runs in
4071
            # separate processes (but ensure the inductor cache is preserved across runs).
4072
            env = os.environ.copy()
4073
            env["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
4074
            cmd = [sys.executable] + sys.argv
4075
            cmd.remove("--warm-start-latency")
4076

4077
            print(f"Performing cold-start run for {args.only}")
4078
            warmup_cmd = cmd + ["--repeat=1", "--disable-output"]
4079
            subprocess.check_call(warmup_cmd, timeout=args.timeout, env=env)
4080

4081
            print(f"Performing warm-start run for {args.only}")
4082
            subprocess.check_call(cmd, timeout=args.timeout, env=env)
4083
        else:
4084
            # single process path just uses the main process
4085
            args.world_size = 1
4086
            process_entry(0, runner, original_dir, args)
4087

4088

4089
def write_csv_when_exception(args, name: str, status: str, device=None):
4090
    print(status)
4091
    placeholder_batch_size = 0
4092
    devices = [device] if device is not None else args.devices
4093
    if args.accuracy:
4094
        headers = ["dev", "name", "batch_size", "accuracy"]
4095
        rows = [[device, name, placeholder_batch_size, status] for device in devices]
4096
    elif args.performance:
4097
        headers = ["dev", "name", "batch_size", "speedup", "abs_latency"]
4098
        rows = [[device, name, placeholder_batch_size, 0.0, 0.0] for device in devices]
4099
    else:
4100
        headers = []
4101
        rows = [[device, name, placeholder_batch_size, 0.0] for device in devices]
4102

4103
    for row in rows:
4104
        output_csv(output_filename, headers, row)
4105

4106

4107
def run(runner, args, original_dir=None):
4108
    # Pass the parsed args object to benchmark runner object
4109
    runner.args = args
4110

4111
    args.filter = args.filter or [r"."]
4112
    args.exclude = args.exclude or [r"^$"]
4113
    args.exclude_exact = args.exclude_exact or []
4114

4115
    if args.inductor:
4116
        assert args.backend is None
4117
        args.backend = "inductor"
4118
    if args.quantization:
4119
        assert args.backend is None
4120
        args.backend = "torchao"
4121
    if args.dynamic_batch_only:
4122
        args.dynamic_shapes = True
4123
        torch._dynamo.config.assume_static_by_default = True
4124
    if args.dynamic_shapes:
4125
        if not args.dynamic_batch_only:
4126
            torch._dynamo.config.assume_static_by_default = False
4127
    if args.propagate_real_tensors:
4128
        # TODO: Separate flag for data dependent
4129
        torch._dynamo.config.capture_scalar_outputs = True
4130
        torch._dynamo.config.capture_dynamic_output_shape_ops = True
4131
        torch._functorch.config.fake_tensor_propagate_real_tensors = True
4132
    if args.specialize_int:
4133
        torch._dynamo.config.specialize_int = True
4134
    if args.ci:
4135
        if args.accuracy:
4136
            # Run fewer iterations when checking accuracy
4137
            args.repeat = min(args.repeat, 2)
4138

4139
            # Set translation validation on by default on CI accuracy runs.
4140
            torch.fx.experimental._config.translation_validation = True
4141

4142
        ci = functools.partial(
4143
            CI, args.backend, training=args.training, dynamic=args.dynamic_shapes
4144
        )
4145
    if args.ddp:
4146
        assert args.training, "DDP benchmark requires --training mode"
4147
        torch._dynamo.config.optimize_ddp = args.optimize_ddp_mode
4148
        if args.only == "dlrm":
4149
            log.error(
4150
                "DLRM+DDP is unsupported as it requires sharding the embedding layer separately from DDP"
4151
            )
4152
            return sys.exit(-1)
4153
    if args.accuracy:
4154
        # Use small batch size. We use >1 batch size to ensure we test
4155
        # batch_norm type of operators that work on batch dims.
4156
        # TODO - Go through the failures for batch size = 2
4157
        if args.batch_size is None:
4158
            if runner.suite_name == "huggingface":
4159
                args.batch_size = 1
4160
            elif runner.suite_name == "torchbench":
4161
                args.batch_size = 4
4162
            else:
4163
                # Larger batch size of TIMM models to have stable batch_norm
4164
                assert runner.suite_name == "timm_models"
4165
                args.batch_size = 8
4166

4167
        # Remove sources of randomness
4168
        if runner.suite_name not in ("timm_models", "huggingface"):
4169
            # TODO - Using train mode for timm_models and HF models. Move to train mode for Torchbench as well.
4170
            args.use_eval_mode = True
4171
        inductor_config.fallback_random = True
4172
        if args.only is not None and args.only not in {
4173
            "alexnet",
4174
            "Background_Matting",
4175
            "pytorch_CycleGAN_and_pix2pix",
4176
            "pytorch_unet",
4177
            "Super_SloMo",
4178
            "vgg16",
4179
            # https://github.com/pytorch/pytorch/issues/96724
4180
            "Wav2Vec2ForCTC",
4181
            "Wav2Vec2ForPreTraining",
4182
            "sam",
4183
            "sam_fast",
4184
            "resnet50_quantized_qat",
4185
            "mobilenet_v2_quantized_qat",
4186
        }:
4187
            # some of the models do not support use_deterministic_algorithms
4188
            torch.use_deterministic_algorithms(True)
4189
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
4190
        torch.backends.cudnn.deterministic = True
4191
        torch.backends.cudnn.allow_tf32 = False
4192
        torch.backends.cudnn.benchmark = False
4193
        torch.backends.cuda.matmul.allow_tf32 = False
4194

4195
        torch.backends.mkldnn.deterministic = True
4196

4197
        # Remove randomeness when torch manual seed is called
4198
        patch_torch_manual_seed()
4199

4200
        # Some models e.g. yolov3 assert batch size on n_gpus
4201
        if "CUDA_VISIBLE_DEVICES" not in os.environ and not args.multiprocess:
4202
            args.device_index = "0"
4203

4204
        # Stricter check to disable fallbacks
4205
        args.suppress_errors = False
4206

4207
    if args.device_index is not None:
4208
        if args.multiprocess:
4209
            print("Cannot specify both --device_index and --multiprocess")
4210
            return sys.exit(-1)
4211
        os.environ["CUDA_VISIBLE_DEVICES"] = args.device_index
4212

4213
    elif args.performance:
4214
        # Ensure that we test on real scenarios
4215
        args.use_eval_mode = False
4216

4217
    if args.partition_id > args.total_partitions or args.partition_id < 0:
4218
        print("Invalid partition id")
4219
        return sys.exit(-1)
4220

4221
    if not args.devices:
4222
        if torch.cuda.is_available():
4223
            args.devices = ["cuda"]
4224
        else:
4225
            log.warning("torch.cuda.is_available() == False, using CPU")
4226
            args.devices = ["cpu"]
4227

4228
    if args.devices != ["cpu"] and (HAS_CUDA or HAS_XPU):
4229
        global synchronize
4230
        synchronize = torch.cuda.synchronize if HAS_CUDA else torch.xpu.synchronize
4231

4232
    if (
4233
        args.devices == ["cuda"]
4234
        and torch.cuda.get_device_properties(0).total_memory < 25 * 2**30
4235
    ):
4236
        # OOM errors on an RTX 3090 with 24gb RAM
4237
        runner.skip_models.update(
4238
            {
4239
                # torchbench
4240
                "hf_Longformer",
4241
                "timm_nfnet",
4242
                "timm_efficientdet",
4243
            }
4244
        )
4245
        if args.training:
4246
            runner.skip_models.add("hf_T5")
4247

4248
    if args.nnc:
4249
        torch._C._jit_override_can_fuse_on_cpu(True)
4250
        torch._C._jit_override_can_fuse_on_gpu(True)
4251
        torch._C._jit_set_texpr_fuser_enabled(True)
4252
        torch._C._jit_set_nvfuser_enabled(False)
4253

4254
    if args.threads:
4255
        torch.set_num_threads(args.threads)
4256

4257
    if args.verbose:
4258
        torch._logging.set_logs(dynamo=logging.DEBUG)
4259

4260
    if args.print_graph_breaks:
4261
        torch._logging.set_logs(graph_breaks=True)
4262

4263
    if args.quiet:
4264
        torch._logging.set_logs(dynamo=logging.ERROR)
4265

4266
    torch._dynamo.config.suppress_errors = args.suppress_errors
4267

4268
    if args.training:
4269
        runner.model_iter_fn = runner.forward_and_backward_pass
4270
        runner.skip_models.update(runner.skip_not_suitable_for_training_models)
4271
    else:
4272
        runner.model_iter_fn = runner.forward_pass
4273

4274
    if args.fast:
4275
        runner.skip_models.update(runner.slow_models)
4276

4277
    if args.devices == ["cpu"]:
4278
        runner.skip_models.update(runner.very_slow_models)
4279
        runner.skip_models.update(runner.skip_models_for_cpu)
4280
    elif args.devices == ["cuda"]:
4281
        runner.skip_models.update(runner.skip_models_for_cuda)
4282

4283
    if not args.multiprocess:
4284
        runner.skip_models.update(runner.skip_multiprocess_models)
4285

4286
    if args.freezing:
4287
        runner.skip_models.update(runner.skip_models_for_freezing)
4288

4289
    if args.no_skip:
4290
        runner.skip_models.clear()
4291

4292
    experiment = null_experiment
4293
    global current_name, current_device, current_batch_size, output_filename, disable_output, optimize_ctx, current_onnx_compiler
4294
    optimize_ctx = contextlib.nullcontext()
4295

4296
    if args.disable_output:
4297
        disable_output = True
4298

4299
    if args.overhead:
4300
        optimize_ctx = torch._dynamo.optimize(dummy_fx_compile, nopython=args.nopython)
4301
        experiment = speedup_experiment
4302
        output_filename = "overheads.csv"
4303
    elif args.inductor:
4304
        inductor_config.debug = args.verbose
4305
        if args.threads:
4306
            inductor_config.cpp.threads = args.threads
4307

4308
        optimize_ctx = functools.partial(
4309
            torch.compile,
4310
            backend="inductor",
4311
            fullgraph=args.nopython,
4312
            mode=args.inductor_compile_mode,
4313
        )
4314
        experiment = speedup_experiment
4315
        output_filename = "inductor.csv"
4316
    elif args.export:
4317
        optimize_ctx = export
4318
        experiment = speedup_experiment
4319
        output_filename = "export.csv"
4320
    elif args.xla:
4321
        (dev,) = args.devices
4322
        os.environ["PJRT_DEVICE"] = {"cuda": "GPU", "cpu": "CPU"}[dev]
4323
        torch._dynamo.mark_dynamic = MagicMock()
4324
        experiment = xla
4325
        output_filename = "xla.csv"
4326
    elif args.torchscript_onnx:
4327
        optimize_ctx = functools.partial(
4328
            optimize_onnx_ctx,
4329
            args.output_directory or ".",
4330
            OnnxModelFromTorchScript,
4331
            copy_before_export=args.performance,  # Accuarcy bench already did deepcopy
4332
        )
4333
        experiment = speedup_experiment_onnx
4334
        output_filename = "torchscript_onnx.csv"
4335
        current_onnx_compiler = "torchscript"
4336
    elif args.torch_onnx_patch:
4337
        optimize_ctx = functools.partial(
4338
            optimize_onnx_ctx,
4339
            args.output_directory or ".",
4340
            OnnxModelFromTorchScript,
4341
            copy_before_export=args.performance,
4342
            use_experimental_patch=True,
4343
        )
4344
        experiment = speedup_experiment_onnx
4345
        output_filename = "torch_onnx_patch.csv"
4346
        current_onnx_compiler = "dynamo"
4347
    elif args.dynamo_onnx:
4348
        optimize_ctx = functools.partial(
4349
            optimize_onnx_ctx,
4350
            args.output_directory or ".",
4351
            OnnxModelFromDynamo,
4352
            dynamic_shapes=args.dynamic_shapes,
4353
            copy_before_export=args.performance,
4354
        )
4355
        experiment = speedup_experiment_onnx
4356
        output_filename = "dynamo_onnx.csv"
4357
        current_onnx_compiler = "dynamo"
4358
    elif args.dynamo_onnx_aot_inline:
4359
        optimize_ctx = functools.partial(
4360
            optimize_onnx_ctx,
4361
            args.output_directory or ".",
4362
            OnnxModelFromDynamoAotInline,
4363
            dynamic_shapes=args.dynamic_shapes,
4364
            copy_before_export=args.performance,
4365
        )
4366
        experiment = speedup_experiment_onnx
4367
        output_filename = "dynamo_onnx_aot_inline.csv"
4368
        current_onnx_compiler = "dynamo"
4369
    elif args.dynamo_onnx_aot_optimize:
4370
        optimize_ctx = functools.partial(
4371
            optimize_onnx_ctx,
4372
            args.output_directory or ".",
4373
            OnnxModelFromDynamoAotOptimize,
4374
            dynamic_shapes=args.dynamic_shapes,
4375
            copy_before_export=args.performance,
4376
        )
4377
        experiment = speedup_experiment_onnx
4378
        output_filename = "dynamo_onnx_aot_optimize.csv"
4379
        current_onnx_compiler = "dynamo"
4380
    elif args.speedup_dynamo_ts:
4381
        optimize_ctx = torch._dynamo.optimize("ts", nopython=args.nopython)
4382
        experiment = speedup_experiment
4383
        output_filename = "speedup_dynamo_ts.csv"
4384
    elif args.prims_nvfuser:
4385
        optimize_ctx = torch._dynamo.optimize("prims_nvfuser", nopython=args.nopython)
4386
        experiment = speedup_experiment
4387
        backend_str = "prims_nvfuser"
4388
        output_filename = f"accuracy_aot_{backend_str}.csv"
4389
    elif args.print_fx:
4390
        optimize_ctx = torch._dynamo.optimize(
4391
            print_fx,
4392
            nopython=args.nopython,
4393
        )
4394
    elif args.print_aten_ops:
4395
        optimize_ctx = torch._dynamo.optimize(
4396
            print_aten_ops,
4397
            nopython=args.nopython,
4398
        )
4399
    elif args.nothing:
4400
        optimize_ctx = nothing
4401
        experiment = speedup_experiment
4402
        output_filename = "nothing.csv"
4403
    elif args.backend or args.export_aot_inductor:
4404
        if args.export_aot_inductor:
4405
            assert not args.training, "AOTInductor only supports inference"
4406
            optimize_ctx = functools.partial(
4407
                export_aot_inductor, device=args.devices[0]
4408
            )
4409

4410
            # AOTInductor doesn't support control flow yet
4411
            runner.skip_models.update(runner.skip_models_due_to_control_flow)
4412
        elif args.backend == "torchao":
4413
            assert "cuda" in args.devices, "Quantization requires CUDA device."
4414
            assert args.bfloat16, "Quantization requires dtype bfloat16."
4415
            try:
4416
                from torchao_backend import setup_baseline, torchao_optimize_ctx
4417
            except ImportError:
4418
                try:
4419
                    from .torchao_backend import setup_baseline, torchao_optimize_ctx
4420
                except ImportError:
4421
                    from userbenchmark.dynamo.dynamobench.torchao_backend import (
4422
                        setup_baseline,
4423
                        torchao_optimize_ctx,
4424
                    )
4425

4426
            setup_baseline()
4427
            baseline_ctx = functools.partial(
4428
                torch.compile,
4429
                backend="inductor",
4430
                fullgraph=args.nopython,
4431
                mode=args.inductor_compile_mode,
4432
            )
4433
            model_iter_fn = baseline_ctx(runner.model_iter_fn)
4434

4435
            # needed to avoid error that causes inconsistent timing due to:
4436
            # Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards
4437
            def model_iter_fn_and_mark_step(*args, **kwargs):
4438
                torch.compiler.cudagraph_mark_step_begin()
4439
                model_iter_fn(*args, **kwargs)
4440

4441
            runner.model_iter_fn = model_iter_fn_and_mark_step
4442
            optimize_ctx = torchao_optimize_ctx(args.quantization)
4443
        else:
4444
            optimize_ctx = torch._dynamo.optimize(args.backend, nopython=args.nopython)
4445
        experiment = (
4446
            speedup_experiment if not args.backend == "torchao" else latency_experiment
4447
        )
4448
        if args.accuracy:
4449
            output_filename = f"accuracy_{args.backend}.csv"
4450
        elif args.tolerance:
4451
            output_filename = f"tolerance_{args.backend}.csv"
4452
        else:
4453
            output_filename = f"speedup_{args.backend}.csv"
4454
    elif args.recompile_profiler:
4455
        output_filename = "recompile_profiler_log.csv"
4456
        experiment = recompile_profiler_experiment
4457
    else:
4458
        optimize_ctx = torch._dynamo.optimize(
4459
            fx_insert_profiling, nopython=args.nopython
4460
        )
4461
        experiment = coverage_experiment
4462
        output_filename = "coverage.csv"
4463

4464
    if args.inductor or args.backend == "inductor" or args.export_aot_inductor:
4465
        inductor_config.triton.cudagraphs = not args.disable_cudagraphs
4466
        inductor_config.triton.persistent_reductions = (
4467
            not args.disable_persistent_reductions
4468
        )
4469
        inductor_config.split_reductions = not args.disable_split_reductions
4470
        inductor_config.triton.divisible_by_16 = not args.disable_divisible_by_16
4471
        if args.inference:
4472
            inductor_config.freezing = args.freezing
4473
        if args.inductor_config:
4474
            for config in args.inductor_config:
4475
                key, value = config.split("=")
4476
                typ = type(inductor_config.__getattr__(key))
4477
                if issubclass(typ, bool):
4478
                    assert value in ("0", "1", "True", "False")
4479
                    value = value in ("1", "True")
4480
                elif issubclass(typ, (str, int, float)):
4481
                    value = typ(value)
4482
                else:
4483
                    raise NotImplementedError(typ)
4484
                inductor_config.__setattr__(key, value)
4485

4486
    runner.setup_amp()
4487

4488
    if args.output:
4489
        output_filename = args.output
4490

4491
    if output_filename:
4492
        if args.output_directory:
4493
            output_filename = os.path.join(args.output_directory, output_filename)
4494
        else:
4495
            output_filename = os.path.join(
4496
                torch._dynamo.config.base_dir, output_filename
4497
            )
4498

4499
    if args.find_batch_sizes and args.only:
4500
        for device in args.devices:
4501
            batch_size = runner.batch_size_finder(device, args.only)
4502
            print(args.only, batch_size)
4503
            output_csv(output_filename, [], [args.only, batch_size])
4504
        return
4505

4506
    if args.export_profiler_trace:
4507
        if args.profiler_trace_name is None:
4508
            if args.backend:
4509
                args.profiler_trace_name = args.backend
4510
            elif args.inductor:
4511
                args.profiler_trace_name = "inductor"
4512
            else:
4513
                args.profiler_trace_name = "profile"
4514
        else:
4515
            args.profiler_trace_name = args.profiler_trace_name
4516

4517
    if args.no_translation_validation:
4518
        # Overwrite 'translation_validation' config, if specified.
4519
        torch.fx.experimental._config.translation_validation = False
4520

4521
    experiment = functools.partial(experiment, args, runner.model_iter_fn)
4522

4523
    if args.only and should_diff_branch(args):
4524
        import git
4525

4526
        repo = git.Repo()
4527
        main_branch = repo.active_branch.name
4528
        try:
4529
            # Adding diff-branch again to the args will override previous value
4530
            call_args = (
4531
                [sys.executable] + sys.argv + [f"--diff-branch={diff_branch_default}"]
4532
            )
4533
            # Run for main branch
4534
            subprocess.check_call(call_args + [f"--tag={main_branch}"])
4535
            # Run for comparison branch
4536
            repo.git.checkout(args.diff_branch)
4537
            subprocess.check_call(call_args + [f"--tag={args.diff_branch}"])
4538
        finally:
4539
            # Go back to main branch
4540
            repo.git.checkout(main_branch)
4541
    elif args.only:
4542
        model_name = args.only
4543
        for device in args.devices:
4544
            batch_size = args.batch_size
4545
            if args.batch_size_file:
4546
                batch_size = read_batch_size_from_file(
4547
                    args, args.batch_size_file, model_name
4548
                )
4549
            if model_specified_by_path(args.only):
4550
                model, example_inputs = load_model_from_path(args.only)
4551
                name = model.__class__.__name__
4552
                model = model.to(device=device)
4553
                example_inputs = tree_map_only(
4554
                    torch.Tensor, lambda x: x.to(device=device), example_inputs
4555
                )
4556
            else:
4557
                name = model_name
4558
                try:
4559
                    with tqdm(desc="loading model"):
4560
                        extra_args = []
4561
                        if hasattr(args, "rank") and hasattr(args, "world_size"):
4562
                            extra_args += [
4563
                                "--rank",
4564
                                str(args.rank),
4565
                                "--world_size",
4566
                                str(args.world_size),
4567
                            ]
4568

4569
                        if args.part:
4570
                            (
4571
                                device,
4572
                                name,
4573
                                model,
4574
                                example_inputs,
4575
                                batch_size,
4576
                            ) = runner.load_model(
4577
                                device,
4578
                                model_name,
4579
                                batch_size=batch_size,
4580
                                part=args.part,
4581
                                extra_args=extra_args,
4582
                            )
4583
                        else:
4584
                            if args.fsdp:
4585
                                # Always load model on cpu for fsdp
4586
                                # When initializing FSDP, we will use the cuda device if args.cuda is set
4587
                                (
4588
                                    _,
4589
                                    name,
4590
                                    model,
4591
                                    example_inputs,
4592
                                    batch_size,
4593
                                ) = runner.load_model(
4594
                                    "cpu",
4595
                                    model_name,
4596
                                    batch_size=batch_size,
4597
                                    extra_args=extra_args,
4598
                                )
4599
                            else:
4600
                                (
4601
                                    device,
4602
                                    name,
4603
                                    model,
4604
                                    example_inputs,
4605
                                    batch_size,
4606
                                ) = runner.load_model(
4607
                                    device,
4608
                                    model_name,
4609
                                    batch_size=batch_size,
4610
                                    extra_args=extra_args,
4611
                                )
4612
                except Exception as e:
4613
                    import traceback
4614

4615
                    mode = "train" if args.training else "eval"
4616
                    print(f"{device:4} {mode:5} {name:34} ")
4617
                    print(traceback.format_exc())
4618
                    status = (
4619
                        "model_fail_to_load"
4620
                        if isinstance(e, NotImplementedError)
4621
                        else "eager_fail_to_run"
4622
                    )
4623
                    write_csv_when_exception(args, name, status, device)
4624
                    continue  # bad benchmark implementation
4625

4626
            if args.trace_on_xla:
4627
                xla_dev = xm.xla_device()
4628
                model = model.to(device=xla_dev)
4629
                example_inputs = tree_map_only(
4630
                    torch.Tensor, lambda x: x.to(device=xla_dev), example_inputs
4631
                )
4632

4633
            current_name = name
4634
            current_device = device
4635
            current_batch_size = batch_size
4636
            set_model_name(name)
4637

4638
            # Look for stuff that looks like batch size, and mark it dynamic.
4639
            # Better integration would integrate directly with benchmark suite
4640
            # but cannot conveniently do this
4641
            # NB: This must be done late enough so that we don't do more
4642
            # conversions on the inputs
4643
            # NB: Assumes only the first batch-y like dimension is the batch
4644
            marked = False
4645

4646
            def detect_and_mark_batch(t):
4647
                nonlocal marked
4648
                for i, s in enumerate(t.size()):
4649
                    if s == batch_size:
4650
                        torch._dynamo.mark_dynamic(t, i)
4651
                        marked = True
4652
                        break
4653

4654
            if (
4655
                args.dynamic_batch_only
4656
                and batch_size > 1
4657
                and model_name not in CI_SKIP_DYNAMIC_BATCH_ONLY
4658
            ):
4659
                tree_map_only(torch.Tensor, detect_and_mark_batch, example_inputs)
4660
                assert marked, f"nothing in example_inputs had a dim with {batch_size}"
4661

4662
            if args.log_operator_inputs:
4663
                log_operator_inputs(
4664
                    model, example_inputs, runner.model_iter_fn, name, args
4665
                )
4666
                continue
4667

4668
            if args.per_process_memory_fraction != 1:
4669
                torch.cuda.set_per_process_memory_fraction(
4670
                    args.per_process_memory_fraction
4671
                )
4672
            if model_name in DO_NOT_CAST_INPUTS:
4673
                model, _ = runner.cast_based_on_args(model, example_inputs)
4674

4675
            else:
4676
                model, example_inputs = runner.cast_based_on_args(model, example_inputs)
4677
            runner.setup_amp(current_device)
4678
            guard_ctx = contextlib.nullcontext()
4679
            if name in runner.guard_on_nn_module_models:
4680
                guard_ctx = torch._dynamo.config.patch(guard_nn_modules=True)
4681

4682
            inline_ctx = contextlib.nullcontext()
4683
            if name in runner.inline_inbuilt_nn_modules_models:
4684
                inline_ctx = torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
4685

4686
            with guard_ctx:
4687
                with inline_ctx:
4688
                    runner.run_one_model(
4689
                        name,
4690
                        model,
4691
                        example_inputs,
4692
                        optimize_ctx,
4693
                        experiment,
4694
                        explain=args.explain,
4695
                        tag=args.tag,
4696
                    )
4697
        if args.generate_aot_autograd_stats:
4698
            stats_file = output_filename.split(".csv")[0] + "_stats.csv"
4699
            output_csv(
4700
                stats_file,
4701
                ("dev", "name", "batch_size", "total_aot_graphs", "ok_aot_graphs"),
4702
                [
4703
                    current_device,
4704
                    current_name,
4705
                    current_batch_size,
4706
                    *Stats.aot_summary(),
4707
                ],
4708
            )
4709
    else:
4710
        metrics.purge_old_log_files()
4711
        if output_filename and os.path.exists(output_filename):
4712
            os.unlink(output_filename)
4713
        if original_dir:
4714
            os.chdir(original_dir)
4715
        model_names = list(runner.iter_model_names(args))
4716
        nmodels = len(model_names)
4717
        for i, name in enumerate(model_names):
4718
            current_name = name
4719
            if args.progress:
4720
                print(f"Running model {i+1}/{nmodels}", flush=True)
4721

4722
            try:
4723
                timeout = args.timeout
4724
                if should_diff_branch(args):
4725
                    timeout *= 2
4726
                env = os.environ.copy()
4727
                if args.ci and name in CI_PRESERVE_COMPILE_DEBUG:
4728
                    env["TORCH_COMPILE_DEBUG"] = "1"
4729
                subprocess.check_call(
4730
                    [sys.executable] + sys.argv + [f"--only={name}"],
4731
                    timeout=timeout,
4732
                    env=env,
4733
                )
4734
            except subprocess.TimeoutExpired:
4735
                write_csv_when_exception(args, name, "timeout")
4736
            except subprocess.CalledProcessError as e:
4737
                print("Run failed with return code: ", e.returncode, file=sys.stderr)
4738
                print("Output: ", e.output, file=sys.stderr)
4739
                print("Error: ", e.stderr, file=sys.stderr)
4740
        print_summary(output_filename, print_dataframe=args.print_dataframe_summary)
4741

4742

4743
def log_operator_inputs(model, example_inputs, model_iter_fn, name, args):
4744
    mode = "training" if args.training else "eval"
4745
    output = os.path.join(os.path.dirname(args.output), f"{name}_{mode}.txt")
4746

4747
    # TODO - add option for coalescing inputs over multiple runs
4748
    if os.path.exists(output):
4749
        print(f"Skipping {name}, {output} already exists")
4750
        return
4751

4752
    print(f"Running {name}")
4753
    try:
4754
        from .microbenchmarks.operator_inp_utils import OperatorInputsMode
4755
    except ImportError:
4756
        from microbenchmarks.operator_inp_utils import OperatorInputsMode
4757

4758
    operator_mode = OperatorInputsMode()
4759
    fake_tensor_mode = FakeTensorMode()
4760

4761
    with torch._subclasses.fake_tensor.FakeCopyMode(fake_tensor_mode):
4762
        model_fake = copy.deepcopy(model)
4763
        example_inputs_fake = copy.deepcopy(example_inputs)
4764
    try:
4765
        with fake_tensor_mode, operator_mode:
4766
            model_iter_fn(model_fake, example_inputs_fake, collect_outputs=False)
4767
    except Exception as e:
4768
        print(f"{name} failed to run with fake tensors, trying real. Exception: {e}")
4769
        operator_mode = OperatorInputsMode()
4770
        try:
4771
            with operator_mode:
4772
                model_iter_fn(model, example_inputs, collect_outputs=False)
4773
        except Exception as e2:
4774
            print(f"{name} failed to run with real. Exception: {e2}")
4775
            raise
4776

4777
    print(f"Writing output to {output}")
4778
    operator_mode.log_to_file(output)
4779

4780

4781
if __name__ == "__main__":
4782
    raise RuntimeError(
4783
        f"You shouldn't run {sys.argv[0]} directly, instead try timm_model.py, torchbench.py or huggingface.py"
4784
    )
4785

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.