pytorch

Форк
0
/
test_throughput_benchmark.py 
83 строки · 2.4 Кб
1
# Owner(s): ["module: unknown"]
2

3
import torch
4
from torch.utils import ThroughputBenchmark
5

6
from torch.testing._internal.common_utils import run_tests, TestCase, TemporaryFileName
7

8
class TwoLayerNet(torch.jit.ScriptModule):
9
    def __init__(self, D_in, H, D_out):
10
        super().__init__()
11
        self.linear1 = torch.nn.Linear(D_in, H)
12
        self.linear2 = torch.nn.Linear(2 * H, D_out)
13

14
    @torch.jit.script_method
15
    def forward(self, x1, x2):
16
        h1_relu = self.linear1(x1).clamp(min=0)
17
        h2_relu = self.linear1(x2).clamp(min=0)
18
        cat = torch.cat((h1_relu, h2_relu), 1)
19
        y_pred = self.linear2(cat)
20
        return y_pred
21

22
class TwoLayerNetModule(torch.nn.Module):
23
    def __init__(self, D_in, H, D_out):
24
        super().__init__()
25
        self.linear1 = torch.nn.Linear(D_in, H)
26
        self.linear2 = torch.nn.Linear(2 * H, D_out)
27

28
    def forward(self, x1, x2):
29
        h1_relu = self.linear1(x1).clamp(min=0)
30
        h2_relu = self.linear1(x2).clamp(min=0)
31
        cat = torch.cat((h1_relu, h2_relu), 1)
32
        y_pred = self.linear2(cat)
33
        return y_pred
34

35
class TestThroughputBenchmark(TestCase):
36
    def linear_test(self, Module, profiler_output_path=""):
37
        D_in = 10
38
        H = 5
39
        D_out = 15
40
        B = 8
41
        NUM_INPUTS = 2
42

43
        module = Module(D_in, H, D_out)
44

45
        inputs = []
46

47
        for i in range(NUM_INPUTS):
48
            inputs.append([torch.randn(B, D_in), torch.randn(B, D_in)])
49
        bench = ThroughputBenchmark(module)
50

51
        for input in inputs:
52
            # can do both args and kwargs here
53
            bench.add_input(input[0], x2=input[1])
54

55
        for i in range(NUM_INPUTS):
56
            # or just unpack the list of inputs
57
            module_result = module(*inputs[i])
58
            bench_result = bench.run_once(*inputs[i])
59
            torch.testing.assert_close(bench_result, module_result)
60

61
        stats = bench.benchmark(
62
            num_calling_threads=4,
63
            num_warmup_iters=100,
64
            num_iters=1000,
65
            profiler_output_path=profiler_output_path,
66
        )
67

68
        print(stats)
69

70

71
    def test_script_module(self):
72
        self.linear_test(TwoLayerNet)
73

74
    def test_module(self):
75
        self.linear_test(TwoLayerNetModule)
76

77
    def test_profiling(self):
78
        with TemporaryFileName() as fname:
79
            self.linear_test(TwoLayerNetModule, profiler_output_path=fname)
80

81

82
if __name__ == '__main__':
83
    run_tests()
84

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.