pytorch

Форк
0
/
run_test.py 
1750 строк · 61.0 Кб
1
#!/usr/bin/env python3
2

3
import argparse
4
import copy
5
import json
6
import os
7
import pathlib
8
import re
9
import shutil
10
import signal
11
import subprocess
12
import sys
13
import tempfile
14
import time
15
from collections import defaultdict
16
from contextlib import ExitStack
17
from datetime import datetime
18
from typing import Any, cast, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
19

20
import pkg_resources
21

22
import torch
23
import torch.distributed as dist
24
from torch.multiprocessing import current_process, get_context
25
from torch.testing._internal.common_utils import (
26
    FILE_SCHEMA,
27
    get_report_path,
28
    IS_CI,
29
    IS_MACOS,
30
    parser as common_parser,
31
    retry_shell,
32
    set_cwd,
33
    shell,
34
    TEST_WITH_ASAN,
35
    TEST_WITH_CROSSREF,
36
    TEST_WITH_ROCM,
37
    TEST_WITH_SLOW_GRADCHECK,
38
)
39

40
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
41

42
# using tools/ to optimize test run.
43
sys.path.insert(0, str(REPO_ROOT))
44
from tools.stats.import_test_stats import (
45
    ADDITIONAL_CI_FILES_FOLDER,
46
    TEST_CLASS_TIMES_FILE,
47
    TEST_TIMES_FILE,
48
)
49
from tools.stats.upload_metrics import add_global_metric, emit_metric
50
from tools.testing.discover_tests import (
51
    CPP_TEST_PATH,
52
    CPP_TEST_PREFIX,
53
    CPP_TESTS_DIR,
54
    parse_test_module,
55
    TESTS,
56
)
57
from tools.testing.do_target_determination_for_s3 import import_results
58

59
from tools.testing.test_run import TestRun
60
from tools.testing.test_selections import (
61
    calculate_shards,
62
    get_test_case_configs,
63
    NUM_PROCS,
64
    ShardedTest,
65
    THRESHOLD,
66
)
67

68
HAVE_TEST_SELECTION_TOOLS = True
69
# Make sure to remove REPO_ROOT after import is done
70
sys.path.remove(str(REPO_ROOT))
71

72

73
RERUN_DISABLED_TESTS = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1"
74
DISTRIBUTED_TEST_PREFIX = "distributed"
75
INDUCTOR_TEST_PREFIX = "inductor"
76
DYNAMO_TEST_PREFIX = "dynamo"
77

78

79
# Note [ROCm parallel CI testing]
80
# https://github.com/pytorch/pytorch/pull/85770 added file-granularity parallel testing.
81
# In .ci/pytorch/test.sh, TEST_CONFIG == "default", CUDA and HIP_VISIBLE_DEVICES is set to 0.
82
# This results in multiple test files sharing the same GPU.
83
# This should be a supported use case for ROCm, but it exposed issues in the kernel driver resulting in hangs.
84
# See https://github.com/pytorch/pytorch/issues/90940.
85
#
86
# Further, ROCm self-hosted runners have up to 4 GPUs.
87
# Device visibility was set to 0 to match CUDA test behavior, but this was wasting available GPU resources.
88
# Assigning each Pool worker their own dedicated GPU avoids the ROCm oversubscription issues.
89
# This should also result in better overall wall clock time since all GPUs can be utilized.
90
def maybe_set_hip_visible_devies():
91
    # Special handling of ROCm GHA runners for parallel (file granularity) tests.
92
    if torch.version.hip:
93
        p = current_process()
94
        if p.name != "MainProcess":
95
            # this is a Process from a parallel Pool, not the MainProcess
96
            os.environ["HIP_VISIBLE_DEVICES"] = str(p._identity[0] % NUM_PROCS)
97

98

99
def strtobool(s):
100
    if s.lower() in ["", "0", "false", "off"]:
101
        return False
102
    return True
103

104

105
class TestChoices(list):
106
    def __init__(self, *args, **kwargs):
107
        super().__init__(args[0])
108

109
    def __contains__(self, item):
110
        return list.__contains__(self, parse_test_module(item))
111

112

113
FSDP_TEST = [test for test in TESTS if test.startswith("distributed/fsdp")]
114

115
WINDOWS_BLOCKLIST = [
116
    "distributed/nn/jit/test_instantiator",
117
    "distributed/rpc/test_faulty_agent",
118
    "distributed/rpc/test_tensorpipe_agent",
119
    "distributed/rpc/test_share_memory",
120
    "distributed/rpc/cuda/test_tensorpipe_agent",
121
    "distributed/pipeline/sync/skip/test_api",
122
    "distributed/pipeline/sync/skip/test_gpipe",
123
    "distributed/pipeline/sync/skip/test_inspect_skip_layout",
124
    "distributed/pipeline/sync/skip/test_leak",
125
    "distributed/pipeline/sync/skip/test_portal",
126
    "distributed/pipeline/sync/skip/test_stash_pop",
127
    "distributed/pipeline/sync/skip/test_tracker",
128
    "distributed/pipeline/sync/skip/test_verify_skippables",
129
    "distributed/pipeline/sync/test_balance",
130
    "distributed/pipeline/sync/test_bugs",
131
    "distributed/pipeline/sync/test_checkpoint",
132
    "distributed/pipeline/sync/test_copy",
133
    "distributed/pipeline/sync/test_deferred_batch_norm",
134
    "distributed/pipeline/sync/test_dependency",
135
    "distributed/pipeline/sync/test_inplace",
136
    "distributed/pipeline/sync/test_microbatch",
137
    "distributed/pipeline/sync/test_phony",
138
    "distributed/pipeline/sync/test_pipe",
139
    "distributed/pipeline/sync/test_pipeline",
140
    "distributed/pipeline/sync/test_stream",
141
    "distributed/pipeline/sync/test_transparency",
142
    "distributed/pipeline/sync/test_worker",
143
    "distributed/elastic/agent/server/test/api_test",
144
    "distributed/elastic/multiprocessing/api_test",
145
    "distributed/_shard/checkpoint/test_checkpoint"
146
    "distributed/_shard/checkpoint/test_file_system_checkpoint"
147
    "distributed/_shard/sharding_spec/test_sharding_spec",
148
    "distributed/_shard/sharding_plan/test_sharding_plan",
149
    "distributed/_shard/sharded_tensor/test_sharded_tensor",
150
    "distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
151
    "distributed/_shard/sharded_tensor/ops/test_embedding",
152
    "distributed/_shard/sharded_tensor/ops/test_embedding_bag",
153
    "distributed/_shard/sharded_tensor/ops/test_binary_cmp",
154
    "distributed/_shard/sharded_tensor/ops/test_init",
155
    "distributed/_shard/sharded_optim/test_sharded_optim",
156
] + FSDP_TEST
157

158
ROCM_BLOCKLIST = [
159
    "distributed/rpc/test_faulty_agent",
160
    "distributed/rpc/test_tensorpipe_agent",
161
    "distributed/rpc/test_share_memory",
162
    "distributed/rpc/cuda/test_tensorpipe_agent",
163
    "distributed/_shard/checkpoint/test_checkpoint"
164
    "distributed/_shard/checkpoint/test_file_system_checkpoint"
165
    "distributed/_shard/sharding_spec/test_sharding_spec",
166
    "distributed/_shard/sharding_plan/test_sharding_plan",
167
    "distributed/_shard/sharded_tensor/test_sharded_tensor",
168
    "distributed/_shard/sharded_tensor/test_sharded_tensor_reshard",
169
    "distributed/_shard/sharded_tensor/ops/test_embedding",
170
    "distributed/_shard/sharded_tensor/ops/test_embedding_bag",
171
    "distributed/_shard/sharded_tensor/ops/test_binary_cmp",
172
    "distributed/_shard/sharded_tensor/ops/test_init",
173
    "distributed/_shard/sharded_optim/test_sharded_optim",
174
    "test_determination",
175
    "test_jit_legacy",
176
    "test_cuda_nvml_based_avail",
177
    "test_jit_cuda_fuser",
178
]
179

180
XPU_BLOCKLIST = [
181
    "test_autograd",
182
]
183

184
XPU_TEST = [
185
    "test_xpu",
186
]
187

188
# The tests inside these files should never be run in parallel with each other
189
RUN_PARALLEL_BLOCKLIST = [
190
    "test_cpp_extensions_jit",
191
    "test_cpp_extensions_open_device_registration",
192
    "test_jit_disabled",
193
    "test_mobile_optimizer",
194
    "test_multiprocessing",
195
    "test_multiprocessing_spawn",
196
    "test_namedtuple_return_api",
197
    "test_overrides",
198
    "test_show_pickle",
199
    "test_tensorexpr",
200
    "test_cuda_primary_ctx",
201
    "test_cuda_trace",
202
    "test_cuda_nvml_based_avail",
203
    # temporarily sets a global config
204
    "test_autograd_fallback",
205
] + FSDP_TEST
206

207
# Test files that should always be run serially with other test files,
208
# but it's okay if the tests inside them are run in parallel with each other.
209
CI_SERIAL_LIST = [
210
    "test_nn",
211
    "test_fake_tensor",
212
    "test_cpp_api_parity",
213
    "test_reductions",
214
    "test_cuda",
215
    "test_cuda_expandable_segments",
216
    "test_indexing",
217
    "test_fx_backends",
218
    "test_linalg",
219
    "test_cpp_extensions_jit",
220
    "test_torch",
221
    "test_tensor_creation_ops",
222
    "test_sparse_csr",
223
    "test_dispatch",
224
    "test_python_dispatch",  # torch.library creation and deletion must be serialized
225
    "test_spectral_ops",  # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/88916
226
    "nn/test_pooling",
227
    "nn/test_convolution",  # Doesn't respect set_per_process_memory_fraction, results in OOM for other tests in slow gradcheck
228
    "distributions/test_distributions",
229
    "test_autograd",  # slow gradcheck runs a test that checks the cuda memory allocator
230
    "test_prims",  # slow gradcheck runs a test that checks the cuda memory allocator
231
    "test_modules",  # failed test due to mismatched elements
232
    "functorch/test_vmap",  # OOM
233
    "test_fx",  # gets SIGKILL
234
    "test_dataloader",  # frequently hangs for ROCm
235
    "test_serialization",  # test_serialization_2gb_file allocates a tensor of 2GB, and could cause OOM
236
    "test_schema_check",  # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/95749
237
    "functorch/test_memory_efficient_fusion",  # Cause CUDA OOM on ROCm
238
    "test_utils",  # OOM
239
    "test_sort_and_select",  # OOM
240
    "test_backward_compatible_arguments",  # OOM
241
    "test_autocast",  # OOM
242
    "test_native_mha",  # OOM
243
    "test_module_hooks",  # OOM
244
    "inductor/test_max_autotune",  # Testing, probably revert later
245
    "inductor/test_torchinductor",  # OOM on test_large_block_sizes
246
    "inductor/test_torchinductor_dynamic_shapes",  # OOM on test_large_block_sizes
247
    "inductor/test_torchinductor_codegen_dynamic_shapes",  # OOM on test_large_block_sizes
248
]
249
# A subset of onnx tests that cannot run in parallel due to high memory usage.
250
ONNX_SERIAL_LIST = [
251
    "onnx/test_models",
252
    "onnx/test_models_quantized_onnxruntime",
253
    "onnx/test_models_onnxruntime",
254
    "onnx/test_custom_ops",
255
    "onnx/test_utility_funs",
256
]
257

258
# A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected
259
CORE_TEST_LIST = [
260
    "test_autograd",
261
    "test_autograd_fallback",
262
    "test_modules",
263
    "test_nn",
264
    "test_ops",
265
    "test_ops_gradients",
266
    "test_ops_fwd_gradients",
267
    "test_ops_jit",
268
    "test_torch",
269
]
270

271

272
# if a test file takes longer than 5 min, we add it to TARGET_DET_LIST
273
SLOW_TEST_THRESHOLD = 300
274

275
DISTRIBUTED_TESTS_CONFIG = {}
276

277

278
if dist.is_available():
279
    DISTRIBUTED_TESTS_CONFIG["test"] = {"WORLD_SIZE": "1"}
280
    if not TEST_WITH_ROCM and dist.is_mpi_available():
281
        DISTRIBUTED_TESTS_CONFIG["mpi"] = {
282
            "WORLD_SIZE": "3",
283
            "TEST_REPORT_SOURCE_OVERRIDE": "dist-mpi",
284
        }
285
    if dist.is_nccl_available():
286
        DISTRIBUTED_TESTS_CONFIG["nccl"] = {
287
            "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
288
            "TEST_REPORT_SOURCE_OVERRIDE": "dist-nccl",
289
        }
290
    if dist.is_gloo_available():
291
        DISTRIBUTED_TESTS_CONFIG["gloo"] = {
292
            "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
293
            "TEST_REPORT_SOURCE_OVERRIDE": "dist-gloo",
294
        }
295
    if dist.is_ucc_available():
296
        DISTRIBUTED_TESTS_CONFIG["ucc"] = {
297
            "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3",
298
            "TEST_REPORT_SOURCE_OVERRIDE": "dist-ucc",
299
            "UCX_TLS": "tcp,cuda",
300
            "UCC_TLS": "nccl,ucp,cuda",
301
            "UCC_TL_UCP_TUNE": "cuda:0",  # don't use UCP TL on CUDA as it is not well supported
302
            "UCC_EC_CUDA_USE_COOPERATIVE_LAUNCH": "n",  # CI nodes (M60) fail if it is on
303
        }
304

305
# https://stackoverflow.com/questions/2549939/get-signal-names-from-numbers-in-python
306
SIGNALS_TO_NAMES_DICT = {
307
    getattr(signal, n): n for n in dir(signal) if n.startswith("SIG") and "_" not in n
308
}
309

310
CPP_EXTENSIONS_ERROR = """
311
Ninja (https://ninja-build.org) is required for some of the C++ extensions
312
tests, but it could not be found. Install ninja with `pip install ninja`
313
or `conda install ninja`. Alternatively, disable said tests with
314
`run_test.py --exclude test_cpp_extensions_aot_ninja test_cpp_extensions_jit`.
315
"""
316

317
PYTORCH_COLLECT_COVERAGE = bool(os.environ.get("PYTORCH_COLLECT_COVERAGE"))
318

319
JIT_EXECUTOR_TESTS = [
320
    "test_jit_profiling",
321
    "test_jit_legacy",
322
    "test_jit_fuser_legacy",
323
]
324

325
INDUCTOR_TESTS = [test for test in TESTS if test.startswith(INDUCTOR_TEST_PREFIX)]
326
DYNAMO_TESTS = [test for test in TESTS if test.startswith(DYNAMO_TEST_PREFIX)]
327
DISTRIBUTED_TESTS = [test for test in TESTS if test.startswith(DISTRIBUTED_TEST_PREFIX)]
328
TORCH_EXPORT_TESTS = [test for test in TESTS if test.startswith("export")]
329
FUNCTORCH_TESTS = [test for test in TESTS if test.startswith("functorch")]
330
ONNX_TESTS = [test for test in TESTS if test.startswith("onnx")]
331
CPP_TESTS = [test for test in TESTS if test.startswith(CPP_TEST_PREFIX)]
332

333
TESTS_REQUIRING_LAPACK = [
334
    "distributions/test_constraints",
335
    "distributions/test_distributions",
336
]
337

338
# These are just the slowest ones, this isn't an exhaustive list.
339
TESTS_NOT_USING_GRADCHECK = [
340
    # Note that you should use skipIfSlowGradcheckEnv if you do not wish to
341
    # skip all the tests in that file, e.g. test_mps
342
    "doctests",
343
    "test_meta",
344
    "test_hub",
345
    "test_fx",
346
    "test_decomp",
347
    "test_cpp_extensions_jit",
348
    "test_jit",
349
    "test_ops",
350
    "test_ops_jit",
351
    "dynamo/test_recompile_ux",
352
    "inductor/test_smoke",
353
    "test_quantization",
354
]
355

356

357
def print_to_stderr(message):
358
    print(message, file=sys.stderr)
359

360

361
def get_executable_command(options, disable_coverage=False, is_cpp_test=False):
362
    if options.coverage and not disable_coverage:
363
        if not is_cpp_test:
364
            executable = ["coverage", "run", "--parallel-mode", "--source=torch"]
365
        else:
366
            # TODO: C++ with coverage is not yet supported
367
            executable = []
368
    else:
369
        if not is_cpp_test:
370
            executable = [sys.executable, "-bb"]
371
        else:
372
            executable = ["pytest"]
373

374
    return executable
375

376

377
def run_test(
378
    test_module: ShardedTest,
379
    test_directory,
380
    options,
381
    launcher_cmd=None,
382
    extra_unittest_args=None,
383
    env=None,
384
) -> int:
385
    env = env or os.environ.copy()
386
    maybe_set_hip_visible_devies()
387
    unittest_args = options.additional_unittest_args.copy()
388
    test_file = test_module.name
389
    stepcurrent_key = test_file
390

391
    is_distributed_test = test_file.startswith(DISTRIBUTED_TEST_PREFIX)
392
    is_cpp_test = test_file.startswith(CPP_TEST_PREFIX)
393
    # NB: Rerun disabled tests depends on pytest-flakefinder and it doesn't work with
394
    # pytest-cpp atm. We also don't have support to disable C++ test yet, so it's ok
395
    # to just return successfully here
396
    if is_cpp_test and RERUN_DISABLED_TESTS:
397
        print_to_stderr(
398
            "Skipping C++ tests when running under RERUN_DISABLED_TESTS mode"
399
        )
400
        return 0
401

402
    if is_cpp_test:
403
        stepcurrent_key = f"{test_file}_{os.urandom(8).hex()}"
404
    else:
405
        unittest_args.extend(
406
            [
407
                f"--shard-id={test_module.shard}",
408
                f"--num-shards={test_module.num_shards}",
409
            ]
410
        )
411
        stepcurrent_key = f"{test_file}_{test_module.shard}_{os.urandom(8).hex()}"
412

413
    if options.verbose:
414
        unittest_args.append(f'-{"v"*options.verbose}')  # in case of pytest
415

416
    if test_file in RUN_PARALLEL_BLOCKLIST:
417
        unittest_args = [
418
            arg for arg in unittest_args if not arg.startswith("--run-parallel")
419
        ]
420

421
    if extra_unittest_args:
422
        assert isinstance(extra_unittest_args, list)
423
        unittest_args.extend(extra_unittest_args)
424

425
    # If using pytest, replace -f with equivalent -x
426
    if options.pytest:
427
        unittest_args.extend(
428
            get_pytest_args(
429
                options,
430
                is_cpp_test=is_cpp_test,
431
                is_distributed_test=is_distributed_test,
432
            )
433
        )
434
        unittest_args.extend(test_module.get_pytest_args())
435
        unittest_args = [arg if arg != "-f" else "-x" for arg in unittest_args]
436

437
    # NB: These features are not available for C++ tests, but there is little incentive
438
    # to implement it because we have never seen a flaky C++ test before.
439
    if IS_CI and not is_cpp_test:
440
        ci_args = ["--import-slow-tests", "--import-disabled-tests"]
441
        if RERUN_DISABLED_TESTS:
442
            ci_args.append("--rerun-disabled-tests")
443
        # use the downloaded test cases configuration, not supported in pytest
444
        unittest_args.extend(ci_args)
445

446
    if test_file in PYTEST_SKIP_RETRIES:
447
        if not options.pytest:
448
            raise RuntimeError(
449
                "A test running without pytest cannot skip retries using "
450
                "the PYTEST_SKIP_RETRIES set."
451
            )
452
        unittest_args = [arg for arg in unittest_args if "--reruns" not in arg]
453

454
    # Extra arguments are not supported with pytest
455
    executable = get_executable_command(options, is_cpp_test=is_cpp_test)
456
    if not executable:
457
        # If there is no eligible executable returning here, it means an unsupported
458
        # case such as coverage for C++ test. So just returning ok makes sense
459
        return 0
460

461
    if test_file.startswith(CPP_TEST_PREFIX):
462
        # C++ tests are not the regular test directory
463
        if CPP_TESTS_DIR:
464
            cpp_test = os.path.join(
465
                CPP_TESTS_DIR,
466
                test_file.replace(f"{CPP_TEST_PREFIX}/", ""),
467
            )
468
        else:
469
            cpp_test = os.path.join(
470
                pathlib.Path(test_directory).parent,
471
                CPP_TEST_PATH,
472
                test_file.replace(f"{CPP_TEST_PREFIX}/", ""),
473
            )
474

475
        argv = [
476
            cpp_test if sys.platform != "win32" else cpp_test + ".exe"
477
        ] + unittest_args
478
    else:
479
        # Can't call `python -m unittest test_*` here because it doesn't run code
480
        # in `if __name__ == '__main__': `. So call `python test_*.py` instead.
481
        argv = [test_file + ".py"] + unittest_args
482

483
    os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
484
    if options.pipe_logs:
485
        log_fd, log_path = tempfile.mkstemp(
486
            dir=REPO_ROOT / "test" / "test-reports",
487
            prefix=f"{sanitize_file_name(str(test_module))}_",
488
            suffix="_toprint.log",
489
        )
490
        os.close(log_fd)
491

492
    command = (launcher_cmd or []) + executable + argv
493
    should_retry = "--subprocess" not in command and not RERUN_DISABLED_TESTS
494
    is_slow = "slow" in os.environ.get("TEST_CONFIG", "") or "slow" in os.environ.get(
495
        "BUILD_ENVRIONMENT", ""
496
    )
497
    timeout = (
498
        None
499
        if not options.enable_timeout
500
        else THRESHOLD * 6
501
        if is_slow
502
        else THRESHOLD * 3
503
        if should_retry
504
        and isinstance(test_module, ShardedTest)
505
        and test_module.time is not None
506
        else None
507
    )
508
    print_to_stderr(f"Executing {command} ... [{datetime.now()}]")
509

510
    with ExitStack() as stack:
511
        output = None
512
        if options.pipe_logs:
513
            output = stack.enter_context(open(log_path, "w"))
514

515
        if should_retry:
516
            ret_code, was_rerun = run_test_retries(
517
                command,
518
                test_directory,
519
                env,
520
                timeout,
521
                stepcurrent_key,
522
                output,
523
                options.continue_through_error,
524
            )
525
        else:
526
            command.extend([f"--sc={stepcurrent_key}", "--print-items"])
527
            ret_code, was_rerun = retry_shell(
528
                command,
529
                test_directory,
530
                stdout=output,
531
                stderr=output,
532
                env=env,
533
                timeout=timeout,
534
            )
535

536
            # Pytest return code 5 means no test is collected. This is needed
537
            # here as we use pytest directly when running C++ tests. Return
538
            # code 4 is ok too as this happens when the binary is not a C++
539
            # test executable. All binary files under build/bin that are not
540
            # C++ test at the time of this writing have been excluded, but we
541
            # can accept code 4 too just in case a new non-test binary file
542
            # comes up in the future.
543
            ret_code = 0 if ret_code == 5 or ret_code == 4 else ret_code
544

545
    if options.pipe_logs:
546
        handle_log_file(
547
            test_module, log_path, failed=(ret_code != 0), was_rerun=was_rerun
548
        )
549
    return ret_code
550

551

552
def run_test_retries(
553
    command,
554
    test_directory,
555
    env,
556
    timeout,
557
    stepcurrent_key,
558
    output,
559
    continue_through_error,
560
):
561
    # Run the test with -x to stop at first failure. Try again, skipping the
562
    # previously run tests, repeating this until there is a test that fails 3
563
    # times (same number of rVetries we typically give).
564
    #
565
    # If continue through error is not set, then we fail fast.
566
    #
567
    # If continue through error is set, then we skip that test, and keep going.
568
    # Basically if the same test fails 3 times in a row, skip the test on the
569
    # next run, but still fail in the end. I take advantage of the value saved
570
    # in stepcurrent to keep track of the most recently run test (which is the
571
    # one that failed if there was a failure).
572

573
    def print_to_file(s):
574
        print(s, file=output, flush=True)
575

576
    num_failures = defaultdict(int)
577

578
    print_items = ["--print-items"]
579
    sc_command = f"--sc={stepcurrent_key}"
580
    while True:
581
        ret_code, _ = retry_shell(
582
            command + [sc_command] + print_items,
583
            test_directory,
584
            stdout=output,
585
            stderr=output,
586
            env=env,
587
            timeout=timeout,
588
            retries=0,  # no retries here, we do it ourselves, this is because it handles timeout exceptions well
589
        )
590
        ret_code = 0 if ret_code == 5 or ret_code == 4 else ret_code
591
        if ret_code == 0:
592
            break  # Got to the end of the test suite successfully
593
        signal_name = f" ({SIGNALS_TO_NAMES_DICT[-ret_code]})" if ret_code < 0 else ""
594
        print_to_file(f"Got exit code {ret_code}{signal_name}")
595

596
        # Read what just failed
597
        try:
598
            with open(
599
                REPO_ROOT / ".pytest_cache/v/cache/stepcurrent" / stepcurrent_key
600
            ) as f:
601
                current_failure = f.read()
602
        except FileNotFoundError:
603
            print_to_file(
604
                "No stepcurrent file found. Either pytest didn't get to run (e.g. import error)"
605
                + " or file got deleted (contact dev infra)"
606
            )
607
            break
608

609
        num_failures[current_failure] += 1
610
        if num_failures[current_failure] >= 3:
611
            if not continue_through_error:
612
                print_to_file("Stopping at first consistent failure")
613
                break
614
            sc_command = f"--scs={stepcurrent_key}"
615
        else:
616
            sc_command = f"--sc={stepcurrent_key}"
617
        print_to_file("Retrying...")
618
        # Print full c++ stack traces during retries
619
        # Don't do it for macos inductor tests as it makes them
620
        # segfault for some reason
621
        if not (
622
            IS_MACOS
623
            and len(command) >= 2
624
            and command[2].startswith(INDUCTOR_TEST_PREFIX)
625
        ):
626
            env = env or {}
627
            env["TORCH_SHOW_CPP_STACKTRACES"] = "1"
628
        print_items = []  # do not continue printing them, massive waste of space
629

630
    consistent_failures = [x[1:-1] for x in num_failures.keys() if num_failures[x] >= 3]
631
    flaky_failures = [x[1:-1] for x in num_failures.keys() if 0 < num_failures[x] < 3]
632
    if len(flaky_failures) > 0:
633
        print_to_file(
634
            "The following tests failed and then succeeded when run in a new process"
635
            + f"{flaky_failures}",
636
        )
637
    if len(consistent_failures) > 0:
638
        print_to_file(f"The following tests failed consistently: {consistent_failures}")
639
        return 1, True
640
    return ret_code, any(x > 0 for x in num_failures.values())
641

642

643
def run_test_with_subprocess(test_module, test_directory, options):
644
    return run_test(
645
        test_module, test_directory, options, extra_unittest_args=["--subprocess"]
646
    )
647

648

649
def _test_cpp_extensions_aot(test_directory, options, use_ninja):
650
    if use_ninja:
651
        try:
652
            from torch.utils import cpp_extension
653

654
            cpp_extension.verify_ninja_availability()
655
        except RuntimeError:
656
            print_to_stderr(CPP_EXTENSIONS_ERROR)
657
            return 1
658

659
    # Wipe the build folder, if it exists already
660
    cpp_extensions_test_dir = os.path.join(test_directory, "cpp_extensions")
661
    cpp_extensions_test_build_dir = os.path.join(cpp_extensions_test_dir, "build")
662
    if os.path.exists(cpp_extensions_test_build_dir):
663
        shutil.rmtree(cpp_extensions_test_build_dir)
664

665
    # Build the test cpp extensions modules
666
    shell_env = os.environ.copy()
667
    shell_env["USE_NINJA"] = str(1 if use_ninja else 0)
668
    cmd = [sys.executable, "setup.py", "install", "--root", "./install"]
669
    return_code = shell(cmd, cwd=cpp_extensions_test_dir, env=shell_env)
670
    if return_code != 0:
671
        return return_code
672
    if sys.platform != "win32":
673
        return_code = shell(
674
            cmd,
675
            cwd=os.path.join(cpp_extensions_test_dir, "no_python_abi_suffix_test"),
676
            env=shell_env,
677
        )
678
        if return_code != 0:
679
            return return_code
680

681
    # "install" the test modules and run tests
682
    python_path = os.environ.get("PYTHONPATH", "")
683
    from shutil import copyfile
684

685
    os.environ["USE_NINJA"] = shell_env["USE_NINJA"]
686
    test_module = "test_cpp_extensions_aot" + ("_ninja" if use_ninja else "_no_ninja")
687
    copyfile(
688
        test_directory + "/test_cpp_extensions_aot.py",
689
        test_directory + "/" + test_module + ".py",
690
    )
691
    try:
692
        cpp_extensions = os.path.join(test_directory, "cpp_extensions")
693
        install_directory = ""
694
        # install directory is the one that is named site-packages
695
        for root, directories, _ in os.walk(os.path.join(cpp_extensions, "install")):
696
            for directory in directories:
697
                if "-packages" in directory:
698
                    install_directory = os.path.join(root, directory)
699

700
        assert install_directory, "install_directory must not be empty"
701
        os.environ["PYTHONPATH"] = os.pathsep.join([install_directory, python_path])
702
        return run_test(ShardedTest(test_module, 1, 1), test_directory, options)
703
    finally:
704
        os.environ["PYTHONPATH"] = python_path
705
        if os.path.exists(test_directory + "/" + test_module + ".py"):
706
            os.remove(test_directory + "/" + test_module + ".py")
707
        os.environ.pop("USE_NINJA")
708

709

710
def test_cpp_extensions_aot_ninja(test_module, test_directory, options):
711
    return _test_cpp_extensions_aot(test_directory, options, use_ninja=True)
712

713

714
def test_cpp_extensions_aot_no_ninja(test_module, test_directory, options):
715
    return _test_cpp_extensions_aot(test_directory, options, use_ninja=False)
716

717

718
def test_distributed(test_module, test_directory, options):
719
    # MPI tests are broken with Python-3.9
720
    mpi_available = subprocess.call(
721
        "command -v mpiexec", shell=True
722
    ) == 0 and sys.version_info < (3, 9)
723
    if options.verbose and not mpi_available:
724
        print_to_stderr("MPI not available -- MPI backend tests will be skipped")
725

726
    config = DISTRIBUTED_TESTS_CONFIG
727
    for backend, env_vars in config.items():
728
        if sys.platform == "win32" and backend != "gloo":
729
            continue
730
        if backend == "mpi" and not mpi_available:
731
            continue
732
        for with_init_file in {True, False}:
733
            if sys.platform == "win32" and not with_init_file:
734
                continue
735
            tmp_dir = tempfile.mkdtemp()
736
            if options.verbose:
737
                init_str = "with {} init_method"
738
                with_init = init_str.format("file" if with_init_file else "env")
739
                print_to_stderr(
740
                    f"Running distributed tests for the {backend} backend {with_init}"
741
                )
742
            old_environ = dict(os.environ)
743
            os.environ["TEMP_DIR"] = tmp_dir
744
            os.environ["BACKEND"] = backend
745
            os.environ["INIT_METHOD"] = "env://"
746
            os.environ.update(env_vars)
747
            if with_init_file:
748
                if test_module.name == "test_distributed_spawn":
749
                    init_method = f"{FILE_SCHEMA}{tmp_dir}/"
750
                else:
751
                    init_method = f"{FILE_SCHEMA}{tmp_dir}/shared_init_file"
752
                os.environ["INIT_METHOD"] = init_method
753
            try:
754
                os.mkdir(os.path.join(tmp_dir, "barrier"))
755
                os.mkdir(os.path.join(tmp_dir, "test_dir"))
756
                if backend == "mpi":
757
                    # test mpiexec for --noprefix option
758
                    with open(os.devnull, "w") as devnull:
759
                        allowrunasroot_opt = (
760
                            "--allow-run-as-root"
761
                            if subprocess.call(
762
                                'mpiexec --allow-run-as-root -n 1 bash -c ""',
763
                                shell=True,
764
                                stdout=devnull,
765
                                stderr=subprocess.STDOUT,
766
                            )
767
                            == 0
768
                            else ""
769
                        )
770
                        noprefix_opt = (
771
                            "--noprefix"
772
                            if subprocess.call(
773
                                f'mpiexec {allowrunasroot_opt} -n 1 --noprefix bash -c ""',
774
                                shell=True,
775
                                stdout=devnull,
776
                                stderr=subprocess.STDOUT,
777
                            )
778
                            == 0
779
                            else ""
780
                        )
781

782
                    mpiexec = ["mpiexec", "-n", "3", noprefix_opt, allowrunasroot_opt]
783

784
                    return_code = run_test(
785
                        test_module, test_directory, options, launcher_cmd=mpiexec
786
                    )
787
                else:
788
                    return_code = run_test(
789
                        test_module,
790
                        test_directory,
791
                        options,
792
                        extra_unittest_args=["--subprocess"],
793
                    )
794
                if return_code != 0:
795
                    return return_code
796
            finally:
797
                shutil.rmtree(tmp_dir)
798
                os.environ.clear()
799
                os.environ.update(old_environ)
800
    return 0
801

802

803
def run_doctests(test_module, test_directory, options):
804
    """
805
    Assumes the incoming test module is called doctest, and simply executes the
806
    xdoctest runner on the torch library itself.
807
    """
808
    import pathlib
809

810
    import xdoctest
811

812
    pkgpath = pathlib.Path(torch.__file__).parent
813

814
    exclude_module_list = ["torch._vendor.*"]
815
    enabled = {
816
        # TODO: expose these options to the user
817
        # For now disable all feature-conditional tests
818
        # 'lapack': 'auto',
819
        # 'cuda': 'auto',
820
        # 'cuda1': 'auto',
821
        # 'qengine': 'auto',
822
        "lapack": 0,
823
        "cuda": 0,
824
        "cuda1": 0,
825
        "qengine": 0,
826
        "autograd_profiler": 0,
827
        "cpp_ext": 0,
828
        "monitor": 0,
829
        "onnx": "auto",
830
    }
831

832
    # Resolve "auto" based on a test to determine if the feature is available.
833
    if enabled["cuda"] == "auto" and torch.cuda.is_available():
834
        enabled["cuda"] = True
835

836
    if (
837
        enabled["cuda1"] == "auto"
838
        and torch.cuda.is_available()
839
        and torch.cuda.device_count() > 1
840
    ):
841
        enabled["cuda1"] = True
842

843
    if enabled["lapack"] == "auto" and torch._C.has_lapack:
844
        enabled["lapack"] = True
845

846
    if enabled["qengine"] == "auto":
847
        try:
848
            # Is there a better check if quantization is enabled?
849
            import torch.ao.nn.quantized as nnq  # NOQA: F401
850

851
            torch.backends.quantized.engine = "qnnpack"
852
            torch.backends.quantized.engine = "fbgemm"
853
        except (ImportError, RuntimeError):
854
            ...
855
        else:
856
            enabled["qengine"] = True
857

858
    if enabled["onnx"] == "auto":
859
        try:
860
            import onnx  # NOQA: F401
861
            import onnxruntime  # NOQA: F401
862
            import onnxscript  # NOQA: F401
863
        except ImportError:
864
            exclude_module_list.append("torch.onnx.*")
865
            enabled["onnx"] = False
866
        else:
867
            enabled["onnx"] = True
868

869
    # Set doctest environment variables
870
    if enabled["cuda"]:
871
        os.environ["TORCH_DOCTEST_CUDA"] = "1"
872

873
    if enabled["cuda1"]:
874
        os.environ["TORCH_DOCTEST_CUDA1"] = "1"
875

876
    if enabled["lapack"]:
877
        os.environ["TORCH_DOCTEST_LAPACK"] = "1"
878

879
    if enabled["qengine"]:
880
        os.environ["TORCH_DOCTEST_QENGINE"] = "1"
881

882
    if enabled["autograd_profiler"]:
883
        os.environ["TORCH_DOCTEST_AUTOGRAD_PROFILER"] = "1"
884

885
    if enabled["cpp_ext"]:
886
        os.environ["TORCH_DOCTEST_CPP_EXT"] = "1"
887

888
    if enabled["monitor"]:
889
        os.environ["TORCH_DOCTEST_MONITOR"] = "1"
890

891
    if enabled["onnx"]:
892
        os.environ["TORCH_DOCTEST_ONNX"] = "1"
893

894
    if 0:
895
        # TODO: could try to enable some of these
896
        os.environ["TORCH_DOCTEST_QUANTIZED_DYNAMIC"] = "1"
897
        os.environ["TORCH_DOCTEST_ANOMALY"] = "1"
898
        os.environ["TORCH_DOCTEST_AUTOGRAD"] = "1"
899
        os.environ["TORCH_DOCTEST_HUB"] = "1"
900
        os.environ["TORCH_DOCTEST_DATALOADER"] = "1"
901
        os.environ["TORCH_DOCTEST_FUTURES"] = "1"
902

903
    pkgpath = os.path.dirname(torch.__file__)
904

905
    xdoctest_config = {
906
        "global_exec": r"\n".join(
907
            [
908
                "from torch import nn",
909
                "import torch.nn.functional as F",
910
                "import torch",
911
            ]
912
        ),
913
        "analysis": "static",  # set to "auto" to test doctests in compiled modules
914
        "style": "google",
915
        "options": "+IGNORE_WHITESPACE",
916
    }
917
    xdoctest_verbose = max(1, options.verbose)
918
    run_summary = xdoctest.runner.doctest_module(
919
        os.fspath(pkgpath),
920
        config=xdoctest_config,
921
        verbose=xdoctest_verbose,
922
        command=options.xdoctest_command,
923
        argv=[],
924
        exclude=exclude_module_list,
925
    )
926
    result = 1 if run_summary.get("n_failed", 0) else 0
927
    return result
928

929

930
def sanitize_file_name(file: str):
931
    return file.replace("\\", ".").replace("/", ".").replace(" ", "_")
932

933

934
def handle_log_file(
935
    test: ShardedTest, file_path: str, failed: bool, was_rerun: bool
936
) -> None:
937
    test = str(test)
938
    with open(file_path, errors="ignore") as f:
939
        full_text = f.read()
940

941
    new_file = "test/test-reports/" + sanitize_file_name(
942
        f"{test}_{os.urandom(8).hex()}_.log"
943
    )
944
    os.rename(file_path, REPO_ROOT / new_file)
945

946
    if not failed and not was_rerun and "=== RERUNS ===" not in full_text:
947
        # If success + no retries (idk how else to check for test level retries
948
        # other than reparse xml), print only what tests ran
949
        print_to_stderr(
950
            f"\n{test} was successful, full logs can be found in artifacts with path {new_file}"
951
        )
952
        for line in full_text.splitlines():
953
            if re.search("Running .* items in this shard:", line):
954
                print_to_stderr(line.rstrip())
955
        print_to_stderr("")
956
        return
957

958
    # otherwise: print entire file
959
    print_to_stderr(f"\nPRINTING LOG FILE of {test} ({new_file})")
960
    print_to_stderr(full_text)
961
    print_to_stderr(f"FINISHED PRINTING LOG FILE of {test} ({new_file})\n")
962

963

964
def get_pytest_args(options, is_cpp_test=False, is_distributed_test=False):
965
    if RERUN_DISABLED_TESTS:
966
        # Distributed tests are too slow, so running them x50 will cause the jobs to timeout after
967
        # 3+ hours. So, let's opt for less number of reruns. We need at least 150 instances of the
968
        # test every 2 weeks to satisfy the Rockset query (15 x 14 = 210). The same logic applies
969
        # to ASAN, which is also slow
970
        count = 15 if is_distributed_test or TEST_WITH_ASAN else 50
971
        # When under rerun-disabled-tests mode, run the same tests multiple times to determine their
972
        # flakiness status. Default to 50 re-runs
973
        rerun_options = ["--flake-finder", f"--flake-runs={count}"]
974
    else:
975
        # When under the normal mode, retry a failed test 2 more times. -x means stop at the first
976
        # failure
977
        rerun_options = ["-x", "--reruns=2"]
978

979
    pytest_args = [
980
        "-vv",
981
        "-rfEX",
982
    ]
983
    if not is_cpp_test:
984
        # C++ tests need to be run with pytest directly, not via python
985
        # We have a custom pytest shard that conflicts with the normal plugin
986
        pytest_args.extend(["-p", "no:xdist", "--use-pytest"])
987
    else:
988
        # Use pytext-dist to run C++ tests in parallel as running them sequentially using run_test
989
        # is much slower than running them directly
990
        pytest_args.extend(["-n", str(NUM_PROCS)])
991

992
        if IS_CI:
993
            # Add the option to generate XML test report here as C++ tests
994
            # won't go into common_utils
995
            test_report_path = get_report_path(pytest=True)
996
            pytest_args.extend(["--junit-xml-reruns", test_report_path])
997

998
    if options.pytest_k_expr:
999
        pytest_args.extend(["-k", options.pytest_k_expr])
1000

1001
    pytest_args.extend(rerun_options)
1002
    return pytest_args
1003

1004

1005
CUSTOM_HANDLERS = {
1006
    "test_cuda_primary_ctx": run_test_with_subprocess,
1007
    "test_cuda_nvml_based_avail": run_test_with_subprocess,
1008
    "test_cuda_trace": run_test_with_subprocess,
1009
    "test_cpp_extensions_aot_no_ninja": test_cpp_extensions_aot_no_ninja,
1010
    "test_cpp_extensions_aot_ninja": test_cpp_extensions_aot_ninja,
1011
    "distributed/test_distributed_spawn": test_distributed,
1012
    "distributed/algorithms/quantization/test_quantization": test_distributed,
1013
    "distributed/test_c10d_nccl": run_test_with_subprocess,
1014
    "distributed/test_c10d_gloo": run_test_with_subprocess,
1015
    "distributed/test_c10d_ucc": run_test_with_subprocess,
1016
    "distributed/test_c10d_common": run_test_with_subprocess,
1017
    "distributed/test_c10d_spawn_gloo": run_test_with_subprocess,
1018
    "distributed/test_c10d_spawn_nccl": run_test_with_subprocess,
1019
    "distributed/test_c10d_spawn_ucc": run_test_with_subprocess,
1020
    "distributed/test_store": run_test_with_subprocess,
1021
    "distributed/test_pg_wrapper": run_test_with_subprocess,
1022
    "distributed/rpc/test_faulty_agent": run_test_with_subprocess,
1023
    "distributed/rpc/test_tensorpipe_agent": run_test_with_subprocess,
1024
    "distributed/rpc/test_share_memory": run_test_with_subprocess,
1025
    "distributed/rpc/cuda/test_tensorpipe_agent": run_test_with_subprocess,
1026
    "doctests": run_doctests,
1027
}
1028

1029

1030
PYTEST_SKIP_RETRIES = {"test_public_bindings"}
1031

1032

1033
def parse_args():
1034
    parser = argparse.ArgumentParser(
1035
        description="Run the PyTorch unit test suite",
1036
        epilog="where TESTS is any of: {}".format(", ".join(TESTS)),
1037
        formatter_class=argparse.RawTextHelpFormatter,
1038
        parents=[common_parser],
1039
    )
1040
    parser.add_argument(
1041
        "-v",
1042
        "--verbose",
1043
        action="count",
1044
        default=0,
1045
        help="Print verbose information and test-by-test results",
1046
    )
1047
    parser.add_argument("--jit", "--jit", action="store_true", help="run all jit tests")
1048
    parser.add_argument(
1049
        "--distributed-tests",
1050
        "--distributed-tests",
1051
        action="store_true",
1052
        help="Run all distributed tests",
1053
    )
1054
    parser.add_argument(
1055
        "--functorch",
1056
        "--functorch",
1057
        action="store_true",
1058
        help=(
1059
            "If this flag is present, we will only run functorch tests. "
1060
            "If this flag is not present, we will run all tests "
1061
            "(including functorch tests)."
1062
        ),
1063
    )
1064
    parser.add_argument(
1065
        "--mps",
1066
        "--mps",
1067
        action="store_true",
1068
        help=("If this flag is present, we will only run test_mps and test_metal"),
1069
    )
1070
    parser.add_argument(
1071
        "--xpu",
1072
        "--xpu",
1073
        action="store_true",
1074
        help=("If this flag is present, we will run xpu tests except XPU_BLOCK_LIST"),
1075
    )
1076
    parser.add_argument(
1077
        "--cpp",
1078
        "--cpp",
1079
        action="store_true",
1080
        help=("If this flag is present, we will only run C++ tests"),
1081
    )
1082
    parser.add_argument(
1083
        "-core",
1084
        "--core",
1085
        action="store_true",
1086
        help="Only run core tests, or tests that validate PyTorch's ops, modules,"
1087
        "and autograd. They are defined by CORE_TEST_LIST.",
1088
    )
1089
    parser.add_argument(
1090
        "--onnx",
1091
        "--onnx",
1092
        action="store_true",
1093
        help=(
1094
            "Only run ONNX tests, or tests that validate PyTorch's ONNX export. "
1095
            "If this flag is not present, we will exclude ONNX tests."
1096
        ),
1097
    )
1098
    parser.add_argument(
1099
        "-k",
1100
        "--pytest-k-expr",
1101
        default="",
1102
        help="Pass to pytest as its -k expr argument",
1103
    )
1104
    parser.add_argument(
1105
        "-c",
1106
        "--coverage",
1107
        action="store_true",
1108
        help="enable coverage",
1109
        default=PYTORCH_COLLECT_COVERAGE,
1110
    )
1111
    parser.add_argument(
1112
        "-i",
1113
        "--include",
1114
        nargs="+",
1115
        choices=TestChoices(TESTS),
1116
        default=TESTS,
1117
        metavar="TESTS",
1118
        help="select a set of tests to include (defaults to ALL tests)."
1119
        " tests must be a part of the TESTS list defined in run_test.py",
1120
    )
1121
    parser.add_argument(
1122
        "-x",
1123
        "--exclude",
1124
        nargs="+",
1125
        choices=TESTS,
1126
        metavar="TESTS",
1127
        default=[],
1128
        help="select a set of tests to exclude",
1129
    )
1130
    parser.add_argument(
1131
        "--ignore-win-blocklist",
1132
        action="store_true",
1133
        help="always run blocklisted windows tests",
1134
    )
1135
    # NS: Disable target determination until it can be made more reliable
1136
    # parser.add_argument(
1137
    #     "--determine-from",
1138
    #     help="File of affected source filenames to determine which tests to run.",
1139
    # )
1140
    parser.add_argument(
1141
        "--continue-through-error",
1142
        "--keep-going",
1143
        action="store_true",
1144
        help="Runs the full test suite despite one of the tests failing",
1145
        default=strtobool(os.environ.get("CONTINUE_THROUGH_ERROR", "False")),
1146
    )
1147
    parser.add_argument(
1148
        "--pipe-logs",
1149
        action="store_true",
1150
        help="Print logs to output file while running tests.  True if in CI and env var is not set",
1151
        default=IS_CI and not strtobool(os.environ.get("VERBOSE_TEST_LOGS", "False")),
1152
    )
1153
    parser.add_argument(
1154
        "--enable-timeout",
1155
        action="store_true",
1156
        help="Set a timeout based on the test times json file.  Only works if there are test times available",
1157
        default=IS_CI and not strtobool(os.environ.get("NO_TEST_TIMEOUT", "False")),
1158
    )
1159
    parser.add_argument(
1160
        "--enable-td",
1161
        action="store_true",
1162
        help="Enables removing tests based on TD",
1163
        default=IS_CI
1164
        and TEST_WITH_CROSSREF
1165
        and os.getenv("BRANCH", "") != "main"
1166
        and not strtobool(os.environ.get("NO_TD", "False")),
1167
    )
1168
    parser.add_argument(
1169
        "additional_unittest_args",
1170
        nargs="*",
1171
        help="additional arguments passed through to unittest, e.g., "
1172
        "python run_test.py -i sparse -- TestSparse.test_factory_size_check",
1173
    )
1174
    parser.add_argument(
1175
        "--shard",
1176
        nargs=2,
1177
        type=int,
1178
        help="runs a shard of the tests (taking into account other selections), e.g., "
1179
        "--shard 2 3 will break up the selected tests into 3 shards and run the tests "
1180
        "in the 2nd shard (the first number should not exceed the second)",
1181
    )
1182
    parser.add_argument(
1183
        "--exclude-jit-executor",
1184
        action="store_true",
1185
        help="exclude tests that are run for a specific jit config",
1186
    )
1187
    parser.add_argument(
1188
        "--exclude-torch-export-tests",
1189
        action="store_true",
1190
        help="exclude torch export tests",
1191
    )
1192
    parser.add_argument(
1193
        "--exclude-distributed-tests",
1194
        action="store_true",
1195
        help="exclude distributed tests",
1196
    )
1197
    parser.add_argument(
1198
        "--exclude-inductor-tests",
1199
        action="store_true",
1200
        help="exclude inductor tests",
1201
    )
1202
    parser.add_argument(
1203
        "--dry-run",
1204
        action="store_true",
1205
        help="Only list the test that will run.",
1206
    )
1207
    parser.add_argument(
1208
        "--xdoctest-command",
1209
        default="all",
1210
        help=(
1211
            "Control the specific doctest action. "
1212
            "Use 'list' to simply parse doctests and check syntax. "
1213
            "Use 'all' to execute all doctests or specify a specific "
1214
            "doctest to run"
1215
        ),
1216
    )
1217
    parser.add_argument(
1218
        "--no-translation-validation",
1219
        action="store_false",
1220
        help="Run tests without translation validation.",
1221
    )
1222

1223
    group = parser.add_mutually_exclusive_group()
1224
    group.add_argument(
1225
        "--dynamo",
1226
        action="store_true",
1227
        help="Run tests with TorchDynamo+EagerBackend turned on",
1228
    )
1229
    group.add_argument(
1230
        "--inductor",
1231
        action="store_true",
1232
        help="Run tests with TorchInductor turned on",
1233
    )
1234

1235
    return parser.parse_args()
1236

1237

1238
def exclude_tests(
1239
    exclude_list, selected_tests, exclude_message=None, exact_match=False
1240
):
1241
    for exclude_test in exclude_list:
1242
        tests_copy = selected_tests[:]
1243
        for test in tests_copy:
1244
            if (
1245
                not exact_match and test.startswith(exclude_test)
1246
            ) or test == exclude_test:
1247
                if exclude_message is not None:
1248
                    print_to_stderr(f"Excluding {test} {exclude_message}")
1249
                selected_tests.remove(test)
1250
    return selected_tests
1251

1252

1253
def must_serial(file: Union[str, ShardedTest]) -> bool:
1254
    if isinstance(file, ShardedTest):
1255
        file = file.name
1256
    return (
1257
        os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1"
1258
        or DISTRIBUTED_TEST_PREFIX in os.getenv("TEST_CONFIG", "")
1259
        or DISTRIBUTED_TEST_PREFIX in file
1260
        or file in CUSTOM_HANDLERS
1261
        or file in RUN_PARALLEL_BLOCKLIST
1262
        or file in CI_SERIAL_LIST
1263
        or file in JIT_EXECUTOR_TESTS
1264
        or file in ONNX_SERIAL_LIST
1265
        or NUM_PROCS == 1
1266
    )
1267

1268

1269
def can_run_in_pytest(test):
1270
    return os.getenv("PYTORCH_TEST_DO_NOT_USE_PYTEST", "0") == "0"
1271

1272

1273
def get_selected_tests(options) -> List[str]:
1274
    selected_tests = options.include
1275

1276
    # filter if there's JIT only and distributed only test options
1277
    if options.jit:
1278
        selected_tests = list(
1279
            filter(lambda test_name: "jit" in test_name, selected_tests)
1280
        )
1281

1282
    if options.distributed_tests:
1283
        selected_tests = list(
1284
            filter(lambda test_name: test_name in DISTRIBUTED_TESTS, selected_tests)
1285
        )
1286

1287
    # Filter to only run core tests when --core option is specified
1288
    if options.core:
1289
        selected_tests = list(
1290
            filter(lambda test_name: test_name in CORE_TEST_LIST, selected_tests)
1291
        )
1292

1293
    # Filter to only run functorch tests when --functorch option is specified
1294
    if options.functorch:
1295
        selected_tests = [tname for tname in selected_tests if tname in FUNCTORCH_TESTS]
1296

1297
    if options.cpp:
1298
        selected_tests = [tname for tname in selected_tests if tname in CPP_TESTS]
1299
    else:
1300
        # Exclude all C++ tests otherwise as they are still handled differently
1301
        # than Python test at the moment
1302
        options.exclude.extend(CPP_TESTS)
1303

1304
    if options.mps:
1305
        selected_tests = ["test_mps", "test_metal", "test_modules"]
1306
    else:
1307
        # Exclude all mps tests otherwise
1308
        options.exclude.extend(["test_mps", "test_metal"])
1309

1310
    if options.xpu:
1311
        selected_tests = exclude_tests(XPU_BLOCKLIST, selected_tests, "on XPU")
1312
    else:
1313
        # Exclude all xpu specifc tests otherwise
1314
        options.exclude.extend(XPU_TEST)
1315

1316
    # Filter to only run onnx tests when --onnx option is specified
1317
    onnx_tests = [tname for tname in selected_tests if tname in ONNX_TESTS]
1318
    if options.onnx:
1319
        selected_tests = onnx_tests
1320
    else:
1321
        # Exclude all onnx tests otherwise
1322
        options.exclude.extend(onnx_tests)
1323

1324
    # process exclusion
1325
    if options.exclude_jit_executor:
1326
        options.exclude.extend(JIT_EXECUTOR_TESTS)
1327

1328
    if options.exclude_distributed_tests:
1329
        options.exclude.extend(DISTRIBUTED_TESTS)
1330

1331
    if options.exclude_inductor_tests:
1332
        options.exclude.extend(INDUCTOR_TESTS)
1333

1334
    if options.exclude_torch_export_tests:
1335
        options.exclude.extend(TORCH_EXPORT_TESTS)
1336

1337
    # these tests failing in CUDA 11.6 temporary disabling. issue https://github.com/pytorch/pytorch/issues/75375
1338
    if torch.version.cuda is not None:
1339
        options.exclude.extend(["distributions/test_constraints"])
1340

1341
    # these tests failing in Python 3.12 temporarily disabling
1342
    if sys.version_info >= (3, 12):
1343
        options.exclude.extend(INDUCTOR_TESTS)
1344
        options.exclude.extend(DYNAMO_TESTS)
1345
        options.exclude.extend(
1346
            [
1347
                "functorch/test_dims",
1348
                "functorch/test_rearrange",
1349
                "functorch/test_parsing",
1350
                "functorch/test_memory_efficient_fusion",
1351
                "torch_np/numpy_tests/core/test_multiarray",
1352
            ]
1353
        )
1354

1355
    selected_tests = exclude_tests(options.exclude, selected_tests)
1356

1357
    if sys.platform == "win32" and not options.ignore_win_blocklist:
1358
        target_arch = os.environ.get("VSCMD_ARG_TGT_ARCH")
1359
        if target_arch != "x64":
1360
            WINDOWS_BLOCKLIST.append("cpp_extensions_aot_no_ninja")
1361
            WINDOWS_BLOCKLIST.append("cpp_extensions_aot_ninja")
1362
            WINDOWS_BLOCKLIST.append("cpp_extensions_jit")
1363
            WINDOWS_BLOCKLIST.append("jit")
1364
            WINDOWS_BLOCKLIST.append("jit_fuser")
1365

1366
        selected_tests = exclude_tests(WINDOWS_BLOCKLIST, selected_tests, "on Windows")
1367

1368
    elif TEST_WITH_ROCM:
1369
        selected_tests = exclude_tests(ROCM_BLOCKLIST, selected_tests, "on ROCm")
1370

1371
    # skip all distributed tests if distributed package is not available.
1372
    if not dist.is_available():
1373
        selected_tests = exclude_tests(
1374
            DISTRIBUTED_TESTS,
1375
            selected_tests,
1376
            "PyTorch is built without distributed support.",
1377
        )
1378

1379
    # skip tests that require LAPACK when it's not available
1380
    if not torch._C.has_lapack:
1381
        selected_tests = exclude_tests(
1382
            TESTS_REQUIRING_LAPACK,
1383
            selected_tests,
1384
            "PyTorch is built without LAPACK support.",
1385
        )
1386

1387
    if TEST_WITH_SLOW_GRADCHECK:
1388
        selected_tests = exclude_tests(
1389
            TESTS_NOT_USING_GRADCHECK,
1390
            selected_tests,
1391
            "Running in slow gradcheck mode, skipping tests "
1392
            "that don't use gradcheck.",
1393
            exact_match=True,
1394
        )
1395

1396
    selected_tests = [parse_test_module(x) for x in selected_tests]
1397
    return selected_tests
1398

1399

1400
def load_test_times_from_file(
1401
    file: str,
1402
) -> Dict[str, Any]:
1403
    # Load previous test times to make sharding decisions
1404
    path = os.path.join(str(REPO_ROOT), file)
1405
    if not os.path.exists(path):
1406
        print_to_stderr(
1407
            f"::warning:: Failed to find test times file `{path}`. Using round robin sharding."
1408
        )
1409
        return {}
1410

1411
    with open(path) as f:
1412
        test_times_file = cast(Dict[str, Any], json.load(f))
1413
    build_environment = os.environ.get("BUILD_ENVIRONMENT")
1414
    test_config = os.environ.get("TEST_CONFIG")
1415
    if test_config in test_times_file.get(build_environment, {}):
1416
        print_to_stderr("Found test times from artifacts")
1417
        return test_times_file[build_environment][test_config]
1418
    elif test_config in test_times_file["default"]:
1419
        print_to_stderr(
1420
            f"::warning:: Gathered no stats from artifacts for {build_environment} build env"
1421
            f" and {test_config} test config. Using default build env and {test_config} test config instead."
1422
        )
1423
        return test_times_file["default"][test_config]
1424
    else:
1425
        print_to_stderr(
1426
            f"::warning:: Gathered no stats from artifacts for build env {build_environment} build env"
1427
            f" and {test_config} test config. Using default build env and default test config instead."
1428
        )
1429
        return test_times_file["default"]["default"]
1430

1431

1432
def load_test_file_times(
1433
    file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_TIMES_FILE,
1434
) -> Dict[str, float]:
1435
    return cast(Dict[str, float], load_test_times_from_file(file))
1436

1437

1438
def load_test_class_times(
1439
    file: str = ADDITIONAL_CI_FILES_FOLDER / TEST_CLASS_TIMES_FILE,
1440
) -> Dict[str, Dict[str, float]]:
1441
    return cast(Dict[str, Dict[str, float]], load_test_times_from_file(file))
1442

1443

1444
def get_sharding_opts(options) -> Tuple[int, int]:
1445
    which_shard, num_shards = 1, 1
1446
    if options.shard:
1447
        assert len(options.shard) == 2, "Unexpected shard format"
1448
        assert min(options.shard) > 0, "Shards must be positive numbers"
1449
        which_shard, num_shards = options.shard
1450
        assert (
1451
            which_shard <= num_shards
1452
        ), "Selected shard must be less than or equal to total number of shards"
1453

1454
    return (which_shard, num_shards)
1455

1456

1457
def do_sharding(
1458
    options,
1459
    selected_tests: Sequence[TestRun],
1460
    test_file_times: Dict[str, float],
1461
    test_class_times: Dict[str, Dict[str, float]],
1462
    sort_by_time: bool = True,
1463
) -> Tuple[float, List[ShardedTest]]:
1464
    which_shard, num_shards = get_sharding_opts(options)
1465

1466
    # Do sharding
1467
    shards = calculate_shards(
1468
        num_shards,
1469
        selected_tests,
1470
        test_file_times,
1471
        test_class_times=test_class_times,
1472
        must_serial=must_serial,
1473
        sort_by_time=sort_by_time,
1474
    )
1475
    return shards[which_shard - 1]
1476

1477

1478
class TestFailure(NamedTuple):
1479
    test: TestRun
1480
    message: str
1481

1482

1483
def run_test_module(
1484
    test: ShardedTest, test_directory: str, options
1485
) -> Optional[TestFailure]:
1486
    try:
1487
        maybe_set_hip_visible_devies()
1488

1489
        test_name = test.name
1490

1491
        # Printing the date here can help diagnose which tests are slow
1492
        print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]")
1493
        handler = CUSTOM_HANDLERS.get(test_name, run_test)
1494
        return_code = handler(test, test_directory, options)
1495
        assert isinstance(return_code, int) and not isinstance(
1496
            return_code, bool
1497
        ), f"While running {str(test)} got non integer return code {return_code}"
1498
        if return_code == 0:
1499
            return None
1500

1501
        message = f"{str(test)} failed!"
1502
        if return_code < 0:
1503
            # subprocess.Popen returns the child process' exit signal as
1504
            # return code -N, where N is the signal number.
1505
            signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
1506
            message += f" Received signal: {signal_name}"
1507
        return TestFailure(test.test, message)
1508
    except Exception as e:
1509
        return TestFailure(test.test, f"{str(test)} failed! {e}")
1510

1511

1512
def run_tests(
1513
    selected_tests: List[ShardedTest],
1514
    test_directory: str,
1515
    options,
1516
    failures: List[TestFailure],
1517
) -> None:
1518
    if len(selected_tests) == 0:
1519
        return
1520

1521
    # parallel = in parallel with other files
1522
    # serial = this file on it's own.  The file might still be run in parallel with itself (ex test_ops)
1523
    selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
1524
    selected_tests_serial = [
1525
        x for x in selected_tests if x not in selected_tests_parallel
1526
    ]
1527

1528
    # See Note [ROCm parallel CI testing]
1529
    pool = get_context("spawn").Pool(
1530
        NUM_PROCS, maxtasksperchild=None if torch.version.hip else 1
1531
    )
1532

1533
    # NB: This is a hack to make conftest.py available on CPP_TESTS_DIR. We should
1534
    # see if the file could be turned into a full-fledge ptest plugin instead
1535
    cpp_conftest_file = os.path.join(CPP_TESTS_DIR, "conftest.py")
1536
    if (
1537
        options.cpp
1538
        and os.path.exists(CPP_TESTS_DIR)
1539
        and os.path.isdir(CPP_TESTS_DIR)
1540
        and not os.path.exists(cpp_conftest_file)
1541
    ):
1542
        # Take the conftest file from the test directory
1543
        shutil.copy(os.path.join(test_directory, "conftest.py"), cpp_conftest_file)
1544

1545
    def handle_error_messages(failure: Optional[TestFailure]):
1546
        if failure is None:
1547
            return False
1548
        failures.append(failure)
1549
        print_to_stderr(failure.message)
1550
        return True
1551

1552
    def parallel_test_completion_callback(failure):
1553
        test_failed = handle_error_messages(failure)
1554
        if (
1555
            test_failed
1556
            and not options.continue_through_error
1557
            and not RERUN_DISABLED_TESTS
1558
        ):
1559
            pool.terminate()
1560

1561
    try:
1562
        for test in selected_tests_serial:
1563
            options_clone = copy.deepcopy(options)
1564
            if can_run_in_pytest(test):
1565
                options_clone.pytest = True
1566
            failure = run_test_module(test, test_directory, options_clone)
1567
            test_failed = handle_error_messages(failure)
1568
            if (
1569
                test_failed
1570
                and not options.continue_through_error
1571
                and not RERUN_DISABLED_TESTS
1572
            ):
1573
                raise RuntimeError(
1574
                    failure.message
1575
                    + "\n\nTip: You can keep running tests even on failure by "
1576
                    "passing --keep-going to run_test.py.\n"
1577
                    "If running on CI, add the 'keep-going' label to "
1578
                    "your PR and rerun your jobs."
1579
                )
1580

1581
        os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS)
1582
        for test in selected_tests_parallel:
1583
            options_clone = copy.deepcopy(options)
1584
            if can_run_in_pytest(test):
1585
                options_clone.pytest = True
1586
            pool.apply_async(
1587
                run_test_module,
1588
                args=(test, test_directory, options_clone),
1589
                callback=parallel_test_completion_callback,
1590
            )
1591
        pool.close()
1592
        pool.join()
1593
        del os.environ["NUM_PARALLEL_PROCS"]
1594

1595
    finally:
1596
        pool.terminate()
1597
        pool.join()
1598

1599
    return
1600

1601

1602
def check_pip_packages() -> None:
1603
    packages = [
1604
        "pytest-rerunfailures",
1605
        "pytest-flakefinder",
1606
        "pytest-xdist",
1607
    ]
1608
    installed_packages = [i.key for i in pkg_resources.working_set]
1609
    for package in packages:
1610
        if package not in installed_packages:
1611
            print_to_stderr(
1612
                f"Missing pip dependency: {package}, please run `pip install -r .ci/docker/requirements-ci.txt`"
1613
            )
1614
            sys.exit(1)
1615

1616

1617
def main():
1618
    check_pip_packages()
1619

1620
    options = parse_args()
1621

1622
    # Include sharding info in all metrics
1623
    which_shard, num_shards = get_sharding_opts(options)
1624
    add_global_metric("shard", which_shard)
1625
    add_global_metric("num_shards", num_shards)
1626

1627
    test_directory = str(REPO_ROOT / "test")
1628
    selected_tests = get_selected_tests(options)
1629

1630
    test_prioritizations = import_results()
1631
    test_prioritizations.amend_tests(selected_tests)
1632

1633
    os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
1634

1635
    if options.coverage and not PYTORCH_COLLECT_COVERAGE:
1636
        shell(["coverage", "erase"])
1637

1638
    if IS_CI:
1639
        # downloading test cases configuration to local environment
1640
        get_test_case_configs(dirpath=test_directory)
1641

1642
    test_file_times_dict = load_test_file_times()
1643
    test_class_times_dict = load_test_class_times()
1644

1645
    class TestBatch:
1646
        """Defines a set of tests with similar priority that should be run together on the current shard"""
1647

1648
        name: str
1649
        sharded_tests: List[ShardedTest]
1650
        failures: List[TestFailure]
1651

1652
        def __init__(
1653
            self, name: str, raw_tests: Sequence[TestRun], should_sort_shard: bool
1654
        ):
1655
            self.name = name
1656
            self.failures = []
1657
            self.time, self.sharded_tests = do_sharding(
1658
                options,
1659
                raw_tests,
1660
                test_file_times_dict,
1661
                test_class_times_dict,
1662
                sort_by_time=should_sort_shard,
1663
            )
1664

1665
        def __str__(self):
1666
            s = f"Name: {self.name} (est. time: {round(self.time / 60, 2)}min)\n"
1667
            serial = [test for test in self.sharded_tests if must_serial(test)]
1668
            parallel = [test for test in self.sharded_tests if not must_serial(test)]
1669
            s += f"  Serial tests ({len(serial)}):\n"
1670
            s += "".join(f"    {test}\n" for test in serial)
1671
            s += f"  Parallel tests ({len(parallel)}):\n"
1672
            s += "".join(f"    {test}\n" for test in parallel)
1673
            return s.strip()
1674

1675
    percent_to_run = 25 if options.enable_td else 100
1676
    print_to_stderr(
1677
        f"Running {percent_to_run}% of tests based on TD"
1678
        if options.enable_td
1679
        else "Running all tests"
1680
    )
1681
    include, exclude = test_prioritizations.get_top_per_tests(percent_to_run)
1682

1683
    test_batch = TestBatch("tests to run", include, False)
1684
    test_batch_exclude = TestBatch("excluded", exclude, True)
1685

1686
    print_to_stderr(test_batch)
1687
    print_to_stderr(test_batch_exclude)
1688

1689
    if options.dry_run:
1690
        return
1691

1692
    if options.dynamo:
1693
        os.environ["PYTORCH_TEST_WITH_DYNAMO"] = "1"
1694

1695
    elif options.inductor:
1696
        os.environ["PYTORCH_TEST_WITH_INDUCTOR"] = "1"
1697

1698
    if not options.no_translation_validation:
1699
        os.environ["PYTORCH_TEST_WITH_TV"] = "1"
1700

1701
    try:
1702
        # Actually run the tests
1703
        start_time = time.time()
1704
        elapsed_time = time.time() - start_time
1705
        print_to_stderr(
1706
            f"Starting test batch '{test_batch.name}' {round(elapsed_time, 2)} seconds after initiating testing"
1707
        )
1708
        run_tests(
1709
            test_batch.sharded_tests, test_directory, options, test_batch.failures
1710
        )
1711

1712
    finally:
1713
        if options.coverage:
1714
            from coverage import Coverage
1715

1716
            with set_cwd(test_directory):
1717
                cov = Coverage()
1718
                if PYTORCH_COLLECT_COVERAGE:
1719
                    cov.load()
1720
                cov.combine(strict=False)
1721
                cov.save()
1722
                if not PYTORCH_COLLECT_COVERAGE:
1723
                    cov.html_report()
1724

1725
        all_failures = test_batch.failures
1726

1727
        if IS_CI:
1728
            for test, _ in all_failures:
1729
                test_stats = test_prioritizations.get_test_stats(test)
1730
                print_to_stderr("Emiting td_test_failure_stats_v2")
1731
                emit_metric(
1732
                    "td_test_failure_stats_v2",
1733
                    {
1734
                        "selected_tests": selected_tests,
1735
                        "failure": str(test),
1736
                        **test_stats,
1737
                    },
1738
                )
1739

1740
    if len(all_failures):
1741
        for _, err in all_failures:
1742
            print_to_stderr(err)
1743

1744
        # A disabled test is expected to fail, so there is no need to report a failure here
1745
        if not RERUN_DISABLED_TESTS:
1746
            sys.exit(1)
1747

1748

1749
if __name__ == "__main__":
1750
    main()
1751

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.