3
from torch.testing._internal.common_utils import TestCase, run_tests, slowTest, IS_WINDOWS
10
PYTORCH_COLLECT_COVERAGE = bool(os.environ.get("PYTORCH_COLLECT_COVERAGE"))
13
class TestFunctionalAutogradBenchmark(TestCase):
14
def _test_runner(self, model, disable_gpu=False):
19
with tempfile.NamedTemporaryFile() as out_file:
21
'../benchmarks/functional_autograd_benchmark/functional_autograd_benchmark.py']
23
cmd += ['--num-iters', '0']
25
cmd += ['--task-filter', 'vjp']
27
cmd += ['--model-filter', model]
29
cmd += ['--output', out_file.name]
31
cmd += ['--gpu', '-1']
33
res = subprocess.run(cmd)
35
self.assertTrue(res.returncode == 0)
37
out_file.seek(0, os.SEEK_END)
38
self.assertTrue(out_file.tell() > 0)
41
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows does not have all the features we need.")
42
@unittest.skipIf(PYTORCH_COLLECT_COVERAGE, "Can deadlocks with gcov, see https://github.com/pytorch/pytorch/issues/49656")
43
def test_fast_tasks(self):
44
fast_tasks = ['resnet18', 'ppl_simple_reg', 'ppl_robust_reg', 'wav2letter',
45
'transformer', 'multiheadattn']
47
for task in fast_tasks:
48
self._test_runner(task)
51
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows does not have all the features we need.")
52
def test_slow_tasks(self):
53
slow_tasks = ['fcn_resnet', 'detr']
57
for task in slow_tasks:
59
self._test_runner(task, disable_gpu=True)
62
if __name__ == '__main__':