pytorch
233 строки · 7.7 Кб
1# Copyright (c) Meta Platforms, Inc. and affiliates
2# Owner(s): ["oncall: distributed"]
3# This file is a model zoo for testing torch.distributed.pipelining.
4import torch
5from torch.autograd import Function
6from torch.distributed.pipelining import pipe_split, SplitPoint
7
8
9class ExampleCode(torch.nn.Module):
10def __init__(self, d_hid):
11super().__init__()
12self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
13self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
14self.cval = torch.nn.Buffer(torch.randn((d_hid,), requires_grad=False))
15self.lin0 = torch.nn.Linear(d_hid, d_hid)
16self.lin1 = torch.nn.Linear(d_hid, d_hid)
17
18def forward(self, x):
19x = torch.mm(x, self.mm_param0)
20x = torch.relu(x)
21# try passing a value that doesn't require_grad across skip boundaries
22a_constant = self.cval.clone()
23x = self.lin0(x)
24pipe_split()
25x = torch.relu(x) + a_constant
26x = torch.mm(x, self.mm_param1)
27x = self.lin1(x)
28x = torch.relu(x)
29return x
30
31
32class ModelWithKwargs(torch.nn.Module):
33DEFAULT_DHID = 512
34DEFAULT_BATCH_SIZE = 256
35
36def __init__(self, d_hid: int = DEFAULT_DHID):
37super().__init__()
38self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
39self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
40self.lin0 = torch.nn.Linear(d_hid, d_hid)
41self.lin1 = torch.nn.Linear(d_hid, d_hid)
42
43def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)):
44x = torch.mm(x, self.mm_param0)
45x = x + y
46x = self.lin0(x)
47x = torch.relu(x)
48pipe_split()
49x = torch.mm(x, self.mm_param1)
50x = self.lin1(x)
51x = torch.relu(x)
52return x
53
54
55class ModelWithParamAlias(torch.nn.Module):
56default_dhid = 512
57default_batch_size = 256
58
59def __init__(self, d_hid: int = default_dhid):
60super().__init__()
61self.mm_param1 = self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
62self.lin1 = self.lin0 = torch.nn.Linear(d_hid, d_hid)
63
64def forward(self, x, y):
65x = torch.mm(x, self.mm_param0)
66x = x + y
67x = self.lin0(x)
68x = torch.relu(x)
69pipe_split()
70x = torch.mm(x, self.mm_param1)
71x = self.lin1(x)
72x = torch.relu(x)
73return x
74
75
76# MLP Layer
77class MLPModule(torch.nn.Module):
78def __init__(self, d_hid: int):
79super().__init__()
80self.net1 = torch.nn.Linear(d_hid, d_hid)
81self.relu = torch.nn.ReLU()
82self.net2 = torch.nn.Linear(d_hid, d_hid)
83
84def forward(self, x):
85x = self.net1(x)
86x = self.relu(x)
87x = self.net2(x)
88return x
89
90
91# Multi-MLP model
92class MultiMLP(torch.nn.Module):
93def __init__(self, d_hid: int, n_layers: int = 2):
94super().__init__()
95self.layers = torch.nn.ModuleList([MLPModule(d_hid) for _ in range(n_layers)])
96# For testing purpose only, this should be defined by user
97self.split_spec = {
98f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers)
99}
100
101def forward(self, x):
102for layer in self.layers:
103x = layer(x)
104return x
105
106
107class CustomLinearDx(Function):
108@staticmethod
109def forward(ctx, input_val, weight, bias, module, layer_idx):
110ctx.save_for_backward(input_val, weight, bias)
111ctx.module = module
112ctx.layer_idx = layer_idx
113return input_val.mm(weight.t()) + bias
114
115@staticmethod
116def backward(ctx, grad_output):
117input_val, weight, bias = ctx.saved_tensors
118grad_input = grad_output.mm(weight)
119ctx.module.cached_context[ctx.layer_idx].append(grad_output.clone())
120ctx.module.cached_context[str(ctx.layer_idx) + "_input"].append(
121input_val.clone()
122)
123return grad_input, None, None, None, None
124
125
126class CustomLinearDxDw(Function):
127@staticmethod
128def forward(ctx, input_val, weight, bias):
129ctx.save_for_backward(input_val, weight, bias)
130return input_val.mm(weight.t()) + bias
131
132@staticmethod
133def backward(ctx, grad_output):
134input_val, weight, bias = ctx.saved_tensors
135grad_input = grad_output.mm(weight)
136grad_weight = grad_output.t().mm(input_val)
137grad_bias = grad_output.sum(0)
138return grad_input, grad_weight, grad_bias
139
140
141class MLPModuleWithDw(torch.nn.Module):
142def __init__(self, d_hid: int):
143super().__init__()
144self.fc1_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid))
145self.fc1_bias = torch.nn.Parameter(torch.randn(d_hid))
146self.fc2_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid))
147self.fc2_bias = torch.nn.Parameter(torch.randn(d_hid))
148
149torch.nn.init.uniform_(self.fc1_weight, -0.01, 0.01)
150torch.nn.init.uniform_(self.fc2_weight, -0.01, 0.01)
151torch.nn.init.uniform_(self.fc1_bias, -0.01, 0.01)
152torch.nn.init.uniform_(self.fc2_bias, -0.01, 0.01)
153
154self.cached_context = {}
155self.cached_context["fc1"] = []
156self.cached_context["fc2"] = []
157self.cached_context["fc1_input"] = []
158self.cached_context["fc2_input"] = []
159
160self.use_custom_logic = False
161
162def forward(self, x):
163if not self.use_custom_logic:
164self.hidden = CustomLinearDxDw.apply(x, self.fc1_weight, self.fc1_bias)
165self.hidden = torch.nn.functional.relu(self.hidden)
166output = CustomLinearDxDw.apply(self.hidden, self.fc2_weight, self.fc2_bias)
167return output
168
169self.hidden = CustomLinearDx.apply(
170x, self.fc1_weight, self.fc1_bias, self, "fc1"
171)
172self.hidden = torch.nn.functional.relu(self.hidden)
173output = CustomLinearDx.apply(
174self.hidden, self.fc2_weight, self.fc2_bias, self, "fc2"
175)
176return output
177
178def compute_dW(self):
179grad_output_fc1 = self.cached_context["fc1"].pop(0)
180grad_output_fc2 = self.cached_context["fc2"].pop(0)
181cached_input_fc1 = self.cached_context["fc1_input"].pop(0)
182cached_input_fc2 = self.cached_context["fc2_input"].pop(0)
183
184dW2 = grad_output_fc2.t().mm(cached_input_fc2)
185db2 = grad_output_fc2.sum(0)
186
187dW1 = grad_output_fc1.t().mm(cached_input_fc1)
188db1 = grad_output_fc1.sum(0)
189
190if self.fc1_weight.grad is not None:
191self.fc1_weight.grad += dW1
192self.fc1_bias.grad += db1
193self.fc2_weight.grad += dW2
194self.fc2_bias.grad += db2
195else:
196self.fc1_weight.grad = dW1
197self.fc1_bias.grad = db1
198self.fc2_weight.grad = dW2
199self.fc2_bias.grad = db2
200
201def toggle(self):
202self.use_custom_logic = not self.use_custom_logic
203
204
205# Multi-MLP model With Dw
206class MultiMLPWithDw(torch.nn.Module):
207def __init__(self, d_hid: int, n_layers: int = 2):
208super().__init__()
209self.layers = torch.nn.ModuleList(
210[MLPModuleWithDw(d_hid) for _ in range(n_layers)]
211)
212# For testing purpose only, this should be defined by user
213self.split_spec = {
214f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers)
215}
216self.use_custom_logic = False
217
218def forward(self, x):
219for layer in self.layers:
220x = layer(x)
221return x
222
223def toggle(self):
224self.use_custom_logic = not self.use_custom_logic
225for layer in self.layers:
226layer.toggle()
227
228def compute_dW(self):
229if not self.use_custom_logic:
230raise RuntimeError("Need to call toggle() to enable custom backward and dW")
231
232for i in reversed(range(len(self.layers))):
233self.layers[i].compute_dW()
234