pytorch

Форк
0
/
custom_lstms.py 
509 строк · 17.0 Кб
1
import numbers
2
import warnings
3
from collections import namedtuple
4
from typing import List, Tuple
5

6
import torch
7
import torch.jit as jit
8
import torch.nn as nn
9
from torch import Tensor
10
from torch.nn import Parameter
11

12
"""
13
Some helper classes for writing custom TorchScript LSTMs.
14

15
Goals:
16
- Classes are easy to read, use, and extend
17
- Performance of custom LSTMs approach fused-kernel-levels of speed.
18

19
A few notes about features we could add to clean up the below code:
20
- Support enumerate with nn.ModuleList:
21
  https://github.com/pytorch/pytorch/issues/14471
22
- Support enumerate/zip with lists:
23
  https://github.com/pytorch/pytorch/issues/15952
24
- Support overriding of class methods:
25
  https://github.com/pytorch/pytorch/issues/10733
26
- Support passing around user-defined namedtuple types for readability
27
- Support slicing w/ range. It enables reversing lists easily.
28
  https://github.com/pytorch/pytorch/issues/10774
29
- Multiline type annotations. List[List[Tuple[Tensor,Tensor]]] is verbose
30
  https://github.com/pytorch/pytorch/pull/14922
31
"""
32

33

34
def script_lstm(
35
    input_size,
36
    hidden_size,
37
    num_layers,
38
    bias=True,
39
    batch_first=False,
40
    dropout=False,
41
    bidirectional=False,
42
):
43
    """Returns a ScriptModule that mimics a PyTorch native LSTM."""
44

45
    # The following are not implemented.
46
    assert bias
47
    assert not batch_first
48

49
    if bidirectional:
50
        stack_type = StackedLSTM2
51
        layer_type = BidirLSTMLayer
52
        dirs = 2
53
    elif dropout:
54
        stack_type = StackedLSTMWithDropout
55
        layer_type = LSTMLayer
56
        dirs = 1
57
    else:
58
        stack_type = StackedLSTM
59
        layer_type = LSTMLayer
60
        dirs = 1
61

62
    return stack_type(
63
        num_layers,
64
        layer_type,
65
        first_layer_args=[LSTMCell, input_size, hidden_size],
66
        other_layer_args=[LSTMCell, hidden_size * dirs, hidden_size],
67
    )
68

69

70
def script_lnlstm(
71
    input_size,
72
    hidden_size,
73
    num_layers,
74
    bias=True,
75
    batch_first=False,
76
    dropout=False,
77
    bidirectional=False,
78
    decompose_layernorm=False,
79
):
80
    """Returns a ScriptModule that mimics a PyTorch native LSTM."""
81

82
    # The following are not implemented.
83
    assert bias
84
    assert not batch_first
85
    assert not dropout
86

87
    if bidirectional:
88
        stack_type = StackedLSTM2
89
        layer_type = BidirLSTMLayer
90
        dirs = 2
91
    else:
92
        stack_type = StackedLSTM
93
        layer_type = LSTMLayer
94
        dirs = 1
95

96
    return stack_type(
97
        num_layers,
98
        layer_type,
99
        first_layer_args=[
100
            LayerNormLSTMCell,
101
            input_size,
102
            hidden_size,
103
            decompose_layernorm,
104
        ],
105
        other_layer_args=[
106
            LayerNormLSTMCell,
107
            hidden_size * dirs,
108
            hidden_size,
109
            decompose_layernorm,
110
        ],
111
    )
112

113

114
LSTMState = namedtuple("LSTMState", ["hx", "cx"])
115

116

117
def reverse(lst: List[Tensor]) -> List[Tensor]:
118
    return lst[::-1]
119

120

121
class LSTMCell(jit.ScriptModule):
122
    def __init__(self, input_size, hidden_size):
123
        super().__init__()
124
        self.input_size = input_size
125
        self.hidden_size = hidden_size
126
        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
127
        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
128
        self.bias_ih = Parameter(torch.randn(4 * hidden_size))
129
        self.bias_hh = Parameter(torch.randn(4 * hidden_size))
130

131
    @jit.script_method
132
    def forward(
133
        self, input: Tensor, state: Tuple[Tensor, Tensor]
134
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
135
        hx, cx = state
136
        gates = (
137
            torch.mm(input, self.weight_ih.t())
138
            + self.bias_ih
139
            + torch.mm(hx, self.weight_hh.t())
140
            + self.bias_hh
141
        )
142
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
143

144
        ingate = torch.sigmoid(ingate)
145
        forgetgate = torch.sigmoid(forgetgate)
146
        cellgate = torch.tanh(cellgate)
147
        outgate = torch.sigmoid(outgate)
148

149
        cy = (forgetgate * cx) + (ingate * cellgate)
150
        hy = outgate * torch.tanh(cy)
151

152
        return hy, (hy, cy)
153

154

155
class LayerNorm(jit.ScriptModule):
156
    def __init__(self, normalized_shape):
157
        super().__init__()
158
        if isinstance(normalized_shape, numbers.Integral):
159
            normalized_shape = (normalized_shape,)
160
        normalized_shape = torch.Size(normalized_shape)
161

162
        # XXX: This is true for our LSTM / NLP use case and helps simplify code
163
        assert len(normalized_shape) == 1
164

165
        self.weight = Parameter(torch.ones(normalized_shape))
166
        self.bias = Parameter(torch.zeros(normalized_shape))
167
        self.normalized_shape = normalized_shape
168

169
    @jit.script_method
170
    def compute_layernorm_stats(self, input):
171
        mu = input.mean(-1, keepdim=True)
172
        sigma = input.std(-1, keepdim=True, unbiased=False)
173
        return mu, sigma
174

175
    @jit.script_method
176
    def forward(self, input):
177
        mu, sigma = self.compute_layernorm_stats(input)
178
        return (input - mu) / sigma * self.weight + self.bias
179

180

181
class LayerNormLSTMCell(jit.ScriptModule):
182
    def __init__(self, input_size, hidden_size, decompose_layernorm=False):
183
        super().__init__()
184
        self.input_size = input_size
185
        self.hidden_size = hidden_size
186
        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
187
        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
188
        # The layernorms provide learnable biases
189

190
        if decompose_layernorm:
191
            ln = LayerNorm
192
        else:
193
            ln = nn.LayerNorm
194

195
        self.layernorm_i = ln(4 * hidden_size)
196
        self.layernorm_h = ln(4 * hidden_size)
197
        self.layernorm_c = ln(hidden_size)
198

199
    @jit.script_method
200
    def forward(
201
        self, input: Tensor, state: Tuple[Tensor, Tensor]
202
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
203
        hx, cx = state
204
        igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))
205
        hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))
206
        gates = igates + hgates
207
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
208

209
        ingate = torch.sigmoid(ingate)
210
        forgetgate = torch.sigmoid(forgetgate)
211
        cellgate = torch.tanh(cellgate)
212
        outgate = torch.sigmoid(outgate)
213

214
        cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate))
215
        hy = outgate * torch.tanh(cy)
216

217
        return hy, (hy, cy)
218

219

220
class LSTMLayer(jit.ScriptModule):
221
    def __init__(self, cell, *cell_args):
222
        super().__init__()
223
        self.cell = cell(*cell_args)
224

225
    @jit.script_method
226
    def forward(
227
        self, input: Tensor, state: Tuple[Tensor, Tensor]
228
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
229
        inputs = input.unbind(0)
230
        outputs = torch.jit.annotate(List[Tensor], [])
231
        for i in range(len(inputs)):
232
            out, state = self.cell(inputs[i], state)
233
            outputs += [out]
234
        return torch.stack(outputs), state
235

236

237
class ReverseLSTMLayer(jit.ScriptModule):
238
    def __init__(self, cell, *cell_args):
239
        super().__init__()
240
        self.cell = cell(*cell_args)
241

242
    @jit.script_method
243
    def forward(
244
        self, input: Tensor, state: Tuple[Tensor, Tensor]
245
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
246
        inputs = reverse(input.unbind(0))
247
        outputs = jit.annotate(List[Tensor], [])
248
        for i in range(len(inputs)):
249
            out, state = self.cell(inputs[i], state)
250
            outputs += [out]
251
        return torch.stack(reverse(outputs)), state
252

253

254
class BidirLSTMLayer(jit.ScriptModule):
255
    __constants__ = ["directions"]
256

257
    def __init__(self, cell, *cell_args):
258
        super().__init__()
259
        self.directions = nn.ModuleList(
260
            [
261
                LSTMLayer(cell, *cell_args),
262
                ReverseLSTMLayer(cell, *cell_args),
263
            ]
264
        )
265

266
    @jit.script_method
267
    def forward(
268
        self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
269
    ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
270
        # List[LSTMState]: [forward LSTMState, backward LSTMState]
271
        outputs = jit.annotate(List[Tensor], [])
272
        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
273
        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
274
        i = 0
275
        for direction in self.directions:
276
            state = states[i]
277
            out, out_state = direction(input, state)
278
            outputs += [out]
279
            output_states += [out_state]
280
            i += 1
281
        return torch.cat(outputs, -1), output_states
282

283

284
def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args):
285
    layers = [layer(*first_layer_args)] + [
286
        layer(*other_layer_args) for _ in range(num_layers - 1)
287
    ]
288
    return nn.ModuleList(layers)
289

290

291
class StackedLSTM(jit.ScriptModule):
292
    __constants__ = ["layers"]  # Necessary for iterating through self.layers
293

294
    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
295
        super().__init__()
296
        self.layers = init_stacked_lstm(
297
            num_layers, layer, first_layer_args, other_layer_args
298
        )
299

300
    @jit.script_method
301
    def forward(
302
        self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
303
    ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
304
        # List[LSTMState]: One state per layer
305
        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
306
        output = input
307
        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
308
        i = 0
309
        for rnn_layer in self.layers:
310
            state = states[i]
311
            output, out_state = rnn_layer(output, state)
312
            output_states += [out_state]
313
            i += 1
314
        return output, output_states
315

316

317
# Differs from StackedLSTM in that its forward method takes
318
# List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM
319
# except we don't support overriding script methods.
320
# https://github.com/pytorch/pytorch/issues/10733
321
class StackedLSTM2(jit.ScriptModule):
322
    __constants__ = ["layers"]  # Necessary for iterating through self.layers
323

324
    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
325
        super().__init__()
326
        self.layers = init_stacked_lstm(
327
            num_layers, layer, first_layer_args, other_layer_args
328
        )
329

330
    @jit.script_method
331
    def forward(
332
        self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]]
333
    ) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]:
334
        # List[List[LSTMState]]: The outer list is for layers,
335
        #                        inner list is for directions.
336
        output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])
337
        output = input
338
        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
339
        i = 0
340
        for rnn_layer in self.layers:
341
            state = states[i]
342
            output, out_state = rnn_layer(output, state)
343
            output_states += [out_state]
344
            i += 1
345
        return output, output_states
346

347

348
class StackedLSTMWithDropout(jit.ScriptModule):
349
    # Necessary for iterating through self.layers and dropout support
350
    __constants__ = ["layers", "num_layers"]
351

352
    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
353
        super().__init__()
354
        self.layers = init_stacked_lstm(
355
            num_layers, layer, first_layer_args, other_layer_args
356
        )
357
        # Introduces a Dropout layer on the outputs of each LSTM layer except
358
        # the last layer, with dropout probability = 0.4.
359
        self.num_layers = num_layers
360

361
        if num_layers == 1:
362
            warnings.warn(
363
                "dropout lstm adds dropout layers after all but last "
364
                "recurrent layer, it expects num_layers greater than "
365
                "1, but got num_layers = 1"
366
            )
367

368
        self.dropout_layer = nn.Dropout(0.4)
369

370
    @jit.script_method
371
    def forward(
372
        self, input: Tensor, states: List[Tuple[Tensor, Tensor]]
373
    ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
374
        # List[LSTMState]: One state per layer
375
        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
376
        output = input
377
        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
378
        i = 0
379
        for rnn_layer in self.layers:
380
            state = states[i]
381
            output, out_state = rnn_layer(output, state)
382
            # Apply the dropout layer except the last layer
383
            if i < self.num_layers - 1:
384
                output = self.dropout_layer(output)
385
            output_states += [out_state]
386
            i += 1
387
        return output, output_states
388

389

390
def flatten_states(states):
391
    states = list(zip(*states))
392
    assert len(states) == 2
393
    return [torch.stack(state) for state in states]
394

395

396
def double_flatten_states(states):
397
    # XXX: Can probably write this in a nicer way
398
    states = flatten_states([flatten_states(inner) for inner in states])
399
    return [hidden.view([-1] + list(hidden.shape[2:])) for hidden in states]
400

401

402
def test_script_rnn_layer(seq_len, batch, input_size, hidden_size):
403
    inp = torch.randn(seq_len, batch, input_size)
404
    state = LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
405
    rnn = LSTMLayer(LSTMCell, input_size, hidden_size)
406
    out, out_state = rnn(inp, state)
407

408
    # Control: pytorch native LSTM
409
    lstm = nn.LSTM(input_size, hidden_size, 1)
410
    lstm_state = LSTMState(state.hx.unsqueeze(0), state.cx.unsqueeze(0))
411
    for lstm_param, custom_param in zip(lstm.all_weights[0], rnn.parameters()):
412
        assert lstm_param.shape == custom_param.shape
413
        with torch.no_grad():
414
            lstm_param.copy_(custom_param)
415
    lstm_out, lstm_out_state = lstm(inp, lstm_state)
416

417
    assert (out - lstm_out).abs().max() < 1e-5
418
    assert (out_state[0] - lstm_out_state[0]).abs().max() < 1e-5
419
    assert (out_state[1] - lstm_out_state[1]).abs().max() < 1e-5
420

421

422
def test_script_stacked_rnn(seq_len, batch, input_size, hidden_size, num_layers):
423
    inp = torch.randn(seq_len, batch, input_size)
424
    states = [
425
        LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
426
        for _ in range(num_layers)
427
    ]
428
    rnn = script_lstm(input_size, hidden_size, num_layers)
429
    out, out_state = rnn(inp, states)
430
    custom_state = flatten_states(out_state)
431

432
    # Control: pytorch native LSTM
433
    lstm = nn.LSTM(input_size, hidden_size, num_layers)
434
    lstm_state = flatten_states(states)
435
    for layer in range(num_layers):
436
        custom_params = list(rnn.parameters())[4 * layer : 4 * (layer + 1)]
437
        for lstm_param, custom_param in zip(lstm.all_weights[layer], custom_params):
438
            assert lstm_param.shape == custom_param.shape
439
            with torch.no_grad():
440
                lstm_param.copy_(custom_param)
441
    lstm_out, lstm_out_state = lstm(inp, lstm_state)
442

443
    assert (out - lstm_out).abs().max() < 1e-5
444
    assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5
445
    assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5
446

447

448
def test_script_stacked_bidir_rnn(seq_len, batch, input_size, hidden_size, num_layers):
449
    inp = torch.randn(seq_len, batch, input_size)
450
    states = [
451
        [
452
            LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
453
            for _ in range(2)
454
        ]
455
        for _ in range(num_layers)
456
    ]
457
    rnn = script_lstm(input_size, hidden_size, num_layers, bidirectional=True)
458
    out, out_state = rnn(inp, states)
459
    custom_state = double_flatten_states(out_state)
460

461
    # Control: pytorch native LSTM
462
    lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True)
463
    lstm_state = double_flatten_states(states)
464
    for layer in range(num_layers):
465
        for direct in range(2):
466
            index = 2 * layer + direct
467
            custom_params = list(rnn.parameters())[4 * index : 4 * index + 4]
468
            for lstm_param, custom_param in zip(lstm.all_weights[index], custom_params):
469
                assert lstm_param.shape == custom_param.shape
470
                with torch.no_grad():
471
                    lstm_param.copy_(custom_param)
472
    lstm_out, lstm_out_state = lstm(inp, lstm_state)
473

474
    assert (out - lstm_out).abs().max() < 1e-5
475
    assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5
476
    assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5
477

478

479
def test_script_stacked_lstm_dropout(
480
    seq_len, batch, input_size, hidden_size, num_layers
481
):
482
    inp = torch.randn(seq_len, batch, input_size)
483
    states = [
484
        LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
485
        for _ in range(num_layers)
486
    ]
487
    rnn = script_lstm(input_size, hidden_size, num_layers, dropout=True)
488

489
    # just a smoke test
490
    out, out_state = rnn(inp, states)
491

492

493
def test_script_stacked_lnlstm(seq_len, batch, input_size, hidden_size, num_layers):
494
    inp = torch.randn(seq_len, batch, input_size)
495
    states = [
496
        LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))
497
        for _ in range(num_layers)
498
    ]
499
    rnn = script_lnlstm(input_size, hidden_size, num_layers)
500

501
    # just a smoke test
502
    out, out_state = rnn(inp, states)
503

504

505
test_script_rnn_layer(5, 2, 3, 7)
506
test_script_stacked_rnn(5, 2, 3, 7, 4)
507
test_script_stacked_bidir_rnn(5, 2, 3, 7, 4)
508
test_script_stacked_lstm_dropout(5, 2, 3, 7, 4)
509
test_script_stacked_lnlstm(5, 2, 3, 7, 4)
510

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.