pytorch

Форк
0
/
test_data_parallel.py 
160 строк · 5.5 Кб
1
# Owner(s): ["oncall: jit"]
2

3
import os
4
import sys
5
import unittest
6

7
import torch
8
import torch.nn as nn
9
import torch.nn.parallel as dp
10

11

12
# Make the helper files in test/ importable
13
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14
sys.path.append(pytorch_test_dir)
15
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA_MULTI_GPU
16

17

18
if __name__ == "__main__":
19
    raise RuntimeError(
20
        "This test file is not meant to be run directly, use:\n\n"
21
        "\tpython test/test_jit.py TESTNAME\n\n"
22
        "instead."
23
    )
24

25

26
class TestDataParallel(JitTestCase):
27
    class Mpy(torch.nn.Module):
28
        def __init__(self) -> None:
29
            super(TestDataParallel.Mpy, self).__init__()
30
            self.m = nn.Sequential(
31
                nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2)
32
            )
33

34
        @torch.jit.ignore
35
        def forward(self, input):
36
            return self.m(input)
37

38
    class Mpy1(torch.nn.Module):
39
        def __init__(self, block):
40
            super(TestDataParallel.Mpy1, self).__init__()
41
            self.m = block
42

43
        @torch.jit.ignore
44
        def forward(self, input):
45
            return self.m.forward(input)
46

47
    class Mpy2(torch.nn.Module):
48
        def __init__(self, block1, block2):
49
            super(TestDataParallel.Mpy2, self).__init__()
50
            self.m1 = block1
51
            self.m2 = block2
52

53
        @torch.jit.ignore
54
        def forward(self, input):
55
            x = self.m1.forward(input)
56
            return self.m2(x)
57

58
    class Msm(torch.jit.ScriptModule):
59
        __constants__ = ["m"]
60

61
        def __init__(self) -> None:
62
            super(TestDataParallel.Msm, self).__init__()
63
            self.m = nn.Sequential(
64
                nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2)
65
            )
66

67
        @torch.jit.script_method
68
        def forward(self, input):
69
            return self.m(input)
70

71
    class Msm1(torch.jit.ScriptModule):
72
        def __init__(self, block):
73
            super(TestDataParallel.Msm1, self).__init__()
74
            self.block = block
75

76
        @torch.jit.script_method
77
        def forward(self, input):
78
            x = self.block(input)
79
            return x
80

81
    def check_replicas(self, module, replicas, input_shape=(2, 2)):
82
        input = torch.randn(input_shape).cuda()
83
        expected_output = module(input).data
84
        for i, replica in enumerate(replicas):
85
            for p in replica.parameters():
86
                self.assertEqual(p.get_device(), i)
87
            for b in replica.buffers():
88
                self.assertEqual(b.get_device(), i)
89
            replica_input = input.cuda(i)
90
            self.assertEqual(replica(replica_input).data, expected_output)
91

92
    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
93
    def test_python_submodule_script(self):
94
        module = self.Mpy1(self.Msm()).cuda()
95
        replicas = dp.replicate(module, {0, 1})
96
        self.check_replicas(module, replicas)
97

98
    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
99
    def test_shared_module(self):
100
        s = self.Msm()
101
        p1 = self.Mpy1(s)
102
        module = self.Mpy2(p1, s).cuda()
103
        replicas = dp.replicate(module, {0, 1})
104
        self.check_replicas(module, replicas)
105

106
    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
107
    def test_traced_module(self):
108
        module = torch.jit.trace(self.Mpy1(self.Mpy()), torch.ones(2, 2)).cuda()
109
        replicas = dp.replicate(module, {0, 1})
110
        self.check_replicas(module, replicas)
111

112
    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
113
    def test_tensor_sharing(self):
114
        module = self.Msm1(self.Msm()).cuda()
115
        replica = dp.replicate(module, {0, 1})
116

117
        def assert_share_data(t1, t2):
118
            # Only checks that they point to the same memory on the same device.
119
            return (
120
                t1.device == t2.device
121
                and t1.storage().data_ptr() == t2.storage().data_ptr()
122
            )
123

124
        for p1, p2 in zip(module.parameters(), replica[0].parameters()):
125
            self.assertTrue(assert_share_data(p1, p2))
126

127
        for p1, p2 in zip(module.buffers(), replica[0].buffers()):
128
            self.assertTrue(assert_share_data(p1, p2))
129

130
        for p1, p2 in zip(module.parameters(), replica[1].parameters()):
131
            self.assertFalse(assert_share_data(p1, p2))
132

133
        for p1, p2 in zip(module.buffers(), replica[1].buffers()):
134
            self.assertFalse(assert_share_data(p1, p2))
135

136
    @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
137
    def test_tensor_sharing_with_forward(self):
138
        module = self.Msm1(self.Msm()).cuda()
139
        replica = dp.replicate(module, {0, 1})
140
        x = torch.ones(2, 2, requires_grad=True).cuda()
141
        first_forward = module(x)
142
        first_forward.sum().backward()
143
        with torch.no_grad():
144
            for p in module.parameters():
145
                # Use .data here to avoid version counter bump.
146
                # The graph created by the following forward will be wrong but
147
                # we never backward through them so it's fine
148
                p.data -= 1.0 * p.grad
149
        second_forward = module(x)
150

151
        # replica which is on the same GPU has a shallow copy of the original
152
        # params and buffers
153
        r0_forward = replica[0](x)
154
        self.assertEqual(second_forward, r0_forward)
155

156
        # replica which is on a different GPU has a deep copy of the original
157
        # params and buffers
158
        x1 = torch.ones(2, 2, requires_grad=True).cuda(device=1)
159
        r1_forward = replica[1](x1)
160
        self.assertEqual(first_forward, r1_forward)
161

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.