pytorch
509 строк · 17.0 Кб
1import numbers2import warnings3from collections import namedtuple4from typing import List, Tuple5
6import torch7import torch.jit as jit8import torch.nn as nn9from torch import Tensor10from torch.nn import Parameter11
12"""
13Some helper classes for writing custom TorchScript LSTMs.
14
15Goals:
16- Classes are easy to read, use, and extend
17- Performance of custom LSTMs approach fused-kernel-levels of speed.
18
19A few notes about features we could add to clean up the below code:
20- Support enumerate with nn.ModuleList:
21https://github.com/pytorch/pytorch/issues/14471
22- Support enumerate/zip with lists:
23https://github.com/pytorch/pytorch/issues/15952
24- Support overriding of class methods:
25https://github.com/pytorch/pytorch/issues/10733
26- Support passing around user-defined namedtuple types for readability
27- Support slicing w/ range. It enables reversing lists easily.
28https://github.com/pytorch/pytorch/issues/10774
29- Multiline type annotations. List[List[Tuple[Tensor,Tensor]]] is verbose
30https://github.com/pytorch/pytorch/pull/14922
31"""
32
33
34def script_lstm(35input_size,36hidden_size,37num_layers,38bias=True,39batch_first=False,40dropout=False,41bidirectional=False,42):43"""Returns a ScriptModule that mimics a PyTorch native LSTM."""44
45# The following are not implemented.46assert bias47assert not batch_first48
49if bidirectional:50stack_type = StackedLSTM251layer_type = BidirLSTMLayer52dirs = 253elif dropout:54stack_type = StackedLSTMWithDropout55layer_type = LSTMLayer56dirs = 157else:58stack_type = StackedLSTM59layer_type = LSTMLayer60dirs = 161
62return stack_type(63num_layers,64layer_type,65first_layer_args=[LSTMCell, input_size, hidden_size],66other_layer_args=[LSTMCell, hidden_size * dirs, hidden_size],67)68
69
70def script_lnlstm(71input_size,72hidden_size,73num_layers,74bias=True,75batch_first=False,76dropout=False,77bidirectional=False,78decompose_layernorm=False,79):80"""Returns a ScriptModule that mimics a PyTorch native LSTM."""81
82# The following are not implemented.83assert bias84assert not batch_first85assert not dropout86
87if bidirectional:88stack_type = StackedLSTM289layer_type = BidirLSTMLayer90dirs = 291else:92stack_type = StackedLSTM93layer_type = LSTMLayer94dirs = 195
96return stack_type(97num_layers,98layer_type,99first_layer_args=[100LayerNormLSTMCell,101input_size,102hidden_size,103decompose_layernorm,104],105other_layer_args=[106LayerNormLSTMCell,107hidden_size * dirs,108hidden_size,109decompose_layernorm,110],111)112
113
114LSTMState = namedtuple("LSTMState", ["hx", "cx"])115
116
117def reverse(lst: List[Tensor]) -> List[Tensor]:118return lst[::-1]119
120
121class LSTMCell(jit.ScriptModule):122def __init__(self, input_size, hidden_size):123super().__init__()124self.input_size = input_size125self.hidden_size = hidden_size126self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))127self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))128self.bias_ih = Parameter(torch.randn(4 * hidden_size))129self.bias_hh = Parameter(torch.randn(4 * hidden_size))130
131@jit.script_method132def forward(133self, input: Tensor, state: Tuple[Tensor, Tensor]134) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:135hx, cx = state136gates = (137torch.mm(input, self.weight_ih.t())138+ self.bias_ih139+ torch.mm(hx, self.weight_hh.t())140+ self.bias_hh141)142ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)143
144ingate = torch.sigmoid(ingate)145forgetgate = torch.sigmoid(forgetgate)146cellgate = torch.tanh(cellgate)147outgate = torch.sigmoid(outgate)148
149cy = (forgetgate * cx) + (ingate * cellgate)150hy = outgate * torch.tanh(cy)151
152return hy, (hy, cy)153
154
155class LayerNorm(jit.ScriptModule):156def __init__(self, normalized_shape):157super().__init__()158if isinstance(normalized_shape, numbers.Integral):159normalized_shape = (normalized_shape,)160normalized_shape = torch.Size(normalized_shape)161
162# XXX: This is true for our LSTM / NLP use case and helps simplify code163assert len(normalized_shape) == 1164
165self.weight = Parameter(torch.ones(normalized_shape))166self.bias = Parameter(torch.zeros(normalized_shape))167self.normalized_shape = normalized_shape168
169@jit.script_method170def compute_layernorm_stats(self, input):171mu = input.mean(-1, keepdim=True)172sigma = input.std(-1, keepdim=True, unbiased=False)173return mu, sigma174
175@jit.script_method176def forward(self, input):177mu, sigma = self.compute_layernorm_stats(input)178return (input - mu) / sigma * self.weight + self.bias179
180
181class LayerNormLSTMCell(jit.ScriptModule):182def __init__(self, input_size, hidden_size, decompose_layernorm=False):183super().__init__()184self.input_size = input_size185self.hidden_size = hidden_size186self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))187self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))188# The layernorms provide learnable biases189
190if decompose_layernorm:191ln = LayerNorm192else:193ln = nn.LayerNorm194
195self.layernorm_i = ln(4 * hidden_size)196self.layernorm_h = ln(4 * hidden_size)197self.layernorm_c = ln(hidden_size)198
199@jit.script_method200def forward(201self, input: Tensor, state: Tuple[Tensor, Tensor]202) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:203hx, cx = state204igates = self.layernorm_i(torch.mm(input, self.weight_ih.t()))205hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t()))206gates = igates + hgates207ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)208
209ingate = torch.sigmoid(ingate)210forgetgate = torch.sigmoid(forgetgate)211cellgate = torch.tanh(cellgate)212outgate = torch.sigmoid(outgate)213
214cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate))215hy = outgate * torch.tanh(cy)216
217return hy, (hy, cy)218
219
220class LSTMLayer(jit.ScriptModule):221def __init__(self, cell, *cell_args):222super().__init__()223self.cell = cell(*cell_args)224
225@jit.script_method226def forward(227self, input: Tensor, state: Tuple[Tensor, Tensor]228) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:229inputs = input.unbind(0)230outputs = torch.jit.annotate(List[Tensor], [])231for i in range(len(inputs)):232out, state = self.cell(inputs[i], state)233outputs += [out]234return torch.stack(outputs), state235
236
237class ReverseLSTMLayer(jit.ScriptModule):238def __init__(self, cell, *cell_args):239super().__init__()240self.cell = cell(*cell_args)241
242@jit.script_method243def forward(244self, input: Tensor, state: Tuple[Tensor, Tensor]245) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:246inputs = reverse(input.unbind(0))247outputs = jit.annotate(List[Tensor], [])248for i in range(len(inputs)):249out, state = self.cell(inputs[i], state)250outputs += [out]251return torch.stack(reverse(outputs)), state252
253
254class BidirLSTMLayer(jit.ScriptModule):255__constants__ = ["directions"]256
257def __init__(self, cell, *cell_args):258super().__init__()259self.directions = nn.ModuleList(260[261LSTMLayer(cell, *cell_args),262ReverseLSTMLayer(cell, *cell_args),263]264)265
266@jit.script_method267def forward(268self, input: Tensor, states: List[Tuple[Tensor, Tensor]]269) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:270# List[LSTMState]: [forward LSTMState, backward LSTMState]271outputs = jit.annotate(List[Tensor], [])272output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])273# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471274i = 0275for direction in self.directions:276state = states[i]277out, out_state = direction(input, state)278outputs += [out]279output_states += [out_state]280i += 1281return torch.cat(outputs, -1), output_states282
283
284def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args):285layers = [layer(*first_layer_args)] + [286layer(*other_layer_args) for _ in range(num_layers - 1)287]288return nn.ModuleList(layers)289
290
291class StackedLSTM(jit.ScriptModule):292__constants__ = ["layers"] # Necessary for iterating through self.layers293
294def __init__(self, num_layers, layer, first_layer_args, other_layer_args):295super().__init__()296self.layers = init_stacked_lstm(297num_layers, layer, first_layer_args, other_layer_args298)299
300@jit.script_method301def forward(302self, input: Tensor, states: List[Tuple[Tensor, Tensor]]303) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:304# List[LSTMState]: One state per layer305output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])306output = input307# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471308i = 0309for rnn_layer in self.layers:310state = states[i]311output, out_state = rnn_layer(output, state)312output_states += [out_state]313i += 1314return output, output_states315
316
317# Differs from StackedLSTM in that its forward method takes
318# List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM
319# except we don't support overriding script methods.
320# https://github.com/pytorch/pytorch/issues/10733
321class StackedLSTM2(jit.ScriptModule):322__constants__ = ["layers"] # Necessary for iterating through self.layers323
324def __init__(self, num_layers, layer, first_layer_args, other_layer_args):325super().__init__()326self.layers = init_stacked_lstm(327num_layers, layer, first_layer_args, other_layer_args328)329
330@jit.script_method331def forward(332self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]]333) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]:334# List[List[LSTMState]]: The outer list is for layers,335# inner list is for directions.336output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], [])337output = input338# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471339i = 0340for rnn_layer in self.layers:341state = states[i]342output, out_state = rnn_layer(output, state)343output_states += [out_state]344i += 1345return output, output_states346
347
348class StackedLSTMWithDropout(jit.ScriptModule):349# Necessary for iterating through self.layers and dropout support350__constants__ = ["layers", "num_layers"]351
352def __init__(self, num_layers, layer, first_layer_args, other_layer_args):353super().__init__()354self.layers = init_stacked_lstm(355num_layers, layer, first_layer_args, other_layer_args356)357# Introduces a Dropout layer on the outputs of each LSTM layer except358# the last layer, with dropout probability = 0.4.359self.num_layers = num_layers360
361if num_layers == 1:362warnings.warn(363"dropout lstm adds dropout layers after all but last "364"recurrent layer, it expects num_layers greater than "365"1, but got num_layers = 1"366)367
368self.dropout_layer = nn.Dropout(0.4)369
370@jit.script_method371def forward(372self, input: Tensor, states: List[Tuple[Tensor, Tensor]]373) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:374# List[LSTMState]: One state per layer375output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])376output = input377# XXX: enumerate https://github.com/pytorch/pytorch/issues/14471378i = 0379for rnn_layer in self.layers:380state = states[i]381output, out_state = rnn_layer(output, state)382# Apply the dropout layer except the last layer383if i < self.num_layers - 1:384output = self.dropout_layer(output)385output_states += [out_state]386i += 1387return output, output_states388
389
390def flatten_states(states):391states = list(zip(*states))392assert len(states) == 2393return [torch.stack(state) for state in states]394
395
396def double_flatten_states(states):397# XXX: Can probably write this in a nicer way398states = flatten_states([flatten_states(inner) for inner in states])399return [hidden.view([-1] + list(hidden.shape[2:])) for hidden in states]400
401
402def test_script_rnn_layer(seq_len, batch, input_size, hidden_size):403inp = torch.randn(seq_len, batch, input_size)404state = LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))405rnn = LSTMLayer(LSTMCell, input_size, hidden_size)406out, out_state = rnn(inp, state)407
408# Control: pytorch native LSTM409lstm = nn.LSTM(input_size, hidden_size, 1)410lstm_state = LSTMState(state.hx.unsqueeze(0), state.cx.unsqueeze(0))411for lstm_param, custom_param in zip(lstm.all_weights[0], rnn.parameters()):412assert lstm_param.shape == custom_param.shape413with torch.no_grad():414lstm_param.copy_(custom_param)415lstm_out, lstm_out_state = lstm(inp, lstm_state)416
417assert (out - lstm_out).abs().max() < 1e-5418assert (out_state[0] - lstm_out_state[0]).abs().max() < 1e-5419assert (out_state[1] - lstm_out_state[1]).abs().max() < 1e-5420
421
422def test_script_stacked_rnn(seq_len, batch, input_size, hidden_size, num_layers):423inp = torch.randn(seq_len, batch, input_size)424states = [425LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))426for _ in range(num_layers)427]428rnn = script_lstm(input_size, hidden_size, num_layers)429out, out_state = rnn(inp, states)430custom_state = flatten_states(out_state)431
432# Control: pytorch native LSTM433lstm = nn.LSTM(input_size, hidden_size, num_layers)434lstm_state = flatten_states(states)435for layer in range(num_layers):436custom_params = list(rnn.parameters())[4 * layer : 4 * (layer + 1)]437for lstm_param, custom_param in zip(lstm.all_weights[layer], custom_params):438assert lstm_param.shape == custom_param.shape439with torch.no_grad():440lstm_param.copy_(custom_param)441lstm_out, lstm_out_state = lstm(inp, lstm_state)442
443assert (out - lstm_out).abs().max() < 1e-5444assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5445assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5446
447
448def test_script_stacked_bidir_rnn(seq_len, batch, input_size, hidden_size, num_layers):449inp = torch.randn(seq_len, batch, input_size)450states = [451[452LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))453for _ in range(2)454]455for _ in range(num_layers)456]457rnn = script_lstm(input_size, hidden_size, num_layers, bidirectional=True)458out, out_state = rnn(inp, states)459custom_state = double_flatten_states(out_state)460
461# Control: pytorch native LSTM462lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True)463lstm_state = double_flatten_states(states)464for layer in range(num_layers):465for direct in range(2):466index = 2 * layer + direct467custom_params = list(rnn.parameters())[4 * index : 4 * index + 4]468for lstm_param, custom_param in zip(lstm.all_weights[index], custom_params):469assert lstm_param.shape == custom_param.shape470with torch.no_grad():471lstm_param.copy_(custom_param)472lstm_out, lstm_out_state = lstm(inp, lstm_state)473
474assert (out - lstm_out).abs().max() < 1e-5475assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5476assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5477
478
479def test_script_stacked_lstm_dropout(480seq_len, batch, input_size, hidden_size, num_layers481):482inp = torch.randn(seq_len, batch, input_size)483states = [484LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))485for _ in range(num_layers)486]487rnn = script_lstm(input_size, hidden_size, num_layers, dropout=True)488
489# just a smoke test490out, out_state = rnn(inp, states)491
492
493def test_script_stacked_lnlstm(seq_len, batch, input_size, hidden_size, num_layers):494inp = torch.randn(seq_len, batch, input_size)495states = [496LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size))497for _ in range(num_layers)498]499rnn = script_lnlstm(input_size, hidden_size, num_layers)500
501# just a smoke test502out, out_state = rnn(inp, states)503
504
505test_script_rnn_layer(5, 2, 3, 7)506test_script_stacked_rnn(5, 2, 3, 7, 4)507test_script_stacked_bidir_rnn(5, 2, 3, 7, 4)508test_script_stacked_lstm_dropout(5, 2, 3, 7, 4)509test_script_stacked_lnlstm(5, 2, 3, 7, 4)510