pytorch

Форк
0
/
test_functional_autograd_benchmark.py 
86 строк · 2.7 Кб
1
# Owner(s): ["module: autograd"]
2

3
import os
4

5
import subprocess
6
import tempfile
7
import unittest
8

9
from torch.testing._internal.common_utils import (
10
    IS_WINDOWS,
11
    run_tests,
12
    slowTest,
13
    TestCase,
14
)
15

16
PYTORCH_COLLECT_COVERAGE = bool(os.environ.get("PYTORCH_COLLECT_COVERAGE"))
17

18

19
# This is a very simple smoke test for the functional autograd benchmarking script.
20
class TestFunctionalAutogradBenchmark(TestCase):
21
    def _test_runner(self, model, disable_gpu=False):
22
        # Note about windows:
23
        # The temporary file is exclusively open by this process and the child process
24
        # is not allowed to open it again. As this is a simple smoke test, we choose for now
25
        # not to run this on windows and keep the code here simple.
26
        with tempfile.NamedTemporaryFile() as out_file:
27
            cmd = [
28
                "python3",
29
                "../benchmarks/functional_autograd_benchmark/functional_autograd_benchmark.py",
30
            ]
31
            # Only run the warmup
32
            cmd += ["--num-iters", "0"]
33
            # Only run the vjp task (fastest one)
34
            cmd += ["--task-filter", "vjp"]
35
            # Only run the specified model
36
            cmd += ["--model-filter", model]
37
            # Output file
38
            cmd += ["--output", out_file.name]
39
            if disable_gpu:
40
                cmd += ["--gpu", "-1"]
41

42
            res = subprocess.run(cmd)
43

44
            self.assertTrue(res.returncode == 0)
45
            # Check that something was written to the file
46
            out_file.seek(0, os.SEEK_END)
47
            self.assertTrue(out_file.tell() > 0)
48

49
    @unittest.skipIf(
50
        IS_WINDOWS,
51
        "NamedTemporaryFile on windows does not have all the features we need.",
52
    )
53
    @unittest.skipIf(
54
        PYTORCH_COLLECT_COVERAGE,
55
        "Can deadlocks with gcov, see https://github.com/pytorch/pytorch/issues/49656",
56
    )
57
    def test_fast_tasks(self):
58
        fast_tasks = [
59
            "resnet18",
60
            "ppl_simple_reg",
61
            "ppl_robust_reg",
62
            "wav2letter",
63
            "transformer",
64
            "multiheadattn",
65
        ]
66

67
        for task in fast_tasks:
68
            self._test_runner(task)
69

70
    @slowTest
71
    @unittest.skipIf(
72
        IS_WINDOWS,
73
        "NamedTemporaryFile on windows does not have all the features we need.",
74
    )
75
    def test_slow_tasks(self):
76
        slow_tasks = ["fcn_resnet", "detr"]
77
        # deepspeech is voluntarily excluded as it takes too long to run without
78
        # proper tuning of the number of threads it should use.
79

80
        for task in slow_tasks:
81
            # Disable GPU for slow test as the CI GPU don't have enough memory
82
            self._test_runner(task, disable_gpu=True)
83

84

85
if __name__ == "__main__":
86
    run_tests()
87

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.