pytorch

Форк
0
/
model_registry.py 
233 строки · 7.7 Кб
1
# Copyright (c) Meta Platforms, Inc. and affiliates
2
# Owner(s): ["oncall: distributed"]
3
# This file is a model zoo for testing torch.distributed.pipelining.
4
import torch
5
from torch.autograd import Function
6
from torch.distributed.pipelining import pipe_split, SplitPoint
7

8

9
class ExampleCode(torch.nn.Module):
10
    def __init__(self, d_hid):
11
        super().__init__()
12
        self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
13
        self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
14
        self.cval = torch.nn.Buffer(torch.randn((d_hid,), requires_grad=False))
15
        self.lin0 = torch.nn.Linear(d_hid, d_hid)
16
        self.lin1 = torch.nn.Linear(d_hid, d_hid)
17

18
    def forward(self, x):
19
        x = torch.mm(x, self.mm_param0)
20
        x = torch.relu(x)
21
        # try passing a value that doesn't require_grad across skip boundaries
22
        a_constant = self.cval.clone()
23
        x = self.lin0(x)
24
        pipe_split()
25
        x = torch.relu(x) + a_constant
26
        x = torch.mm(x, self.mm_param1)
27
        x = self.lin1(x)
28
        x = torch.relu(x)
29
        return x
30

31

32
class ModelWithKwargs(torch.nn.Module):
33
    DEFAULT_DHID = 512
34
    DEFAULT_BATCH_SIZE = 256
35

36
    def __init__(self, d_hid: int = DEFAULT_DHID):
37
        super().__init__()
38
        self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
39
        self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
40
        self.lin0 = torch.nn.Linear(d_hid, d_hid)
41
        self.lin1 = torch.nn.Linear(d_hid, d_hid)
42

43
    def forward(self, x, y=torch.zeros(DEFAULT_BATCH_SIZE, DEFAULT_DHID)):
44
        x = torch.mm(x, self.mm_param0)
45
        x = x + y
46
        x = self.lin0(x)
47
        x = torch.relu(x)
48
        pipe_split()
49
        x = torch.mm(x, self.mm_param1)
50
        x = self.lin1(x)
51
        x = torch.relu(x)
52
        return x
53

54

55
class ModelWithParamAlias(torch.nn.Module):
56
    default_dhid = 512
57
    default_batch_size = 256
58

59
    def __init__(self, d_hid: int = default_dhid):
60
        super().__init__()
61
        self.mm_param1 = self.mm_param0 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
62
        self.lin1 = self.lin0 = torch.nn.Linear(d_hid, d_hid)
63

64
    def forward(self, x, y):
65
        x = torch.mm(x, self.mm_param0)
66
        x = x + y
67
        x = self.lin0(x)
68
        x = torch.relu(x)
69
        pipe_split()
70
        x = torch.mm(x, self.mm_param1)
71
        x = self.lin1(x)
72
        x = torch.relu(x)
73
        return x
74

75

76
# MLP Layer
77
class MLPModule(torch.nn.Module):
78
    def __init__(self, d_hid: int):
79
        super().__init__()
80
        self.net1 = torch.nn.Linear(d_hid, d_hid)
81
        self.relu = torch.nn.ReLU()
82
        self.net2 = torch.nn.Linear(d_hid, d_hid)
83

84
    def forward(self, x):
85
        x = self.net1(x)
86
        x = self.relu(x)
87
        x = self.net2(x)
88
        return x
89

90

91
# Multi-MLP model
92
class MultiMLP(torch.nn.Module):
93
    def __init__(self, d_hid: int, n_layers: int = 2):
94
        super().__init__()
95
        self.layers = torch.nn.ModuleList([MLPModule(d_hid) for _ in range(n_layers)])
96
        # For testing purpose only, this should be defined by user
97
        self.split_spec = {
98
            f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers)
99
        }
100

101
    def forward(self, x):
102
        for layer in self.layers:
103
            x = layer(x)
104
        return x
105

106

107
class CustomLinearDx(Function):
108
    @staticmethod
109
    def forward(ctx, input_val, weight, bias, module, layer_idx):
110
        ctx.save_for_backward(input_val, weight, bias)
111
        ctx.module = module
112
        ctx.layer_idx = layer_idx
113
        return input_val.mm(weight.t()) + bias
114

115
    @staticmethod
116
    def backward(ctx, grad_output):
117
        input_val, weight, bias = ctx.saved_tensors
118
        grad_input = grad_output.mm(weight)
119
        ctx.module.cached_context[ctx.layer_idx].append(grad_output.clone())
120
        ctx.module.cached_context[str(ctx.layer_idx) + "_input"].append(
121
            input_val.clone()
122
        )
123
        return grad_input, None, None, None, None
124

125

126
class CustomLinearDxDw(Function):
127
    @staticmethod
128
    def forward(ctx, input_val, weight, bias):
129
        ctx.save_for_backward(input_val, weight, bias)
130
        return input_val.mm(weight.t()) + bias
131

132
    @staticmethod
133
    def backward(ctx, grad_output):
134
        input_val, weight, bias = ctx.saved_tensors
135
        grad_input = grad_output.mm(weight)
136
        grad_weight = grad_output.t().mm(input_val)
137
        grad_bias = grad_output.sum(0)
138
        return grad_input, grad_weight, grad_bias
139

140

141
class MLPModuleWithDw(torch.nn.Module):
142
    def __init__(self, d_hid: int):
143
        super().__init__()
144
        self.fc1_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid))
145
        self.fc1_bias = torch.nn.Parameter(torch.randn(d_hid))
146
        self.fc2_weight = torch.nn.Parameter(torch.randn(d_hid, d_hid))
147
        self.fc2_bias = torch.nn.Parameter(torch.randn(d_hid))
148

149
        torch.nn.init.uniform_(self.fc1_weight, -0.01, 0.01)
150
        torch.nn.init.uniform_(self.fc2_weight, -0.01, 0.01)
151
        torch.nn.init.uniform_(self.fc1_bias, -0.01, 0.01)
152
        torch.nn.init.uniform_(self.fc2_bias, -0.01, 0.01)
153

154
        self.cached_context = {}
155
        self.cached_context["fc1"] = []
156
        self.cached_context["fc2"] = []
157
        self.cached_context["fc1_input"] = []
158
        self.cached_context["fc2_input"] = []
159

160
        self.use_custom_logic = False
161

162
    def forward(self, x):
163
        if not self.use_custom_logic:
164
            self.hidden = CustomLinearDxDw.apply(x, self.fc1_weight, self.fc1_bias)
165
            self.hidden = torch.nn.functional.relu(self.hidden)
166
            output = CustomLinearDxDw.apply(self.hidden, self.fc2_weight, self.fc2_bias)
167
            return output
168

169
        self.hidden = CustomLinearDx.apply(
170
            x, self.fc1_weight, self.fc1_bias, self, "fc1"
171
        )
172
        self.hidden = torch.nn.functional.relu(self.hidden)
173
        output = CustomLinearDx.apply(
174
            self.hidden, self.fc2_weight, self.fc2_bias, self, "fc2"
175
        )
176
        return output
177

178
    def compute_dW(self):
179
        grad_output_fc1 = self.cached_context["fc1"].pop(0)
180
        grad_output_fc2 = self.cached_context["fc2"].pop(0)
181
        cached_input_fc1 = self.cached_context["fc1_input"].pop(0)
182
        cached_input_fc2 = self.cached_context["fc2_input"].pop(0)
183

184
        dW2 = grad_output_fc2.t().mm(cached_input_fc2)
185
        db2 = grad_output_fc2.sum(0)
186

187
        dW1 = grad_output_fc1.t().mm(cached_input_fc1)
188
        db1 = grad_output_fc1.sum(0)
189

190
        if self.fc1_weight.grad is not None:
191
            self.fc1_weight.grad += dW1
192
            self.fc1_bias.grad += db1
193
            self.fc2_weight.grad += dW2
194
            self.fc2_bias.grad += db2
195
        else:
196
            self.fc1_weight.grad = dW1
197
            self.fc1_bias.grad = db1
198
            self.fc2_weight.grad = dW2
199
            self.fc2_bias.grad = db2
200

201
    def toggle(self):
202
        self.use_custom_logic = not self.use_custom_logic
203

204

205
# Multi-MLP model With Dw
206
class MultiMLPWithDw(torch.nn.Module):
207
    def __init__(self, d_hid: int, n_layers: int = 2):
208
        super().__init__()
209
        self.layers = torch.nn.ModuleList(
210
            [MLPModuleWithDw(d_hid) for _ in range(n_layers)]
211
        )
212
        # For testing purpose only, this should be defined by user
213
        self.split_spec = {
214
            f"layers.{i}": SplitPoint.BEGINNING for i in range(1, n_layers)
215
        }
216
        self.use_custom_logic = False
217

218
    def forward(self, x):
219
        for layer in self.layers:
220
            x = layer(x)
221
        return x
222

223
    def toggle(self):
224
        self.use_custom_logic = not self.use_custom_logic
225
        for layer in self.layers:
226
            layer.toggle()
227

228
    def compute_dW(self):
229
        if not self.use_custom_logic:
230
            raise RuntimeError("Need to call toggle() to enable custom backward and dW")
231

232
        for i in reversed(range(len(self.layers))):
233
            self.layers[i].compute_dW()
234

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.