pytorch
75 строк · 2.1 Кб
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3import torch
4from torch.distributed.pipelining import pipe_split, pipeline
5from torch.testing._internal.common_utils import run_tests, TestCase
6
7
8# Building block for model
9class Block(torch.nn.Module):
10def __init__(self) -> None:
11super().__init__()
12self.conv = torch.nn.Conv2d(
13in_channels=16, out_channels=16, kernel_size=3, padding=1
14)
15self.lin0 = torch.nn.Linear(256, 256)
16self.relu = torch.nn.ReLU()
17self.lin1 = torch.nn.Linear(256, 256)
18
19def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor:
20x = self.conv(x)
21x = self.lin0(x)
22pipe_split()
23x.add_(constant)
24x = self.lin1(x)
25return self.relu(x)
26
27
28# Full model
29class M(torch.nn.Module):
30def __init__(self) -> None:
31super().__init__()
32self.block0 = Block()
33self.block1 = Block()
34
35def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor:
36x = self.block0(x, constant=constant)
37pipe_split()
38x = self.block1(x, constant=constant)
39return x
40
41
42class UnflattenTests(TestCase):
43def test_unflatten(self):
44x = torch.randn(1, 16, 256, 256)
45constant = torch.ones(1, 16, 256, 256)
46
47mod = M()
48
49pipe = pipeline(
50mod,
51(x,),
52{"constant": constant},
53)
54
55assert pipe.num_stages == 4
56orig_state_dict = mod.state_dict()
57
58# Check qualnames
59for stage_idx in range(pipe.num_stages):
60stage_mod = pipe.get_stage_module(stage_idx)
61for param_name, param in stage_mod.named_parameters():
62assert (
63param_name in orig_state_dict
64), f"{param_name} not in original state dict"
65print("Param qualname test passed")
66
67# Check equivalence
68ref = mod(x, constant)
69out = pipe(x, constant)[0]
70torch.testing.assert_close(out, ref)
71print(f"Equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}")
72
73
74if __name__ == "__main__":
75run_tests()
76