CSS-LM
318 строк · 16.8 Кб
1# coding=utf-8
2# Copyright 2018 The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Classes to support Encoder-Decoder architectures """
16
17
18import logging19from typing import Optional20
21from .configuration_encoder_decoder import EncoderDecoderConfig22from .configuration_utils import PretrainedConfig23from .modeling_utils import PreTrainedModel24
25
26logger = logging.getLogger(__name__)27
28
29class EncoderDecoderModel(PreTrainedModel):30r"""31:class:`~transformers.EncoderDecoder` is a generic model class that will be
32instantiated as a transformer architecture with one of the base model
33classes of the library as encoder and another one as
34decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
35class method for the encoder and `AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
36"""
37config_class = EncoderDecoderConfig38base_model_prefix = "encoder_decoder"39
40def __init__(41self,42config: Optional[PretrainedConfig] = None,43encoder: Optional[PreTrainedModel] = None,44decoder: Optional[PreTrainedModel] = None,45):46assert config is not None or (47encoder is not None and decoder is not None48), "Either a configuration or an Encoder and a decoder has to be provided"49if config is None:50config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)51else:52assert isinstance(config, self.config_class), "config: {} has to be of type {}".format(53config, self.config_class54)55# initialize with config56super().__init__(config)57
58if encoder is None:59from .modeling_auto import AutoModel60
61encoder = AutoModel.from_config(config.encoder)62
63if decoder is None:64from .modeling_auto import AutoModelForCausalLM65
66decoder = AutoModelForCausalLM.from_config(config.decoder)67
68self.encoder = encoder69self.decoder = decoder70assert (71self.encoder.get_output_embeddings() is None72), "The encoder {} should not have a LM Head. Please use a model without LM Head"73
74def tie_weights(self):75# for now no weights tying in encoder-decoder76pass77
78def get_encoder(self):79return self.encoder80
81def get_decoder(self):82return self.decoder83
84def get_input_embeddings(self):85return self.encoder.get_input_embeddings()86
87def get_output_embeddings(self):88return self.decoder.get_output_embeddings()89
90@classmethod91def from_encoder_decoder_pretrained(92cls,93encoder_pretrained_model_name_or_path: str = None,94decoder_pretrained_model_name_or_path: str = None,95*model_args,96**kwargs97) -> PreTrainedModel:98r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.99
100
101The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
102To train the model, you need to first set it back in training mode with `model.train()`.
103
104Params:
105encoder_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
106information necessary to initiate the encoder. Either:
107
108- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
109- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
110- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/encoder``.
111- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
112
113decoder_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
114information necessary to initiate the decoder. Either:
115
116- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
117- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
118- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/decoder``.
119- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
120
121model_args: (`optional`) Sequence of positional arguments:
122All remaning positional arguments will be passed to the underlying model's ``__init__`` method
123
124kwargs: (`optional`) Remaining dictionary of keyword arguments.
125Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
126
127Examples::
128
129>>> from transformers import EncoderDecoderModel
130>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
131"""
132
133kwargs_encoder = {134argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")135}136
137kwargs_decoder = {138argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")139}140
141# Load and initialize the encoder and decoder142# The distinction between encoder and decoder at the model level is made143# by the value of the flag `is_decoder` that we need to set correctly.144encoder = kwargs_encoder.pop("model", None)145if encoder is None:146assert (147encoder_pretrained_model_name_or_path is not None148), "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"149from .modeling_auto import AutoModel150
151encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)152encoder.config.is_decoder = False153
154decoder = kwargs_decoder.pop("model", None)155if decoder is None:156assert (157decoder_pretrained_model_name_or_path is not None158), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"159from .modeling_auto import AutoModelForCausalLM160
161if "config" not in kwargs_decoder:162from .configuration_auto import AutoConfig163
164decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)165if decoder_config.is_decoder is False:166logger.info(167f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."168)169decoder_config.is_decoder = True170
171kwargs_decoder["config"] = decoder_config172
173if kwargs_decoder["config"].is_decoder is False:174logger.warning(175f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"176)177
178decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)179
180return cls(encoder=encoder, decoder=decoder)181
182def forward(183self,184input_ids=None,185inputs_embeds=None,186attention_mask=None,187head_mask=None,188encoder_outputs=None,189decoder_input_ids=None,190decoder_attention_mask=None,191decoder_head_mask=None,192decoder_inputs_embeds=None,193labels=None,194**kwargs,195):196
197"""198Args:
199input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
200Indices of input sequence tokens in the vocabulary for the encoder.
201Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
202See :func:`transformers.PreTrainedTokenizer.encode` and
203:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
204inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
205Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
206This is useful if you want more control over how to convert `input_ids` indices into associated vectors
207than the model's internal embedding lookup matrix.
208attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
209Mask to avoid performing attention on padding token indices for the encoder.
210Mask values selected in ``[0, 1]``:
211``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
212head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
213Mask to nullify selected heads of the self-attention modules for the encoder.
214Mask values selected in ``[0, 1]``:
215``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
216encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
217Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
218`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
219Used in the cross-attention of the decoder.
220decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
221Provide for sequence to sequence training to the decoder.
222Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
223See :func:`transformers.PreTrainedTokenizer.encode` and
224:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
225decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
226Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
227decoder_head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
228Mask to nullify selected heads of the self-attention modules for the decoder.
229Mask values selected in ``[0, 1]``:
230``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
231decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
232Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
233This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
234than the model's internal embedding lookup matrix.
235labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
236Labels for computing the masked language modeling loss for the decoder.
237Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
238Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
239in ``[0, ..., config.vocab_size]``
240kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
241- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
242- With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function.
243
244Examples::
245
246>>> from transformers import EncoderDecoderModel, BertTokenizer
247>>> import torch
248
249>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
250>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
251
252>>> # forward
253>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
254>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
255
256>>> # training
257>>> loss, outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)[:2]
258
259>>> # generation
260>>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
261
262"""
263
264kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}265
266kwargs_decoder = {267argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")268}269
270if encoder_outputs is None:271encoder_outputs = self.encoder(272input_ids=input_ids,273attention_mask=attention_mask,274inputs_embeds=inputs_embeds,275head_mask=head_mask,276return_dict=False,277**kwargs_encoder,278)279
280hidden_states = encoder_outputs[0]281
282# Decode283decoder_outputs = self.decoder(284input_ids=decoder_input_ids,285inputs_embeds=decoder_inputs_embeds,286attention_mask=decoder_attention_mask,287encoder_hidden_states=hidden_states,288encoder_attention_mask=attention_mask,289head_mask=decoder_head_mask,290labels=labels,291return_dict=False,292**kwargs_decoder,293)294
295return decoder_outputs + encoder_outputs296
297def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):298assert past is not None, "past has to be defined for encoder_outputs"299
300# first step301if type(past) is tuple:302encoder_outputs, _ = past303else:304encoder_outputs = (past,)305
306decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)307
308return {309"attention_mask": attention_mask,310"decoder_attention_mask": decoder_inputs["attention_mask"],311"decoder_input_ids": decoder_inputs["input_ids"],312"encoder_outputs": encoder_outputs,313}314
315def _reorder_cache(self, past, beam_idx):316# as a default encoder-decoder models do not re-order the past.317# TODO(PVP): might have to be updated, e.g. if GPT2 is to be used as a decoder318return past319