pytorch

Форк
0
/
test_functional_autograd_benchmark.py 
63 строки · 2.5 Кб
1
# Owner(s): ["module: autograd"]
2

3
from torch.testing._internal.common_utils import TestCase, run_tests, slowTest, IS_WINDOWS
4

5
import subprocess
6
import tempfile
7
import os
8
import unittest
9

10
PYTORCH_COLLECT_COVERAGE = bool(os.environ.get("PYTORCH_COLLECT_COVERAGE"))
11

12
# This is a very simple smoke test for the functional autograd benchmarking script.
13
class TestFunctionalAutogradBenchmark(TestCase):
14
    def _test_runner(self, model, disable_gpu=False):
15
        # Note about windows:
16
        # The temporary file is exclusively open by this process and the child process
17
        # is not allowed to open it again. As this is a simple smoke test, we choose for now
18
        # not to run this on windows and keep the code here simple.
19
        with tempfile.NamedTemporaryFile() as out_file:
20
            cmd = ['python3',
21
                   '../benchmarks/functional_autograd_benchmark/functional_autograd_benchmark.py']
22
            # Only run the warmup
23
            cmd += ['--num-iters', '0']
24
            # Only run the vjp task (fastest one)
25
            cmd += ['--task-filter', 'vjp']
26
            # Only run the specified model
27
            cmd += ['--model-filter', model]
28
            # Output file
29
            cmd += ['--output', out_file.name]
30
            if disable_gpu:
31
                cmd += ['--gpu', '-1']
32

33
            res = subprocess.run(cmd)
34

35
            self.assertTrue(res.returncode == 0)
36
            # Check that something was written to the file
37
            out_file.seek(0, os.SEEK_END)
38
            self.assertTrue(out_file.tell() > 0)
39

40

41
    @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows does not have all the features we need.")
42
    @unittest.skipIf(PYTORCH_COLLECT_COVERAGE, "Can deadlocks with gcov, see https://github.com/pytorch/pytorch/issues/49656")
43
    def test_fast_tasks(self):
44
        fast_tasks = ['resnet18', 'ppl_simple_reg', 'ppl_robust_reg', 'wav2letter',
45
                      'transformer', 'multiheadattn']
46

47
        for task in fast_tasks:
48
            self._test_runner(task)
49

50
    @slowTest
51
    @unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows does not have all the features we need.")
52
    def test_slow_tasks(self):
53
        slow_tasks = ['fcn_resnet', 'detr']
54
        # deepspeech is voluntarily excluded as it takes too long to run without
55
        # proper tuning of the number of threads it should use.
56

57
        for task in slow_tasks:
58
            # Disable GPU for slow test as the CI GPU don't have enough memory
59
            self._test_runner(task, disable_gpu=True)
60

61

62
if __name__ == '__main__':
63
    run_tests()
64

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.