pytorch
124 строки · 3.4 Кб
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3from model_registry import MLPModule, ModelWithParamAlias4
5import torch6from torch.distributed.pipelining import pipe_split, pipeline7from torch.testing._internal.common_utils import (8instantiate_parametrized_tests,9parametrize,10run_tests,11TestCase,12)
13
14
15d_hid = 51216microbatch_size = 1617
18torch.manual_seed(0)19
20
21# Basic example
22class ExampleCode(torch.nn.Module):23def __init__(self) -> None:24super().__init__()25self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))26self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))27self.lin1 = torch.nn.Linear(d_hid, d_hid)28self.lin2 = torch.nn.Linear(d_hid, d_hid)29
30def forward(self, x, y):31x = torch.mm(x, self.mm_param1) # mutli-use param32skip_connection = x33x = x + y34x = torch.relu(x)35pipe_split()36x = torch.mm(x, self.mm_param1) # mutli-use param37x = self.lin1(x)38pipe_split()39x = torch.relu(x)40x = x + skip_connection41x = torch.mm(x, self.mm_param2)42pipe_split()43x = self.lin2(x)44x = torch.relu(x)45return x46
47
48class MultiMLP(torch.nn.Module):49def __init__(self) -> None:50super().__init__()51self.mlp0 = MLPModule(d_hid)52self.mlp1 = MLPModule(d_hid)53self.mlp2 = MLPModule(d_hid)54self.mlp3 = MLPModule(d_hid)55
56def forward(self, x, y):57x = self.mlp0(x)58pipe_split()59x = self.mlp1(x)60pipe_split()61x = self.mlp2(x)62pipe_split()63x = self.mlp3(x)64return x - y65
66
67EXPECTED_N_STAGES = {68ExampleCode: 4,69MultiMLP: 4,70ModelWithParamAlias: 2,71}
72
73# Currently, we don't enforce full set equality on the FQNs between the original
74# and pipelined models, because in the multi-use param case, PP will deduplicate
75# the FQNs from the state_dict.
76# TODO
77CHECK_FQN_SET_EQUALITY = False78
79
80class PipeTests(TestCase):81@parametrize("ModelClass", [ExampleCode, MultiMLP, ModelWithParamAlias])82def test_model_split(self, ModelClass):83mod = ModelClass()84x = torch.randn(microbatch_size, d_hid)85y = torch.randn(microbatch_size, d_hid)86
87pipe = pipeline(88mod,89mb_args=(x, y),90)91
92assert (93pipe.num_stages == EXPECTED_N_STAGES[ModelClass]94), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}"95
96ref_out = mod(x, y)97out = pipe(x, y)[0]98torch.testing.assert_close(out, ref_out)99print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}")100
101# Check qualname102# state_dict.keys include both parameters and persistent buffers103old_names = set(mod.state_dict().keys())104new_names = set()105for idx in range(pipe.num_stages):106stage_mod = pipe.get_stage_module(idx)107stage_fqns = set(stage_mod.state_dict().keys())108assert stage_fqns.issubset(old_names)109new_names.update(stage_fqns)110
111if CHECK_FQN_SET_EQUALITY:112assert (113old_names == new_names114), f"""115old names {old_names}116new names {new_names}117"""
118print("Qualname check passed")119
120
121instantiate_parametrized_tests(PipeTests)122
123if __name__ == "__main__":124run_tests()125