9
import torch.nn.parallel as dp
13
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14
sys.path.append(pytorch_test_dir)
15
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA_MULTI_GPU
18
if __name__ == "__main__":
20
"This test file is not meant to be run directly, use:\n\n"
21
"\tpython test/test_jit.py TESTNAME\n\n"
26
class TestDataParallel(JitTestCase):
27
class Mpy(torch.nn.Module):
28
def __init__(self) -> None:
29
super(TestDataParallel.Mpy, self).__init__()
30
self.m = nn.Sequential(
31
nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2)
35
def forward(self, input):
38
class Mpy1(torch.nn.Module):
39
def __init__(self, block):
40
super(TestDataParallel.Mpy1, self).__init__()
44
def forward(self, input):
45
return self.m.forward(input)
47
class Mpy2(torch.nn.Module):
48
def __init__(self, block1, block2):
49
super(TestDataParallel.Mpy2, self).__init__()
54
def forward(self, input):
55
x = self.m1.forward(input)
58
class Msm(torch.jit.ScriptModule):
61
def __init__(self) -> None:
62
super(TestDataParallel.Msm, self).__init__()
63
self.m = nn.Sequential(
64
nn.Linear(2, 2), nn.BatchNorm1d(2), nn.ReLU(), nn.Linear(2, 2)
67
@torch.jit.script_method
68
def forward(self, input):
71
class Msm1(torch.jit.ScriptModule):
72
def __init__(self, block):
73
super(TestDataParallel.Msm1, self).__init__()
76
@torch.jit.script_method
77
def forward(self, input):
81
def check_replicas(self, module, replicas, input_shape=(2, 2)):
82
input = torch.randn(input_shape).cuda()
83
expected_output = module(input).data
84
for i, replica in enumerate(replicas):
85
for p in replica.parameters():
86
self.assertEqual(p.get_device(), i)
87
for b in replica.buffers():
88
self.assertEqual(b.get_device(), i)
89
replica_input = input.cuda(i)
90
self.assertEqual(replica(replica_input).data, expected_output)
92
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
93
def test_python_submodule_script(self):
94
module = self.Mpy1(self.Msm()).cuda()
95
replicas = dp.replicate(module, {0, 1})
96
self.check_replicas(module, replicas)
98
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
99
def test_shared_module(self):
102
module = self.Mpy2(p1, s).cuda()
103
replicas = dp.replicate(module, {0, 1})
104
self.check_replicas(module, replicas)
106
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
107
def test_traced_module(self):
108
module = torch.jit.trace(self.Mpy1(self.Mpy()), torch.ones(2, 2)).cuda()
109
replicas = dp.replicate(module, {0, 1})
110
self.check_replicas(module, replicas)
112
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
113
def test_tensor_sharing(self):
114
module = self.Msm1(self.Msm()).cuda()
115
replica = dp.replicate(module, {0, 1})
117
def assert_share_data(t1, t2):
120
t1.device == t2.device
121
and t1.storage().data_ptr() == t2.storage().data_ptr()
124
for p1, p2 in zip(module.parameters(), replica[0].parameters()):
125
self.assertTrue(assert_share_data(p1, p2))
127
for p1, p2 in zip(module.buffers(), replica[0].buffers()):
128
self.assertTrue(assert_share_data(p1, p2))
130
for p1, p2 in zip(module.parameters(), replica[1].parameters()):
131
self.assertFalse(assert_share_data(p1, p2))
133
for p1, p2 in zip(module.buffers(), replica[1].buffers()):
134
self.assertFalse(assert_share_data(p1, p2))
136
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
137
def test_tensor_sharing_with_forward(self):
138
module = self.Msm1(self.Msm()).cuda()
139
replica = dp.replicate(module, {0, 1})
140
x = torch.ones(2, 2, requires_grad=True).cuda()
141
first_forward = module(x)
142
first_forward.sum().backward()
143
with torch.no_grad():
144
for p in module.parameters():
148
p.data -= 1.0 * p.grad
149
second_forward = module(x)
153
r0_forward = replica[0](x)
154
self.assertEqual(second_forward, r0_forward)
158
x1 = torch.ones(2, 2, requires_grad=True).cuda(device=1)
159
r1_forward = replica[1](x1)
160
self.assertEqual(first_forward, r1_forward)