pytorch

Форк
0
/
convnet_benchmarks_test.py 
23 строки · 839.0 Байт
1
import unittest
2
from caffe2.python import convnet_benchmarks as cb
3
from caffe2.python import test_util, workspace
4

5

6
# TODO: investigate why this randomly core dump in ROCM CI
7
@unittest.skipIf(not workspace.has_cuda_support, "no cuda gpu")
8
class TestConvnetBenchmarks(test_util.TestCase):
9
    def testConvnetBenchmarks(self):
10
        all_args = [
11
            '--batch_size 16 --order NCHW --iterations 1 '
12
            '--warmup_iterations 1',
13
            '--batch_size 16 --order NCHW --iterations 1 '
14
            '--warmup_iterations 1 --forward_only',
15
        ]
16
        for model in [cb.AlexNet, cb.OverFeat, cb.VGGA, cb.Inception]:
17
            for arg_str in all_args:
18
                args = cb.GetArgumentParser().parse_args(arg_str.split(' '))
19
                cb.Benchmark(model, args)
20

21

22
if __name__ == '__main__':
23
    unittest.main()
24

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.