pytorch

Форк
0
/
run_test.py 
1888 строк · 66.1 Кб
1
#!/usr/bin/env python3
2

3
import argparse
4
import copy
5
import glob
6
import json
7
import os
8
import re
9
import shutil
10
import signal
11
import subprocess
12
import sys
13
import tempfile
14
import time
15
from collections import defaultdict
16
from contextlib import ExitStack
17
from datetime import datetime
18
from pathlib import Path
19
from typing import Any, cast, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
20

21
import pkg_resources
22

23
import torch
24
import torch.distributed as dist
25
from torch.multiprocessing import current_process, get_context
26
from torch.testing._internal.common_utils import (
27
    get_report_path,
28
    IS_CI,
29
    IS_MACOS,
30
    IS_WINDOWS,
31
    retry_shell,
32
    set_cwd,
33
    shell,
34
    TEST_CUDA,
35
    TEST_WITH_ASAN,
36
    TEST_WITH_CROSSREF,
37
    TEST_WITH_ROCM,
38
    TEST_WITH_SLOW_GRADCHECK,
39
)
40

41

42
# using tools/ to optimize test run.
43
REPO_ROOT = Path(__file__).resolve().parent.parent
44
sys.path.insert(0, str(REPO_ROOT))
45

46
from tools.stats.import_test_stats import (
47
    ADDITIONAL_CI_FILES_FOLDER,
48
    TEST_CLASS_TIMES_FILE,
49
    TEST_TIMES_FILE,
50
)
51
from tools.stats.upload_metrics import add_global_metric, emit_metric
52
from tools.testing.discover_tests import (
53
    CPP_TEST_PATH,
54
    CPP_TEST_PREFIX,
55
    CPP_TESTS_DIR,
56
    parse_test_module,
57
    TESTS,
58
)
59
from tools.testing.do_target_determination_for_s3 import import_results
60
from tools.testing.target_determination.gen_artifact import gen_ci_artifact
61
from tools.testing.target_determination.heuristics.previously_failed_in_pr import (
62
    gen_additional_test_failures_file,
63
)
64
from tools.testing.target_determination.heuristics.utils import get_pr_number
65
from tools.testing.test_run import TestRun
66
from tools.testing.test_selections import (
67
    calculate_shards,
68
    get_test_case_configs,
69
    NUM_PROCS,
70
    ShardedTest,
71
    THRESHOLD,
72
)
73

74

75
# Make sure to remove REPO_ROOT after import is done
76
sys.path.remove(str(REPO_ROOT))
77

78

79
HAVE_TEST_SELECTION_TOOLS = True
80
TEST_CONFIG = os.getenv("TEST_CONFIG", "")
81
BUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "")
82
RERUN_DISABLED_TESTS = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1"
83
DISTRIBUTED_TEST_PREFIX = "distributed"
84
INDUCTOR_TEST_PREFIX = "inductor"
85
IS_SLOW = "slow" in TEST_CONFIG or "slow" in BUILD_ENVIRONMENT
86

87

88
# Note [ROCm parallel CI testing]
89
# https://github.com/pytorch/pytorch/pull/85770 added file-granularity parallel testing.
90
# In .ci/pytorch/test.sh, TEST_CONFIG == "default", CUDA and HIP_VISIBLE_DEVICES is set to 0.
91
# This results in multiple test files sharing the same GPU.
92
# This should be a supported use case for ROCm, but it exposed issues in the kernel driver resulting in hangs.
93
# See https://github.com/pytorch/pytorch/issues/90940.
94
#
95
# Further, ROCm self-hosted runners have up to 4 GPUs.
96
# Device visibility was set to 0 to match CUDA test behavior, but this was wasting available GPU resources.
97
# Assigning each Pool worker their own dedicated GPU avoids the ROCm oversubscription issues.
98
# This should also result in better overall wall clock time since all GPUs can be utilized.
99
def maybe_set_hip_visible_devies():
100
    # Special handling of ROCm GHA runners for parallel (file granularity) tests.
101
    if torch.version.hip:
102
        p = current_process()
103
        if p.name != "MainProcess":
104
            # this is a Process from a parallel Pool, not the MainProcess
105
            os.environ["HIP_VISIBLE_DEVICES"] = str(p._identity[0] % NUM_PROCS)
106

107

108
def strtobool(s):
109
    return s.lower() not in {"", "0", "false", "off"}
110

111

112
class TestChoices(list):
113
    def __init__(self, *args, **kwargs):
114
        super().__init__(args[0])
115

116
    def __contains__(self, item):
117
        return list.__contains__(self, parse_test_module(item))
118

119

120
FSDP_TEST = [test for test in TESTS if test.startswith("distributed/fsdp")]
121

122
WINDOWS_BLOCKLIST = [
123
    "distributed/nn/jit/test_instantiator",
124
    "distributed/rpc/test_faulty_agent",
125
    "distributed/rpc/test_tensorpipe_agent",
126
    "distributed/rpc/test_share_memory",
127
    "distributed/rpc/cuda/test_tensorpipe_agent",
128
    "distributed/pipeline/sync/skip/test_api",
129
    "distributed/pipeline/sync/skip/test_gpipe",
130
    "distributed/pipeline/sync/skip/test_inspect_skip_layout",
131
    "distributed/pipeline/sync/skip/test_leak",
132
    "distributed/pipeline/sync/skip/test_portal",
133
    "distributed/pipeline/sync/skip/test_stash_pop",
134
    "distributed/pipeline/sync/skip/test_tracker",
135
    "distributed/pipeline/sync/skip/test_verify_skippables",
136
    "distributed/pipeline/sync/test_balance",
137
    "distributed/pipeline/sync/test_bugs",
138
    "distributed/pipeline/sync/test_checkpoint",
139
    "distributed/pipeline/sync/test_copy",
140
    "distributed/pipeline/sync/test_deferred_batch_norm",
141
    "distributed/pipeline/sync/test_dependency",
142
    "distributed/pipeline/sync/test_inplace",
143
    "distributed/pipeline/sync/test_microbatch",
144
    "distributed/pipeline/sync/test_phony",
145
    "distributed/pipeline/sync/test_pipe",
146
    "distributed/pipeline/sync/test_pipeline",
147
    "distributed/pipeline/sync/test_stream",
148
    "distributed/pipeline/sync/test_transparency",
149
    "distributed/pipeline/sync/test_worker",
150
    "distributed/elastic/agent/server/test/api_test",
151
    "distributed/elastic/multiprocessing/api_test",
152
    "distributed/_shard/checkpoint/test_checkpoint"
153
    "distributed/_shard/checkpoint/test_file_system_checkpoint"
154
    "distributed/_shard/sharding_spec/test_sharding_spec",
155
    "distributed/_shard/sharding_plan/test_sharding_plan",
156
    "distributed/_shard/sharded_tensor/test_sharded_tensor",
157
    "distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
158
    "distributed/_shard/sharded_tensor/ops/test_embedding",
159
    "distributed/_shard/sharded_tensor/ops/test_embedding_bag",
160
    "distributed/_shard/sharded_tensor/ops/test_binary_cmp",
161
    "distributed/_shard/sharded_tensor/ops/test_init",
162
    "distributed/_shard/sharded_optim/test_sharded_optim",
163
] + FSDP_TEST
164

165
ROCM_BLOCKLIST = [
166
    "distributed/rpc/test_faulty_agent",
167
    "distributed/rpc/test_tensorpipe_agent",
168
    "distributed/rpc/test_share_memory",
169
    "distributed/rpc/cuda/test_tensorpipe_agent",
170
    "distributed/_shard/checkpoint/test_checkpoint"
171
    "distributed/_shard/checkpoint/test_file_system_checkpoint"
172
    "distributed/_shard/sharding_spec/test_sharding_spec",
173
    "distributed/_shard/sharding_plan/test_sharding_plan",
174
    "distributed/_shard/sharded_tensor/test_sharded_tensor",
175
    "distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
176
    "distributed/_shard/sharded_tensor/ops/test_embedding",
177
    "distributed/_shard/sharded_tensor/ops/test_embedding_bag",
178
    "distributed/_shard/sharded_tensor/ops/test_binary_cmp",
179
    "distributed/_shard/sharded_tensor/ops/test_init",
180
    "distributed/_shard/sharded_optim/test_sharded_optim",
181
    "test_determination",
182
    "test_jit_legacy",
183
    "test_cuda_nvml_based_avail",
184
    "test_jit_cuda_fuser",
185
    "distributed/_tensor/test_attention",
186
    "test_transformers",
187
]
188

189
XPU_BLOCKLIST = [
190
    "test_autograd",
191
    "profiler/test_cpp_thread",
192
    "profiler/test_execution_trace",
193
    "profiler/test_memory_profiler",
194
    "profiler/test_profiler",
195
    "profiler/test_profiler_tree",
196
    "profiler/test_record_function",
197
    "profiler/test_torch_tidy",
198
]
199

200
XPU_TEST = [
201
    "test_xpu",
202
]
203

204
# The tests inside these files should never be run in parallel with each other
205
RUN_PARALLEL_BLOCKLIST = [
206
    "test_cpp_extensions_jit",
207
    "test_cpp_extensions_open_device_registration",
208
    "test_cpp_extensions_stream_and_event",
209
    "test_cpp_extensions_mtia_backend",
210
    "test_jit_disabled",
211
    "test_mobile_optimizer",
212
    "test_multiprocessing",
213
    "test_multiprocessing_spawn",
214
    "test_namedtuple_return_api",
215
    "test_overrides",
216
    "test_show_pickle",
217
    "test_tensorexpr",
218
    "test_cuda_primary_ctx",
219
    "test_cuda_trace",
220
    "inductor/test_benchmark_fusion",
221
    "test_cuda_nvml_based_avail",
222
    # temporarily sets a global config
223
    "test_autograd_fallback",
224
] + FSDP_TEST
225

226
# Test files that should always be run serially with other test files,
227
# but it's okay if the tests inside them are run in parallel with each other.
228
CI_SERIAL_LIST = [
229
    "test_nn",
230
    "test_fake_tensor",
231
    "test_cpp_api_parity",
232
    "test_reductions",
233
    "test_fx_backends",
234
    "test_cpp_extensions_jit",
235
    "test_torch",
236
    "test_tensor_creation_ops",
237
    "test_dispatch",
238
    "test_python_dispatch",  # torch.library creation and deletion must be serialized
239
    "test_spectral_ops",  # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/88916
240
    "nn/test_pooling",
241
    "nn/test_convolution",  # Doesn't respect set_per_process_memory_fraction, results in OOM for other tests in slow gradcheck
242
    "distributions/test_distributions",
243
    "test_fx",  # gets SIGKILL
244
    "functorch/test_memory_efficient_fusion",  # Cause CUDA OOM on ROCm
245
    "test_utils",  # OOM
246
    "test_sort_and_select",  # OOM
247
    "test_backward_compatible_arguments",  # OOM
248
    "test_autocast",  # OOM
249
    "test_native_mha",  # OOM
250
    "test_module_hooks",  # OOM
251
    "inductor/test_max_autotune",
252
    "inductor/test_cutlass_backend",  # slow due to many nvcc compilation steps,
253
    "inductor/test_flex_attention",  # OOM
254
]
255
# A subset of onnx tests that cannot run in parallel due to high memory usage.
256
ONNX_SERIAL_LIST = [
257
    "onnx/test_models",
258
    "onnx/test_models_quantized_onnxruntime",
259
    "onnx/test_models_onnxruntime",
260
    "onnx/test_custom_ops",
261
    "onnx/test_utility_funs",
262
]
263

264
# A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected
265
CORE_TEST_LIST = [
266
    "test_autograd",
267
    "test_autograd_fallback",
268
    "test_modules",
269
    "test_nn",
270
    "test_ops",
271
    "test_ops_gradients",
272
    "test_ops_fwd_gradients",
273
    "test_ops_jit",
274
    "test_torch",
275
]
276

277

278
# if a test file takes longer than 5 min, we add it to TARGET_DET_LIST
279
SLOW_TEST_THRESHOLD = 300
280

281
DISTRIBUTED_TESTS_CONFIG = {}
282

283

284
if dist.is_available():
285
    DISTRIBUTED_TESTS_CONFIG["test"] = {"WORLD_SIZE": "1"}
286
    if not TEST_WITH_ROCM and dist.is_mpi_available():
287
        DISTRIBUTED_TESTS_CONFIG["mpi"] = {
288
            "WORLD_SIZE": "3",
289
            "TEST_REPORT_SOURCE_OVERRIDE": "dist-mpi",
290
        }
291
    if dist.is_nccl_available():
292
        DISTRIBUTED_TESTS_CONFIG["nccl"] = {
293
            "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
294
            "TEST_REPORT_SOURCE_OVERRIDE": "dist-nccl",
295
        }
296
    if dist.is_gloo_available():
297
        DISTRIBUTED_TESTS_CONFIG["gloo"] = {
298
            "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
299
            "TEST_REPORT_SOURCE_OVERRIDE": "dist-gloo",
300
        }
301
    if dist.is_ucc_available():
302
        DISTRIBUTED_TESTS_CONFIG["ucc"] = {
303
            "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
304
            "TEST_REPORT_SOURCE_OVERRIDE": "dist-ucc",
305
            "UCX_TLS": "tcp,cuda",
306
            "UCC_TLS": "nccl,ucp,cuda",
307
            "UCC_TL_UCP_TUNE": "cuda:0",  # don't use UCP TL on CUDA as it is not well supported
308
            "UCC_EC_CUDA_USE_COOPERATIVE_LAUNCH": "n",  # CI nodes (M60) fail if it is on
309
        }
310

311
# https://stackoverflow.com/questions/2549939/get-signal-names-from-numbers-in-python
312
SIGNALS_TO_NAMES_DICT = {
313
    getattr(signal, n): n for n in dir(signal) if n.startswith("SIG") and "_" not in n
314
}
315

316
CPP_EXTENSIONS_ERROR = """
317
Ninja (https://ninja-build.org) is required for some of the C++ extensions
318
tests, but it could not be found. Install ninja with `pip install ninja`
319
or `conda install ninja`. Alternatively, disable said tests with
320
`run_test.py --exclude test_cpp_extensions_aot_ninja test_cpp_extensions_jit`.
321
"""
322

323
PYTORCH_COLLECT_COVERAGE = bool(os.environ.get("PYTORCH_COLLECT_COVERAGE"))
324

325
JIT_EXECUTOR_TESTS = [
326
    "test_jit_profiling",
327
    "test_jit_legacy",
328
    "test_jit_fuser_legacy",
329
]
330

331
INDUCTOR_TESTS = [test for test in TESTS if test.startswith(INDUCTOR_TEST_PREFIX)]
332
DISTRIBUTED_TESTS = [test for test in TESTS if test.startswith(DISTRIBUTED_TEST_PREFIX)]
333
TORCH_EXPORT_TESTS = [test for test in TESTS if test.startswith("export")]
334
FUNCTORCH_TESTS = [test for test in TESTS if test.startswith("functorch")]
335
ONNX_TESTS = [test for test in TESTS if test.startswith("onnx")]
336
CPP_TESTS = [test for test in TESTS if test.startswith(CPP_TEST_PREFIX)]
337

338
TESTS_REQUIRING_LAPACK = [
339
    "distributions/test_constraints",
340
    "distributions/test_distributions",
341
]
342

343
# These are just the slowest ones, this isn't an exhaustive list.
344
TESTS_NOT_USING_GRADCHECK = [
345
    # Note that you should use skipIfSlowGradcheckEnv if you do not wish to
346
    # skip all the tests in that file, e.g. test_mps
347
    "doctests",
348
    "test_meta",
349
    "test_hub",
350
    "test_fx",
351
    "test_decomp",
352
    "test_cpp_extensions_jit",
353
    "test_jit",
354
    "test_ops",
355
    "test_ops_jit",
356
    "dynamo/test_recompile_ux",
357
    "inductor/test_smoke",
358
    "test_quantization",
359
]
360

361

362
def print_to_stderr(message):
363
    print(message, file=sys.stderr)
364

365

366
def get_executable_command(options, disable_coverage=False, is_cpp_test=False):
367
    if options.coverage and not disable_coverage:
368
        if not is_cpp_test:
369
            executable = ["coverage", "run", "--parallel-mode", "--source=torch"]
370
        else:
371
            # TODO: C++ with coverage is not yet supported
372
            executable = []
373
    else:
374
        if not is_cpp_test:
375
            executable = [sys.executable, "-bb"]
376
        else:
377
            executable = ["pytest"]
378

379
    return executable
380

381

382
def run_test(
383
    test_module: ShardedTest,
384
    test_directory,
385
    options,
386
    launcher_cmd=None,
387
    extra_unittest_args=None,
388
    env=None,
389
    print_log=True,
390
) -> int:
391
    scribe_token = os.getenv("SCRIBE_GRAPHQL_ACCESS_TOKEN", "")
392
    if scribe_token:
393
        print_to_stderr("SCRIBE_GRAPHQL_ACCESS_TOKEN is set")
394
    else:
395
        print_to_stderr("SCRIBE_GRAPHQL_ACCESS_TOKEN is NOT set")
396

397
    env = env or os.environ.copy()
398
    maybe_set_hip_visible_devies()
399
    unittest_args = options.additional_args.copy()
400
    test_file = test_module.name
401
    stepcurrent_key = test_file
402

403
    is_distributed_test = test_file.startswith(DISTRIBUTED_TEST_PREFIX)
404
    is_cpp_test = test_file.startswith(CPP_TEST_PREFIX)
405
    # NB: Rerun disabled tests depends on pytest-flakefinder and it doesn't work with
406
    # pytest-cpp atm. We also don't have support to disable C++ test yet, so it's ok
407
    # to just return successfully here
408
    if is_cpp_test and RERUN_DISABLED_TESTS:
409
        print_to_stderr(
410
            "Skipping C++ tests when running under RERUN_DISABLED_TESTS mode"
411
        )
412
        return 0
413

414
    if is_cpp_test:
415
        stepcurrent_key = f"{test_file}_{os.urandom(8).hex()}"
416
    else:
417
        unittest_args.extend(
418
            [
419
                f"--shard-id={test_module.shard}",
420
                f"--num-shards={test_module.num_shards}",
421
            ]
422
        )
423
        stepcurrent_key = f"{test_file}_{test_module.shard}_{os.urandom(8).hex()}"
424

425
    if options.verbose:
426
        unittest_args.append(f'-{"v" * options.verbose}')  # in case of pytest
427

428
    if test_file in RUN_PARALLEL_BLOCKLIST:
429
        unittest_args = [
430
            arg for arg in unittest_args if not arg.startswith("--run-parallel")
431
        ]
432

433
    if extra_unittest_args:
434
        assert isinstance(extra_unittest_args, list)
435
        unittest_args.extend(extra_unittest_args)
436

437
    # If using pytest, replace -f with equivalent -x
438
    if options.pytest:
439
        unittest_args.extend(
440
            get_pytest_args(
441
                options,
442
                is_cpp_test=is_cpp_test,
443
                is_distributed_test=is_distributed_test,
444
            )
445
        )
446
        unittest_args.extend(test_module.get_pytest_args())
447
        replacement = {"-f": "-x"}
448
        unittest_args = [replacement.get(arg, arg) for arg in unittest_args]
449

450
    if options.showlocals:
451
        if options.pytest:
452
            unittest_args.extend(["--showlocals", "--tb=long", "--color=yes"])
453
        else:
454
            unittest_args.append("--locals")
455

456
    # NB: These features are not available for C++ tests, but there is little incentive
457
    # to implement it because we have never seen a flaky C++ test before.
458
    if IS_CI and not is_cpp_test:
459
        ci_args = ["--import-slow-tests", "--import-disabled-tests"]
460
        if RERUN_DISABLED_TESTS:
461
            ci_args.append("--rerun-disabled-tests")
462
        # use the downloaded test cases configuration, not supported in pytest
463
        unittest_args.extend(ci_args)
464

465
    if test_file in PYTEST_SKIP_RETRIES:
466
        if not options.pytest:
467
            raise RuntimeError(
468
                "A test running without pytest cannot skip retries using "
469
                "the PYTEST_SKIP_RETRIES set."
470
            )
471
        unittest_args = [arg for arg in unittest_args if "--reruns" not in arg]
472

473
    # Extra arguments are not supported with pytest
474
    executable = get_executable_command(options, is_cpp_test=is_cpp_test)
475
    if not executable:
476
        # If there is no eligible executable returning here, it means an unsupported
477
        # case such as coverage for C++ test. So just returning ok makes sense
478
        return 0
479

480
    if test_file.startswith(CPP_TEST_PREFIX):
481
        # C++ tests are not the regular test directory
482
        if CPP_TESTS_DIR:
483
            cpp_test = os.path.join(
484
                CPP_TESTS_DIR,
485
                test_file.replace(f"{CPP_TEST_PREFIX}/", ""),
486
            )
487
        else:
488
            cpp_test = os.path.join(
489
                Path(test_directory).parent,
490
                CPP_TEST_PATH,
491
                test_file.replace(f"{CPP_TEST_PREFIX}/", ""),
492
            )
493

494
        argv = [
495
            cpp_test if sys.platform != "win32" else cpp_test + ".exe"
496
        ] + unittest_args
497
    else:
498
        # Can't call `python -m unittest test_*` here because it doesn't run code
499
        # in `if __name__ == '__main__': `. So call `python test_*.py` instead.
500
        argv = [test_file + ".py"] + unittest_args
501

502
    os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
503
    if options.pipe_logs:
504
        log_fd, log_path = tempfile.mkstemp(
505
            dir=REPO_ROOT / "test" / "test-reports",
506
            prefix=f"{sanitize_file_name(str(test_module))}_",
507
            suffix="_toprint.log",
508
        )
509
        os.close(log_fd)
510

511
    command = (launcher_cmd or []) + executable + argv
512
    should_retry = (
513
        "--subprocess" not in command
514
        and not RERUN_DISABLED_TESTS
515
        and not is_cpp_test
516
        and "-n" not in command
517
    )
518
    timeout = (
519
        None
520
        if not options.enable_timeout
521
        else THRESHOLD * 6
522
        if IS_SLOW
523
        else THRESHOLD * 3
524
        if should_retry
525
        and isinstance(test_module, ShardedTest)
526
        and test_module.time is not None
527
        else THRESHOLD * 3
528
        if is_cpp_test
529
        else None
530
    )
531
    print_to_stderr(f"Executing {command} ... [{datetime.now()}]")
532

533
    with ExitStack() as stack:
534
        output = None
535
        if options.pipe_logs:
536
            output = stack.enter_context(open(log_path, "w"))
537

538
        if should_retry:
539
            ret_code, was_rerun = run_test_retries(
540
                command,
541
                test_directory,
542
                env,
543
                timeout,
544
                stepcurrent_key,
545
                output,
546
                options.continue_through_error,
547
            )
548
        else:
549
            command.extend([f"--sc={stepcurrent_key}", "--print-items"])
550
            ret_code, was_rerun = retry_shell(
551
                command,
552
                test_directory,
553
                stdout=output,
554
                stderr=output,
555
                env=env,
556
                timeout=timeout,
557
                retries=0,
558
            )
559

560
            # Pytest return code 5 means no test is collected. Exit code 4 is
561
            # returned when the binary is not a C++ test executable, but 4 can
562
            # also be returned if the file fails before running any tests. All
563
            # binary files under build/bin that are not C++ test at the time of
564
            # this writing have been excluded and new ones should be added to
565
            # the list of exclusions in tools/testing/discover_tests.py
566
            ret_code = 0 if ret_code == 5 else ret_code
567

568
    if options.pipe_logs and print_log:
569
        handle_log_file(
570
            test_module, log_path, failed=(ret_code != 0), was_rerun=was_rerun
571
        )
572
    return ret_code
573

574

575
def try_set_cpp_stack_traces(env, command, set=True):
576
    # Print full c++ stack traces during retries
577
    env = env or {}
578
    env["TORCH_SHOW_CPP_STACKTRACES"] = "1" if set else "0"
579
    return env
580

581

582
def run_test_retries(
583
    command,
584
    test_directory,
585
    env,
586
    timeout,
587
    stepcurrent_key,
588
    output,
589
    continue_through_error,
590
):
591
    # Run the test with -x to stop at first failure.  Rerun the test by itself.
592
    # If it succeeds, move on to the rest of the tests in a new process.  If it
593
    # still fails, see below
594
    #
595
    # If continue through error is not set, then we fail fast.
596
    #
597
    # If continue through error is set, then we skip that test, and keep going.
598
    # Basically if the same test fails 3 times in a row, skip the test on the
599
    # next run, but still fail in the end. I take advantage of the value saved
600
    # in stepcurrent to keep track of the most recently run test (which is the
601
    # one that failed if there was a failure).
602

603
    def print_to_file(s):
604
        print(s, file=output, flush=True)
605

606
    num_failures = defaultdict(int)
607

608
    print_items = ["--print-items"]
609
    sc_command = f"--sc={stepcurrent_key}"
610
    while True:
611
        ret_code, _ = retry_shell(
612
            command + [sc_command] + print_items,
613
            test_directory,
614
            stdout=output,
615
            stderr=output,
616
            env=env,
617
            timeout=timeout,
618
            retries=0,  # no retries here, we do it ourselves, this is because it handles timeout exceptions well
619
        )
620
        ret_code = 0 if ret_code == 5 else ret_code
621
        if ret_code == 0 and not sc_command.startswith("--rs="):
622
            break  # Got to the end of the test suite successfully
623
        signal_name = f" ({SIGNALS_TO_NAMES_DICT[-ret_code]})" if ret_code < 0 else ""
624
        print_to_file(f"Got exit code {ret_code}{signal_name}")
625

626
        # Read what just failed/ran
627
        try:
628
            with open(
629
                REPO_ROOT / ".pytest_cache/v/cache/stepcurrent" / stepcurrent_key
630
            ) as f:
631
                current_failure = f.read()
632
        except FileNotFoundError:
633
            print_to_file(
634
                "No stepcurrent file found. Either pytest didn't get to run (e.g. import error)"
635
                + " or file got deleted (contact dev infra)"
636
            )
637
            break
638

639
        env = try_set_cpp_stack_traces(env, command, set=False)
640
        if ret_code != 0:
641
            num_failures[current_failure] += 1
642

643
        if ret_code == 0:
644
            # Rerunning the previously failing test succeeded, so now we can
645
            # skip it and move on
646
            sc_command = f"--scs={stepcurrent_key}"
647
            print_to_file(
648
                "Test succeeeded in new process, continuing with the rest of the tests"
649
            )
650
        elif num_failures[current_failure] >= 3:
651
            if not continue_through_error:
652
                print_to_file("Stopping at first consistent failure")
653
                break
654
            sc_command = f"--scs={stepcurrent_key}"
655
            print_to_file(
656
                "Test failed consistently, "
657
                "continuing with the rest of the tests due to continue-through-error being set"
658
            )
659
        else:
660
            env = try_set_cpp_stack_traces(env, command, set=True)
661
            sc_command = f"--rs={stepcurrent_key}"
662
            print_to_file("Retrying single test...")
663
        print_items = []  # do not continue printing them, massive waste of space
664

665
    consistent_failures = [x[1:-1] for x in num_failures.keys() if num_failures[x] >= 3]
666
    flaky_failures = [x[1:-1] for x in num_failures.keys() if 0 < num_failures[x] < 3]
667
    if len(flaky_failures) > 0:
668
        print_to_file(
669
            "The following tests failed and then succeeded when run in a new process"
670
            + f"{flaky_failures}",
671
        )
672
    if len(consistent_failures) > 0:
673
        print_to_file(f"The following tests failed consistently: {consistent_failures}")
674
        return 1, True
675
    return ret_code, any(x > 0 for x in num_failures.values())
676

677

678
def run_test_with_subprocess(test_module, test_directory, options):
679
    return run_test(
680
        test_module, test_directory, options, extra_unittest_args=["--subprocess"]
681
    )
682

683

684
def _test_cpp_extensions_aot(test_directory, options, use_ninja):
685
    if use_ninja:
686
        try:
687
            from torch.utils import cpp_extension
688

689
            cpp_extension.verify_ninja_availability()
690
        except RuntimeError:
691
            print_to_stderr(CPP_EXTENSIONS_ERROR)
692
            return 1
693

694
    # Wipe the build folder, if it exists already
695
    cpp_extensions_test_dir = os.path.join(test_directory, "cpp_extensions")
696
    cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build")
697
    if os.path.exists(cpp_extensions_test_build_dir):
698
        shutil.rmtree(cpp_extensions_test_build_dir)
699

700
    # Build the test cpp extensions modules
701
    shell_env = os.environ.copy()
702
    shell_env["USE_NINJA"] = str(1 if use_ninja else 0)
703
    cmd = [sys.executable, "setup.py", "install", "--root", "./install"]
704
    return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=shell_env)
705
    if return_code != 0:
706
        return return_code
707
    if sys.platform != "win32":
708
        return_code = shell(
709
            cmd,
710
            cwd=os.path.join(cpp_extensions_test_dir, "no_python_abi_suffix_test"),
711
            env=shell_env,
712
        )
713
        if return_code != 0:
714
            return return_code
715

716
    # "install" the test modules and run tests
717
    python_path = os.environ.get("PYTHONPATH", "")
718
    from shutil import copyfile
719

720
    os.environ["USE_NINJA"] = shell_env["USE_NINJA"]
721
    test_module = "test_cpp_extensions_aot" + ("_ninja" if use_ninja else "_no_ninja")
722
    copyfile(
723
        test_directory + "/test_cpp_extensions_aot.py",
724
        test_directory + "/" + test_module + ".py",
725
    )
726
    try:
727
        cpp_extensions = os.path.join(test_directory, "cpp_extensions")
728
        install_directory = ""
729
        # install directory is the one that is named site-packages
730
        for root, directories, _ in os.walk(os.path.join(cpp_extensions, "install")):
731
            for directory in directories:
732
                if "-packages" in directory:
733
                    install_directory = os.path.join(root, directory)
734

735
        assert install_directory, "install_directory must not be empty"
736
        os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
737
        return run_test(ShardedTest(test_module, 1, 1), test_directory, options)
738
    finally:
739
        os.environ["PYTHONPATH"] = python_path
740
        if os.path.exists(test_directory + "/" + test_module + ".py"):
741
            os.remove(test_directory + "/" + test_module + ".py")
742
        os.environ.pop("USE_NINJA")
743

744

745
def test_cpp_extensions_aot_ninja(test_module, test_directory, options):
746
    return _test_cpp_extensions_aot(test_directory, options, use_ninja=True)
747

748

749
def test_cpp_extensions_aot_no_ninja(test_module, test_directory, options):
750
    return _test_cpp_extensions_aot(test_directory, options, use_ninja=False)
751

752

753
def test_autoload_enable(test_module, test_directory, options):
754
    return _test_autoload(test_directory, options, enable=True)
755

756

757
def test_autoload_disable(test_module, test_directory, options):
758
    return _test_autoload(test_directory, options, enable=False)
759

760

761
def _test_autoload(test_directory, options, enable=True):
762
    # Wipe the build folder, if it exists already
763
    cpp_extensions_test_dir = os.path.join(test_directory, "cpp_extensions")
764
    cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build")
765
    if os.path.exists(cpp_extensions_test_build_dir):
766
        shutil.rmtree(cpp_extensions_test_build_dir)
767

768
    # Build the test cpp extensions modules
769
    cmd = [sys.executable, "setup.py", "install", "--root", "./install"]
770
    return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=os.environ)
771
    if return_code != 0:
772
        return return_code
773

774
    # "install" the test modules and run tests
775
    python_path = os.environ.get("PYTHONPATH", "")
776

777
    try:
778
        cpp_extensions = os.path.join(test_directory, "cpp_extensions")
779
        install_directory = ""
780
        # install directory is the one that is named site-packages
781
        for root, directories, _ in os.walk(os.path.join(cpp_extensions, "install")):
782
            for directory in directories:
783
                if "-packages" in directory:
784
                    install_directory = os.path.join(root, directory)
785

786
        assert install_directory, "install_directory must not be empty"
787
        os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
788
        os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = str(int(enable))
789

790
        cmd = [sys.executable, "test_autoload.py"]
791
        return_code = shell(cmd, cwd=test_directory, env=os.environ)
792
        return return_code
793
    finally:
794
        os.environ["PYTHONPATH"] = python_path
795
        os.environ.pop("TORCH_DEVICE_BACKEND_AUTOLOAD")
796

797

798
def test_distributed(test_module, test_directory, options):
799
    # MPI tests are broken with Python-3.9
800
    mpi_available = subprocess.call(
801
        "command -v mpiexec", shell=True
802
    ) == 0 and sys.version_info < (3, 9)
803
    if options.verbose and not mpi_available:
804
        print_to_stderr("MPI not available -- MPI backend tests will be skipped")
805

806
    config = DISTRIBUTED_TESTS_CONFIG
807
    for backend, env_vars in config.items():
808
        if sys.platform == "win32" and backend != "gloo":
809
            continue
810
        if backend == "mpi" and not mpi_available:
811
            continue
812
        for with_init_file in {True, False}:
813
            if sys.platform == "win32" and not with_init_file:
814
                continue
815
            tmp_dir = tempfile.mkdtemp()
816
            if options.verbose:
817
                init_str = "with {} init_method"
818
                with_init = init_str.format("file" if with_init_file else "env")
819
                print_to_stderr(
820
                    f"Running distributed tests for the {backend} backend {with_init}"
821
                )
822
            old_environ = dict(os.environ)
823
            os.environ["TEMP_DIR"] = tmp_dir
824
            os.environ["BACKEND"] = backend
825
            os.environ.update(env_vars)
826
            try:
827
                os.mkdir(os.path.join(tmp_dir, "barrier"))
828
                os.mkdir(os.path.join(tmp_dir, "test_dir"))
829
                if backend == "mpi":
830
                    # test mpiexec for --noprefix option
831
                    with open(os.devnull, "w") as devnull:
832
                        allowrunasroot_opt = (
833
                            "--allow-run-as-root"
834
                            if subprocess.call(
835
                                'mpiexec --allow-run-as-root -n 1 bash -c ""',
836
                                shell=True,
837
                                stdout=devnull,
838
                                stderr=subprocess.STDOUT,
839
                            )
840
                            == 0
841
                            else ""
842
                        )
843
                        noprefix_opt = (
844
                            "--noprefix"
845
                            if subprocess.call(
846
                                f'mpiexec {allowrunasroot_opt} -n 1 --noprefix bash -c ""',
847
                                shell=True,
848
                                stdout=devnull,
849
                                stderr=subprocess.STDOUT,
850
                            )
851
                            == 0
852
                            else ""
853
                        )
854

855
                    mpiexec = ["mpiexec", "-n", "3", noprefix_opt, allowrunasroot_opt]
856

857
                    return_code = run_test(
858
                        test_module, test_directory, options, launcher_cmd=mpiexec
859
                    )
860
                else:
861
                    return_code = run_test(
862
                        test_module,
863
                        test_directory,
864
                        options,
865
                        extra_unittest_args=["--subprocess"],
866
                    )
867
                if return_code != 0:
868
                    return return_code
869
            finally:
870
                shutil.rmtree(tmp_dir)
871
                os.environ.clear()
872
                os.environ.update(old_environ)
873
    return 0
874

875

876
def run_doctests(test_module, test_directory, options):
877
    """
878
    Assumes the incoming test module is called doctest, and simply executes the
879
    xdoctest runner on the torch library itself.
880
    """
881
    import xdoctest
882

883
    pkgpath = Path(torch.__file__).parent
884

885
    exclude_module_list = ["torch._vendor.*"]
886
    enabled = {
887
        # TODO: expose these options to the user
888
        # For now disable all feature-conditional tests
889
        # 'lapack': 'auto',
890
        # 'cuda': 'auto',
891
        # 'cuda1': 'auto',
892
        # 'qengine': 'auto',
893
        "lapack": 0,
894
        "cuda": 0,
895
        "cuda1": 0,
896
        "qengine": 0,
897
        "autograd_profiler": 0,
898
        "cpp_ext": 0,
899
        "monitor": 0,
900
        "onnx": "auto",
901
    }
902

903
    # Resolve "auto" based on a test to determine if the feature is available.
904
    if enabled["cuda"] == "auto" and torch.cuda.is_available():
905
        enabled["cuda"] = True
906

907
    if (
908
        enabled["cuda1"] == "auto"
909
        and torch.cuda.is_available()
910
        and torch.cuda.device_count() > 1
911
    ):
912
        enabled["cuda1"] = True
913

914
    if enabled["lapack"] == "auto" and torch._C.has_lapack:
915
        enabled["lapack"] = True
916

917
    if enabled["qengine"] == "auto":
918
        try:
919
            # Is there a better check if quantization is enabled?
920
            import torch.ao.nn.quantized as nnq  # NOQA: F401
921

922
            torch.backends.quantized.engine = "qnnpack"
923
            torch.backends.quantized.engine = "fbgemm"
924
        except (ImportError, RuntimeError):
925
            ...
926
        else:
927
            enabled["qengine"] = True
928

929
    if enabled["onnx"] == "auto":
930
        try:
931
            import onnx  # NOQA: F401
932
            import onnxruntime  # NOQA: F401
933
            import onnxscript  # NOQA: F401
934
        except ImportError:
935
            exclude_module_list.append("torch.onnx.*")
936
            enabled["onnx"] = False
937
        else:
938
            enabled["onnx"] = True
939

940
    # Set doctest environment variables
941
    if enabled["cuda"]:
942
        os.environ["TORCH_DOCTEST_CUDA"] = "1"
943

944
    if enabled["cuda1"]:
945
        os.environ["TORCH_DOCTEST_CUDA1"] = "1"
946

947
    if enabled["lapack"]:
948
        os.environ["TORCH_DOCTEST_LAPACK"] = "1"
949

950
    if enabled["qengine"]:
951
        os.environ["TORCH_DOCTEST_QENGINE"] = "1"
952

953
    if enabled["autograd_profiler"]:
954
        os.environ["TORCH_DOCTEST_AUTOGRAD_PROFILER"] = "1"
955

956
    if enabled["cpp_ext"]:
957
        os.environ["TORCH_DOCTEST_CPP_EXT"] = "1"
958

959
    if enabled["monitor"]:
960
        os.environ["TORCH_DOCTEST_MONITOR"] = "1"
961

962
    if enabled["onnx"]:
963
        os.environ["TORCH_DOCTEST_ONNX"] = "1"
964

965
    if 0:
966
        # TODO: could try to enable some of these
967
        os.environ["TORCH_DOCTEST_QUANTIZED_DYNAMIC"] = "1"
968
        os.environ["TORCH_DOCTEST_ANOMALY"] = "1"
969
        os.environ["TORCH_DOCTEST_AUTOGRAD"] = "1"
970
        os.environ["TORCH_DOCTEST_HUB"] = "1"
971
        os.environ["TORCH_DOCTEST_DATALOADER"] = "1"
972
        os.environ["TORCH_DOCTEST_FUTURES"] = "1"
973

974
    pkgpath = os.path.dirname(torch.__file__)
975

976
    xdoctest_config = {
977
        "global_exec": r"\n".join(
978
            [
979
                "from torch import nn",
980
                "import torch.nn.functional as F",
981
                "import torch",
982
            ]
983
        ),
984
        "analysis": "static",  # set to "auto" to test doctests in compiled modules
985
        "style": "google",
986
        "options": "+IGNORE_WHITESPACE",
987
    }
988
    xdoctest_verbose = max(1, options.verbose)
989
    run_summary = xdoctest.runner.doctest_module(
990
        os.fspath(pkgpath),
991
        config=xdoctest_config,
992
        verbose=xdoctest_verbose,
993
        command=options.xdoctest_command,
994
        argv=[],
995
        exclude=exclude_module_list,
996
    )
997
    result = 1 if run_summary.get("n_failed", 0) else 0
998
    return result
999

1000

1001
def sanitize_file_name(file: str):
1002
    return file.replace("\\", ".").replace("/", ".").replace(" ", "_")
1003

1004

1005
def handle_log_file(
1006
    test: ShardedTest, file_path: str, failed: bool, was_rerun: bool
1007
) -> None:
1008
    test = str(test)
1009
    with open(file_path, errors="ignore") as f:
1010
        full_text = f.read()
1011

1012
    new_file = "test/test-reports/" + sanitize_file_name(
1013
        f"{test}_{os.urandom(8).hex()}_.log"
1014
    )
1015
    os.rename(file_path, REPO_ROOT / new_file)
1016

1017
    if not failed and not was_rerun and "=== RERUNS ===" not in full_text:
1018
        # If success + no retries (idk how else to check for test level retries
1019
        # other than reparse xml), print only what tests ran
1020
        print_to_stderr(
1021
            f"\n{test} was successful, full logs can be found in artifacts with path {new_file}"
1022
        )
1023
        for line in full_text.splitlines():
1024
            if re.search("Running .* items in this shard:", line):
1025
                print_to_stderr(line.rstrip())
1026
        print_to_stderr("")
1027
        return
1028

1029
    # otherwise: print entire file
1030
    print_to_stderr(f"\nPRINTING LOG FILE of {test} ({new_file})")
1031
    print_to_stderr(full_text)
1032
    print_to_stderr(f"FINISHED PRINTING LOG FILE of {test} ({new_file})\n")
1033

1034

1035
def get_pytest_args(options, is_cpp_test=False, is_distributed_test=False):
1036
    if RERUN_DISABLED_TESTS:
1037
        # Distributed tests are too slow, so running them x50 will cause the jobs to timeout after
1038
        # 3+ hours. So, let's opt for less number of reruns. We need at least 150 instances of the
1039
        # test every 2 weeks to satisfy the Rockset query (15 x 14 = 210). The same logic applies
1040
        # to ASAN, which is also slow
1041
        count = 15 if is_distributed_test or TEST_WITH_ASAN else 50
1042
        # When under rerun-disabled-tests mode, run the same tests multiple times to determine their
1043
        # flakiness status. Default to 50 re-runs
1044
        rerun_options = ["--flake-finder", f"--flake-runs={count}"]
1045
    else:
1046
        # When under the normal mode, retry a failed test 2 more times. -x means stop at the first
1047
        # failure
1048
        rerun_options = ["-x", "--reruns=2"]
1049

1050
    pytest_args = [
1051
        "-vv",
1052
        "-rfEX",
1053
    ]
1054
    if not is_cpp_test:
1055
        # C++ tests need to be run with pytest directly, not via python
1056
        # We have a custom pytest shard that conflicts with the normal plugin
1057
        pytest_args.extend(["-p", "no:xdist", "--use-pytest"])
1058
    else:
1059
        # Use pytext-dist to run C++ tests in parallel as running them sequentially using run_test
1060
        # is much slower than running them directly
1061
        pytest_args.extend(["-n", str(NUM_PROCS)])
1062

1063
        if IS_CI:
1064
            # Add the option to generate XML test report here as C++ tests
1065
            # won't go into common_utils
1066
            test_report_path = get_report_path(pytest=True)
1067
            pytest_args.extend(["--junit-xml-reruns", test_report_path])
1068

1069
    if options.pytest_k_expr:
1070
        pytest_args.extend(["-k", options.pytest_k_expr])
1071

1072
    pytest_args.extend(rerun_options)
1073
    return pytest_args
1074

1075

1076
def run_ci_sanity_check(test: ShardedTest, test_directory, options):
1077
    assert (
1078
        test.name == "test_ci_sanity_check_fail"
1079
    ), f"This handler only works for test_ci_sanity_check_fail, got {test.name}"
1080
    ret_code = run_test(test, test_directory, options, print_log=False)
1081
    # This test should fail
1082
    if ret_code != 1:
1083
        return 1
1084
    test_reports_dir = str(REPO_ROOT / "test/test-reports")
1085
    # Delete the log files and xmls generated by the test
1086
    for file in glob.glob(f"{test_reports_dir}/{test.name}*.log"):
1087
        os.remove(file)
1088
    for dirname in glob.glob(f"{test_reports_dir}/**/{test.name}"):
1089
        shutil.rmtree(dirname)
1090
    return 0
1091

1092

1093
CUSTOM_HANDLERS = {
1094
    "test_cuda_primary_ctx": run_test_with_subprocess,
1095
    "test_cuda_nvml_based_avail": run_test_with_subprocess,
1096
    "test_cuda_trace": run_test_with_subprocess,
1097
    "test_cpp_extensions_aot_no_ninja": test_cpp_extensions_aot_no_ninja,
1098
    "test_cpp_extensions_aot_ninja": test_cpp_extensions_aot_ninja,
1099
    "distributed/test_distributed_spawn": test_distributed,
1100
    "distributed/algorithms/quantization/test_quantization": test_distributed,
1101
    "distributed/test_c10d_nccl": run_test_with_subprocess,
1102
    "distributed/test_c10d_gloo": run_test_with_subprocess,
1103
    "distributed/test_c10d_ucc": run_test_with_subprocess,
1104
    "distributed/test_c10d_common": run_test_with_subprocess,
1105
    "distributed/test_c10d_spawn_gloo": run_test_with_subprocess,
1106
    "distributed/test_c10d_spawn_nccl": run_test_with_subprocess,
1107
    "distributed/test_c10d_spawn_ucc": run_test_with_subprocess,
1108
    "distributed/test_store": run_test_with_subprocess,
1109
    "distributed/test_pg_wrapper": run_test_with_subprocess,
1110
    "distributed/rpc/test_faulty_agent": run_test_with_subprocess,
1111
    "distributed/rpc/test_tensorpipe_agent": run_test_with_subprocess,
1112
    "distributed/rpc/test_share_memory": run_test_with_subprocess,
1113
    "distributed/rpc/cuda/test_tensorpipe_agent": run_test_with_subprocess,
1114
    "doctests": run_doctests,
1115
    "test_ci_sanity_check_fail": run_ci_sanity_check,
1116
    "test_autoload_enable": test_autoload_enable,
1117
    "test_autoload_disable": test_autoload_disable,
1118
}
1119

1120

1121
PYTEST_SKIP_RETRIES = {"test_public_bindings"}
1122

1123

1124
def parse_args():
1125
    parser = argparse.ArgumentParser(
1126
        description="Run the PyTorch unit test suite",
1127
        epilog="where TESTS is any of: {}".format(", ".join(TESTS)),
1128
        formatter_class=argparse.RawTextHelpFormatter,
1129
    )
1130
    parser.add_argument(
1131
        "-v",
1132
        "--verbose",
1133
        action="count",
1134
        default=0,
1135
        help="Print verbose information and test-by-test results",
1136
    )
1137
    if sys.version_info >= (3, 9):
1138
        parser.add_argument(
1139
            "--showlocals",
1140
            action=argparse.BooleanOptionalAction,
1141
            default=strtobool(os.environ.get("TEST_SHOWLOCALS", "False")),
1142
            help="Show local variables in tracebacks (default: True)",
1143
        )
1144
    else:
1145
        parser.add_argument(
1146
            "--showlocals",
1147
            action="store_true",
1148
            default=strtobool(os.environ.get("TEST_SHOWLOCALS", "False")),
1149
            help="Show local variables in tracebacks (default: True)",
1150
        )
1151
        parser.add_argument("--no-showlocals", dest="showlocals", action="store_false")
1152
    parser.add_argument("--jit", "--jit", action="store_true", help="run all jit tests")
1153
    parser.add_argument(
1154
        "--distributed-tests",
1155
        "--distributed-tests",
1156
        action="store_true",
1157
        help="Run all distributed tests",
1158
    )
1159
    parser.add_argument(
1160
        "--functorch",
1161
        "--functorch",
1162
        action="store_true",
1163
        help=(
1164
            "If this flag is present, we will only run functorch tests. "
1165
            "If this flag is not present, we will run all tests "
1166
            "(including functorch tests)."
1167
        ),
1168
    )
1169
    parser.add_argument(
1170
        "--mps",
1171
        "--mps",
1172
        action="store_true",
1173
        help=("If this flag is present, we will only run test_mps and test_metal"),
1174
    )
1175
    parser.add_argument(
1176
        "--xpu",
1177
        "--xpu",
1178
        action="store_true",
1179
        help=("If this flag is present, we will run xpu tests except XPU_BLOCK_LIST"),
1180
    )
1181
    parser.add_argument(
1182
        "--cpp",
1183
        "--cpp",
1184
        action="store_true",
1185
        help=("If this flag is present, we will only run C++ tests"),
1186
    )
1187
    parser.add_argument(
1188
        "-core",
1189
        "--core",
1190
        action="store_true",
1191
        help="Only run core tests, or tests that validate PyTorch's ops, modules,"
1192
        "and autograd. They are defined by CORE_TEST_LIST.",
1193
    )
1194
    parser.add_argument(
1195
        "--onnx",
1196
        "--onnx",
1197
        action="store_true",
1198
        help=(
1199
            "Only run ONNX tests, or tests that validate PyTorch's ONNX export. "
1200
            "If this flag is not present, we will exclude ONNX tests."
1201
        ),
1202
    )
1203
    parser.add_argument(
1204
        "-k",
1205
        "--pytest-k-expr",
1206
        default="",
1207
        help="Pass to pytest as its -k expr argument",
1208
    )
1209
    parser.add_argument(
1210
        "-c",
1211
        "--coverage",
1212
        action="store_true",
1213
        help="enable coverage",
1214
        default=PYTORCH_COLLECT_COVERAGE,
1215
    )
1216
    parser.add_argument(
1217
        "-i",
1218
        "--include",
1219
        nargs="+",
1220
        choices=TestChoices(TESTS),
1221
        default=TESTS,
1222
        metavar="TESTS",
1223
        help="select a set of tests to include (defaults to ALL tests)."
1224
        " tests must be a part of the TESTS list defined in run_test.py",
1225
    )
1226
    parser.add_argument(
1227
        "-x",
1228
        "--exclude",
1229
        nargs="+",
1230
        choices=TESTS,
1231
        metavar="TESTS",
1232
        default=[],
1233
        help="select a set of tests to exclude",
1234
    )
1235
    parser.add_argument(
1236
        "--ignore-win-blocklist",
1237
        action="store_true",
1238
        help="always run blocklisted windows tests",
1239
    )
1240
    # NS: Disable target determination until it can be made more reliable
1241
    # parser.add_argument(
1242
    #     "--determine-from",
1243
    #     help="File of affected source filenames to determine which tests to run.",
1244
    # )
1245
    parser.add_argument(
1246
        "--continue-through-error",
1247
        "--keep-going",
1248
        action="store_true",
1249
        help="Runs the full test suite despite one of the tests failing",
1250
        default=strtobool(os.environ.get("CONTINUE_THROUGH_ERROR", "False")),
1251
    )
1252
    parser.add_argument(
1253
        "--pipe-logs",
1254
        action="store_true",
1255
        help="Print logs to output file while running tests.  True if in CI and env var is not set",
1256
        default=IS_CI and not strtobool(os.environ.get("VERBOSE_TEST_LOGS", "False")),
1257
    )
1258
    parser.add_argument(
1259
        "--enable-timeout",
1260
        action="store_true",
1261
        help="Set a timeout based on the test times json file.  Only works if there are test times available",
1262
        default=IS_CI and not strtobool(os.environ.get("NO_TEST_TIMEOUT", "False")),
1263
    )
1264
    parser.add_argument(
1265
        "--enable-td",
1266
        action="store_true",
1267
        help="Enables removing tests based on TD",
1268
        default=IS_CI
1269
        and (
1270
            TEST_WITH_CROSSREF
1271
            or TEST_WITH_ASAN
1272
            or (TEST_CONFIG == "distributed" and TEST_CUDA)
1273
            or (IS_WINDOWS and not TEST_CUDA)
1274
            or TEST_CONFIG == "nogpu_AVX512"
1275
            or TEST_CONFIG == "nogpu_NO_AVX2"
1276
            or TEST_CONFIG == "default"
1277
        )
1278
        and get_pr_number() is not None
1279
        and not strtobool(os.environ.get("NO_TD", "False"))
1280
        and not TEST_WITH_ROCM
1281
        and not IS_MACOS
1282
        and "xpu" not in BUILD_ENVIRONMENT
1283
        and "onnx" not in BUILD_ENVIRONMENT
1284
        and os.environ.get("GITHUB_WORKFLOW", "slow") in ("trunk", "pull"),
1285
    )
1286
    parser.add_argument(
1287
        "--shard",
1288
        nargs=2,
1289
        type=int,
1290
        help="runs a shard of the tests (taking into account other selections), e.g., "
1291
        "--shard 2 3 will break up the selected tests into 3 shards and run the tests "
1292
        "in the 2nd shard (the first number should not exceed the second)",
1293
    )
1294
    parser.add_argument(
1295
        "--exclude-jit-executor",
1296
        action="store_true",
1297
        help="exclude tests that are run for a specific jit config",
1298
    )
1299
    parser.add_argument(
1300
        "--exclude-torch-export-tests",
1301
        action="store_true",
1302
        help="exclude torch export tests",
1303
    )
1304
    parser.add_argument(
1305
        "--exclude-distributed-tests",
1306
        action="store_true",
1307
        help="exclude distributed tests",
1308
    )
1309
    parser.add_argument(
1310
        "--exclude-inductor-tests",
1311
        action="store_true",
1312
        help="exclude inductor tests",
1313
    )
1314
    parser.add_argument(
1315
        "--dry-run",
1316
        action="store_true",
1317
        help="Only list the test that will run.",
1318
    )
1319
    parser.add_argument(
1320
        "--xdoctest-command",
1321
        default="all",
1322
        help=(
1323
            "Control the specific doctest action. "
1324
            "Use 'list' to simply parse doctests and check syntax. "
1325
            "Use 'all' to execute all doctests or specify a specific "
1326
            "doctest to run"
1327
        ),
1328
    )
1329
    parser.add_argument(
1330
        "--no-translation-validation",
1331
        action="store_false",
1332
        help="Run tests without translation validation.",
1333
    )
1334

1335
    group = parser.add_mutually_exclusive_group()
1336
    group.add_argument(
1337
        "--dynamo",
1338
        action="store_true",
1339
        help="Run tests with TorchDynamo+EagerBackend turned on",
1340
    )
1341
    group.add_argument(
1342
        "--inductor",
1343
        action="store_true",
1344
        help="Run tests with TorchInductor turned on",
1345
    )
1346

1347
    args, extra = parser.parse_known_args()
1348
    if "--" in extra:
1349
        extra.remove("--")
1350
    args.additional_args = extra
1351
    return args
1352

1353

1354
def exclude_tests(
1355
    exclude_list, selected_tests, exclude_message=None, exact_match=False
1356
):
1357
    for exclude_test in exclude_list:
1358
        tests_copy = selected_tests[:]
1359
        for test in tests_copy:
1360
            if (
1361
                not exact_match and test.startswith(exclude_test)
1362
            ) or test == exclude_test:
1363
                if exclude_message is not None:
1364
                    print_to_stderr(f"Excluding {test} {exclude_message}")
1365
                selected_tests.remove(test)
1366
    return selected_tests
1367

1368

1369
def must_serial(file: Union[str, ShardedTest]) -> bool:
1370
    if isinstance(file, ShardedTest):
1371
        file = file.name
1372
    return (
1373
        os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1"
1374
        or DISTRIBUTED_TEST_PREFIX in os.getenv("TEST_CONFIG", "")
1375
        or DISTRIBUTED_TEST_PREFIX in file
1376
        or file in CUSTOM_HANDLERS
1377
        or file in RUN_PARALLEL_BLOCKLIST
1378
        or file in CI_SERIAL_LIST
1379
        or file in JIT_EXECUTOR_TESTS
1380
        or file in ONNX_SERIAL_LIST
1381
        or NUM_PROCS == 1
1382
    )
1383

1384

1385
def can_run_in_pytest(test):
1386
    return os.getenv("PYTORCH_TEST_DO_NOT_USE_PYTEST", "0") == "0"
1387

1388

1389
def get_selected_tests(options) -> List[str]:
1390
    selected_tests = options.include
1391

1392
    # filter if there's JIT only and distributed only test options
1393
    if options.jit:
1394
        selected_tests = list(
1395
            filter(lambda test_name: "jit" in test_name, selected_tests)
1396
        )
1397

1398
    if options.distributed_tests:
1399
        selected_tests = list(
1400
            filter(lambda test_name: test_name in DISTRIBUTED_TESTS, selected_tests)
1401
        )
1402

1403
    # Filter to only run core tests when --core option is specified
1404
    if options.core:
1405
        selected_tests = list(
1406
            filter(lambda test_name: test_name in CORE_TEST_LIST, selected_tests)
1407
        )
1408

1409
    # Filter to only run functorch tests when --functorch option is specified
1410
    if options.functorch:
1411
        selected_tests = [tname for tname in selected_tests if tname in FUNCTORCH_TESTS]
1412

1413
    if options.cpp:
1414
        selected_tests = [tname for tname in selected_tests if tname in CPP_TESTS]
1415
    else:
1416
        # Exclude all C++ tests otherwise as they are still handled differently
1417
        # than Python test at the moment
1418
        options.exclude.extend(CPP_TESTS)
1419

1420
    if options.mps:
1421
        selected_tests = ["test_mps", "test_metal", "test_modules", "test_nn"]
1422
    else:
1423
        # Exclude all mps tests otherwise
1424
        options.exclude.extend(["test_mps", "test_metal"])
1425

1426
    if options.xpu:
1427
        selected_tests = exclude_tests(XPU_BLOCKLIST, selected_tests, "on XPU")
1428
    else:
1429
        # Exclude all xpu specifc tests otherwise
1430
        options.exclude.extend(XPU_TEST)
1431

1432
    # Filter to only run onnx tests when --onnx option is specified
1433
    onnx_tests = [tname for tname in selected_tests if tname in ONNX_TESTS]
1434
    if options.onnx:
1435
        selected_tests = onnx_tests
1436
    else:
1437
        # Exclude all onnx tests otherwise
1438
        options.exclude.extend(onnx_tests)
1439

1440
    # process exclusion
1441
    if options.exclude_jit_executor:
1442
        options.exclude.extend(JIT_EXECUTOR_TESTS)
1443

1444
    if options.exclude_distributed_tests:
1445
        options.exclude.extend(DISTRIBUTED_TESTS)
1446

1447
    if options.exclude_inductor_tests:
1448
        options.exclude.extend(INDUCTOR_TESTS)
1449

1450
    if options.exclude_torch_export_tests:
1451
        options.exclude.extend(TORCH_EXPORT_TESTS)
1452

1453
    # these tests failing in CUDA 11.6 temporary disabling. issue https://github.com/pytorch/pytorch/issues/75375
1454
    if torch.version.cuda is not None:
1455
        options.exclude.extend(["distributions/test_constraints"])
1456

1457
    # these tests failing in Python 3.12 temporarily disabling
1458
    if sys.version_info >= (3, 12):
1459
        options.exclude.extend(
1460
            [
1461
                "functorch/test_dims",
1462
                "functorch/test_rearrange",
1463
                "functorch/test_parsing",
1464
                "functorch/test_memory_efficient_fusion",
1465
                "torch_np/numpy_tests/core/test_multiarray",
1466
            ]
1467
        )
1468

1469
    selected_tests = exclude_tests(options.exclude, selected_tests)
1470

1471
    if sys.platform == "win32" and not options.ignore_win_blocklist:
1472
        target_arch = os.environ.get("VSCMD_ARG_TGT_ARCH")
1473
        if target_arch != "x64":
1474
            WINDOWS_BLOCKLIST.append("cpp_extensions_aot_no_ninja")
1475
            WINDOWS_BLOCKLIST.append("cpp_extensions_aot_ninja")
1476
            WINDOWS_BLOCKLIST.append("cpp_extensions_jit")
1477
            WINDOWS_BLOCKLIST.append("jit")
1478
            WINDOWS_BLOCKLIST.append("jit_fuser")
1479

1480
        selected_tests = exclude_tests(WINDOWS_BLOCKLIST, selected_tests, "on Windows")
1481

1482
    elif TEST_WITH_ROCM:
1483
        selected_tests = exclude_tests(ROCM_BLOCKLIST, selected_tests, "on ROCm")
1484

1485
    # skip all distributed tests if distributed package is not available.
1486
    if not dist.is_available():
1487
        selected_tests = exclude_tests(
1488
            DISTRIBUTED_TESTS,
1489
            selected_tests,
1490
            "PyTorch is built without distributed support.",
1491
        )
1492

1493
    # skip tests that require LAPACK when it's not available
1494
    if not torch._C.has_lapack:
1495
        selected_tests = exclude_tests(
1496
            TESTS_REQUIRING_LAPACK,
1497
            selected_tests,
1498
            "PyTorch is built without LAPACK support.",
1499
        )
1500

1501
    if TEST_WITH_SLOW_GRADCHECK:
1502
        selected_tests = exclude_tests(
1503
            TESTS_NOT_USING_GRADCHECK,
1504
            selected_tests,
1505
            "Running in slow gradcheck mode, skipping tests "
1506
            "that don't use gradcheck.",
1507
            exact_match=True,
1508
        )
1509

1510
    selected_tests = [parse_test_module(x) for x in selected_tests]
1511
    return selected_tests
1512

1513

1514
def load_test_times_from_file(file: str) -> Dict[str, Any]:
1515
    # Load previous test times to make sharding decisions
1516
    path = os.path.join(str(REPO_ROOT), file)
1517
    if not os.path.exists(path):
1518
        print_to_stderr(
1519
            f"::warning:: Failed to find test times file `{path}`. Using round robin sharding."
1520
        )
1521
        return {}
1522

1523
    with open(path) as f:
1524
        test_times_file = cast(Dict[str, Any], json.load(f))
1525
    build_environment = os.environ.get("BUILD_ENVIRONMENT")
1526
    test_config = os.environ.get("TEST_CONFIG")
1527
    if test_config in test_times_file.get(build_environment, {}):
1528
        print_to_stderr("Found test times from artifacts")
1529
        return test_times_file[build_environment][test_config]
1530
    elif test_config in test_times_file["default"]:
1531
        print_to_stderr(
1532
            f"::warning:: Gathered no stats from artifacts for {build_environment} build env"
1533
            f" and {test_config} test config. Using default build env and {test_config} test config instead."
1534
        )
1535
        return test_times_file["default"][test_config]
1536
    else:
1537
        print_to_stderr(
1538
            f"::warning:: Gathered no stats from artifacts for build env {build_environment} build env"
1539
            f" and {test_config} test config. Using default build env and default test config instead."
1540
        )
1541
        return test_times_file["default"]["default"]
1542

1543

1544
def load_test_file_times(
1545
    file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_TIMES_FILE,
1546
) -> Dict[str, float]:
1547
    return cast(Dict[str, float], load_test_times_from_file(file))
1548

1549

1550
def load_test_class_times(
1551
    file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_CLASS_TIMES_FILE,
1552
) -> Dict[str, Dict[str, float]]:
1553
    return cast(Dict[str, Dict[str, float]], load_test_times_from_file(file))
1554

1555

1556
def get_sharding_opts(options) -> Tuple[int, int]:
1557
    which_shard, num_shards = 1, 1
1558
    if options.shard:
1559
        assert len(options.shard) == 2, "Unexpected shard format"
1560
        assert min(options.shard) > 0, "Shards must be positive numbers"
1561
        which_shard, num_shards = options.shard
1562
        assert (
1563
            which_shard <= num_shards
1564
        ), "Selected shard must be less than or equal to total number of shards"
1565

1566
    return (which_shard, num_shards)
1567

1568

1569
def do_sharding(
1570
    options,
1571
    selected_tests: Sequence[TestRun],
1572
    test_file_times: Dict[str, float],
1573
    test_class_times: Dict[str, Dict[str, float]],
1574
    sort_by_time: bool = True,
1575
) -> Tuple[float, List[ShardedTest]]:
1576
    which_shard, num_shards = get_sharding_opts(options)
1577

1578
    # Do sharding
1579
    shards = calculate_shards(
1580
        num_shards,
1581
        selected_tests,
1582
        test_file_times,
1583
        test_class_times=test_class_times,
1584
        must_serial=must_serial,
1585
        sort_by_time=sort_by_time,
1586
    )
1587
    return shards[which_shard - 1]
1588

1589

1590
class TestFailure(NamedTuple):
1591
    test: TestRun
1592
    message: str
1593

1594

1595
def run_test_module(
1596
    test: ShardedTest, test_directory: str, options
1597
) -> Optional[TestFailure]:
1598
    try:
1599
        maybe_set_hip_visible_devies()
1600

1601
        test_name = test.name
1602

1603
        # Printing the date here can help diagnose which tests are slow
1604
        print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
1605
        handler = CUSTOM_HANDLERS.get(test_name, run_test)
1606
        return_code = handler(test, test_directory, options)
1607
        assert isinstance(return_code, int) and not isinstance(
1608
            return_code, bool
1609
        ), f"While running {str(test)} got non integer return code {return_code}"
1610
        if return_code == 0:
1611
            return None
1612

1613
        message = f"{str(test)} failed!"
1614
        if return_code < 0:
1615
            # subprocess.Popen returns the child process' exit signal as
1616
            # return code -N, where N is the signal number.
1617
            signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
1618
            message += f" Received signal: {signal_name}"
1619
        return TestFailure(test.test, message)
1620
    except Exception as e:
1621
        return TestFailure(test.test, f"{str(test)} failed! {e}")
1622

1623

1624
def run_tests(
1625
    selected_tests: List[ShardedTest],
1626
    test_directory: str,
1627
    options,
1628
    failures: List[TestFailure],
1629
) -> None:
1630
    if len(selected_tests) == 0:
1631
        return
1632

1633
    # parallel = in parallel with other files
1634
    # serial = this file on it's own.  The file might still be run in parallel with itself (ex test_ops)
1635
    selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
1636
    selected_tests_serial = [
1637
        x for x in selected_tests if x not in selected_tests_parallel
1638
    ]
1639

1640
    # See Note [ROCm parallel CI testing]
1641
    pool = get_context("spawn").Pool(
1642
        NUM_PROCS, maxtasksperchild=None if torch.version.hip else 1
1643
    )
1644

1645
    # NB: This is a hack to make conftest.py and files it depends on available
1646
    # on CPP_TESTS_DIR. We should see if the file could be turned into a
1647
    # full-fledge ptest plugin instead
1648
    conftest_files = [
1649
        "conftest.py",
1650
        "pytest_shard_custom.py",
1651
    ]
1652
    for conftest_file in conftest_files:
1653
        cpp_file = os.path.join(CPP_TESTS_DIR, conftest_file)
1654
        if (
1655
            options.cpp
1656
            and os.path.exists(CPP_TESTS_DIR)
1657
            and os.path.isdir(CPP_TESTS_DIR)
1658
            and not os.path.exists(cpp_file)
1659
        ):
1660
            shutil.copy(os.path.join(test_directory, conftest_file), cpp_file)
1661

1662
    def handle_error_messages(failure: Optional[TestFailure]):
1663
        if failure is None:
1664
            return False
1665
        failures.append(failure)
1666
        print_to_stderr(failure.message)
1667
        return True
1668

1669
    def parallel_test_completion_callback(failure):
1670
        test_failed = handle_error_messages(failure)
1671
        if (
1672
            test_failed
1673
            and not options.continue_through_error
1674
            and not RERUN_DISABLED_TESTS
1675
        ):
1676
            pool.terminate()
1677

1678
    keep_going_message = (
1679
        "\n\nTip: You can keep running tests even on failure by passing --keep-going to run_test.py.\n"
1680
        "If running on CI, add the 'keep-going' label to your PR and rerun your jobs."
1681
    )
1682

1683
    try:
1684
        for test in selected_tests_serial:
1685
            options_clone = copy.deepcopy(options)
1686
            if can_run_in_pytest(test):
1687
                options_clone.pytest = True
1688
            failure = run_test_module(test, test_directory, options_clone)
1689
            test_failed = handle_error_messages(failure)
1690
            if (
1691
                test_failed
1692
                and not options.continue_through_error
1693
                and not RERUN_DISABLED_TESTS
1694
            ):
1695
                raise RuntimeError(failure.message + keep_going_message)
1696

1697
        # Run tests marked as serial first
1698
        for test in selected_tests_parallel:
1699
            options_clone = copy.deepcopy(options)
1700
            if can_run_in_pytest(test):
1701
                options_clone.pytest = True
1702
            options_clone.additional_args.extend(["-m", "serial"])
1703
            failure = run_test_module(test, test_directory, options_clone)
1704
            test_failed = handle_error_messages(failure)
1705
            if (
1706
                test_failed
1707
                and not options.continue_through_error
1708
                and not RERUN_DISABLED_TESTS
1709
            ):
1710
                raise RuntimeError(failure.message + keep_going_message)
1711

1712
        os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS)
1713
        for test in selected_tests_parallel:
1714
            options_clone = copy.deepcopy(options)
1715
            if can_run_in_pytest(test):
1716
                options_clone.pytest = True
1717
            options_clone.additional_args.extend(["-m", "not serial"])
1718
            pool.apply_async(
1719
                run_test_module,
1720
                args=(test, test_directory, options_clone),
1721
                callback=parallel_test_completion_callback,
1722
            )
1723
        pool.close()
1724
        pool.join()
1725
        del os.environ["NUM_PARALLEL_PROCS"]
1726

1727
    finally:
1728
        pool.terminate()
1729
        pool.join()
1730

1731
    return
1732

1733

1734
def check_pip_packages() -> None:
1735
    packages = [
1736
        "pytest-rerunfailures",
1737
        "pytest-flakefinder",
1738
        "pytest-xdist",
1739
    ]
1740
    installed_packages = [i.key for i in pkg_resources.working_set]
1741
    for package in packages:
1742
        if package not in installed_packages:
1743
            print_to_stderr(
1744
                f"Missing pip dependency: {package}, please run `pip install -r .ci/docker/requirements-ci.txt`"
1745
            )
1746
            sys.exit(1)
1747

1748

1749
def main():
1750
    check_pip_packages()
1751

1752
    options = parse_args()
1753

1754
    # Include sharding info in all metrics
1755
    which_shard, num_shards = get_sharding_opts(options)
1756
    add_global_metric("shard", which_shard)
1757
    add_global_metric("num_shards", num_shards)
1758

1759
    test_directory = str(REPO_ROOT / "test")
1760
    selected_tests = get_selected_tests(options)
1761

1762
    test_prioritizations = import_results()
1763
    test_prioritizations.amend_tests(selected_tests)
1764

1765
    os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
1766

1767
    if options.coverage and not PYTORCH_COLLECT_COVERAGE:
1768
        shell(["coverage", "erase"])
1769

1770
    if IS_CI:
1771
        # downloading test cases configuration to local environment
1772
        get_test_case_configs(dirpath=test_directory)
1773

1774
    test_file_times_dict = load_test_file_times()
1775
    test_class_times_dict = load_test_class_times()
1776

1777
    class TestBatch:
1778
        """Defines a set of tests with similar priority that should be run together on the current shard"""
1779

1780
        name: str
1781
        sharded_tests: List[ShardedTest]
1782
        failures: List[TestFailure]
1783

1784
        def __init__(
1785
            self, name: str, raw_tests: Sequence[TestRun], should_sort_shard: bool
1786
        ):
1787
            self.name = name
1788
            self.failures = []
1789
            self.time, self.sharded_tests = do_sharding(
1790
                options,
1791
                raw_tests,
1792
                test_file_times_dict,
1793
                test_class_times_dict,
1794
                sort_by_time=should_sort_shard,
1795
            )
1796

1797
        def __str__(self):
1798
            s = f"Name: {self.name} (est. time: {round(self.time / 60, 2)}min)\n"
1799
            serial = [test for test in self.sharded_tests if must_serial(test)]
1800
            parallel = [test for test in self.sharded_tests if not must_serial(test)]
1801
            s += f"  Serial tests ({len(serial)}):\n"
1802
            s += "".join(f"    {test}\n" for test in serial)
1803
            s += f"  Parallel tests ({len(parallel)}):\n"
1804
            s += "".join(f"    {test}\n" for test in parallel)
1805
            return s.strip()
1806

1807
    percent_to_run = 25 if options.enable_td else 100
1808
    print_to_stderr(
1809
        f"Running {percent_to_run}% of tests based on TD"
1810
        if options.enable_td
1811
        else "Running all tests"
1812
    )
1813
    include, exclude = test_prioritizations.get_top_per_tests(percent_to_run)
1814

1815
    test_batch = TestBatch("tests to run", include, False)
1816
    test_batch_exclude = TestBatch("excluded", exclude, True)
1817
    if IS_CI:
1818
        gen_ci_artifact([x.to_json() for x in include], [x.to_json() for x in exclude])
1819

1820
    print_to_stderr(f"Running parallel tests on {NUM_PROCS} processes")
1821
    print_to_stderr(test_batch)
1822
    print_to_stderr(test_batch_exclude)
1823

1824
    if options.dry_run:
1825
        return
1826

1827
    if options.dynamo:
1828
        os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1"
1829

1830
    elif options.inductor:
1831
        os.environ["PYTORCH_TEST_WITH_INDUCTOR"] = "1"
1832

1833
    if not options.no_translation_validation:
1834
        os.environ["PYTORCH_TEST_WITH_TV"] = "1"
1835

1836
    try:
1837
        # Actually run the tests
1838
        start_time = time.time()
1839
        run_tests(
1840
            test_batch.sharded_tests, test_directory, options, test_batch.failures
1841
        )
1842
        elapsed_time = time.time() - start_time
1843
        print_to_stderr(
1844
            f"Running test batch '{test_batch.name}' cost {round(elapsed_time, 2)} seconds"
1845
        )
1846

1847
    finally:
1848
        if options.coverage:
1849
            from coverage import Coverage
1850

1851
            with set_cwd(test_directory):
1852
                cov = Coverage()
1853
                if PYTORCH_COLLECT_COVERAGE:
1854
                    cov.load()
1855
                cov.combine(strict=False)
1856
                cov.save()
1857
                if not PYTORCH_COLLECT_COVERAGE:
1858
                    cov.html_report()
1859

1860
        all_failures = test_batch.failures
1861

1862
        if IS_CI:
1863
            for test, _ in all_failures:
1864
                test_stats = test_prioritizations.get_test_stats(test)
1865
                print_to_stderr("Emiting td_test_failure_stats_v2")
1866
                emit_metric(
1867
                    "td_test_failure_stats_v2",
1868
                    {
1869
                        "selected_tests": selected_tests,
1870
                        "failure": str(test),
1871
                        **test_stats,
1872
                    },
1873
                )
1874
            gen_additional_test_failures_file(
1875
                [test.test_file for test, _ in all_failures]
1876
            )
1877

1878
    if len(all_failures):
1879
        for _, err in all_failures:
1880
            print_to_stderr(err)
1881

1882
        # A disabled test is expected to fail, so there is no need to report a failure here
1883
        if not RERUN_DISABLED_TESTS:
1884
            sys.exit(1)
1885

1886

1887
if __name__ == "__main__":
1888
    main()
1889

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.