CSS-LM
714 строк · 30.3 Кб
1# coding=utf-8
2# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""PyTorch OpenAI GPT model."""
17
18
19import json
20import logging
21import math
22import os
23import warnings
24from dataclasses import dataclass
25from typing import Optional, Tuple
26
27import torch
28import torch.nn as nn
29from torch.nn import CrossEntropyLoss
30
31from .activations import gelu_new, swish
32from .configuration_openai import OpenAIGPTConfig
33from .file_utils import (
34ModelOutput,
35add_code_sample_docstrings,
36add_start_docstrings,
37add_start_docstrings_to_callable,
38replace_return_docstrings,
39)
40from .modeling_outputs import BaseModelOutput, CausalLMOutput
41from .modeling_utils import (
42Conv1D,
43PreTrainedModel,
44SequenceSummary,
45find_pruneable_heads_and_indices,
46prune_conv1d_layer,
47)
48
49
50logger = logging.getLogger(__name__)
51
52_CONFIG_FOR_DOC = "OpenAIGPTConfig"
53_TOKENIZER_FOR_DOC = "OpenAIGPTTokenizer"
54
55OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
56"openai-gpt",
57# See all OpenAI GPT models at https://huggingface.co/models?filter=openai-gpt
58]
59
60
61def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
62""" Load tf pre-trained weights in a pytorch model (from NumPy arrays here)
63"""
64import re
65import numpy as np
66
67if ".ckpt" in openai_checkpoint_folder_path:
68openai_checkpoint_folder_path = os.path.dirname(openai_checkpoint_folder_path)
69
70logger.info("Loading weights from {}".format(openai_checkpoint_folder_path))
71
72with open(openai_checkpoint_folder_path + "/parameters_names.json", "r", encoding="utf-8") as names_handle:
73names = json.load(names_handle)
74with open(openai_checkpoint_folder_path + "/params_shapes.json", "r", encoding="utf-8") as shapes_handle:
75shapes = json.load(shapes_handle)
76offsets = np.cumsum([np.prod(shape) for shape in shapes])
77init_params = [np.load(openai_checkpoint_folder_path + "/params_{}.npy".format(n)) for n in range(10)]
78init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
79init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
80
81# This was used when we had a single embedding matrix for positions and tokens
82# init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
83# del init_params[1]
84init_params = [arr.squeeze() for arr in init_params]
85
86try:
87assert model.tokens_embed.weight.shape == init_params[1].shape
88assert model.positions_embed.weight.shape == init_params[0].shape
89except AssertionError as e:
90e.args += (model.tokens_embed.weight.shape, init_params[1].shape)
91e.args += (model.positions_embed.weight.shape, init_params[0].shape)
92raise
93
94model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
95model.positions_embed.weight.data = torch.from_numpy(init_params[0])
96names.pop(0)
97# Pop position and token embedding arrays
98init_params.pop(0)
99init_params.pop(0)
100
101for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
102name = name[6:] # skip "model/"
103assert name[-2:] == ":0"
104name = name[:-2]
105name = name.split("/")
106pointer = model
107for m_name in name:
108if re.fullmatch(r"[A-Za-z]+\d+", m_name):
109scope_names = re.split(r"(\d+)", m_name)
110else:
111scope_names = [m_name]
112if scope_names[0] == "g":
113pointer = getattr(pointer, "weight")
114elif scope_names[0] == "b":
115pointer = getattr(pointer, "bias")
116elif scope_names[0] == "w":
117pointer = getattr(pointer, "weight")
118else:
119pointer = getattr(pointer, scope_names[0])
120if len(scope_names) >= 2:
121num = int(scope_names[1])
122pointer = pointer[num]
123try:
124assert pointer.shape == array.shape
125except AssertionError as e:
126e.args += (pointer.shape, array.shape)
127raise
128try:
129assert pointer.shape == array.shape
130except AssertionError as e:
131e.args += (pointer.shape, array.shape)
132raise
133logger.info("Initialize PyTorch weight {}".format(name))
134pointer.data = torch.from_numpy(array)
135return model
136
137
138ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu_new}
139
140
141class Attention(nn.Module):
142def __init__(self, nx, n_ctx, config, scale=False):
143super().__init__()
144n_state = nx # in Attention: n_state=768 (nx=n_embd)
145# [switch nx => n_state from Block to Attention to keep identical to TF implem]
146assert n_state % config.n_head == 0
147self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
148self.n_head = config.n_head
149self.split_size = n_state
150self.scale = scale
151
152self.c_attn = Conv1D(n_state * 3, nx)
153self.c_proj = Conv1D(n_state, nx)
154self.attn_dropout = nn.Dropout(config.attn_pdrop)
155self.resid_dropout = nn.Dropout(config.resid_pdrop)
156self.pruned_heads = set()
157
158def prune_heads(self, heads):
159if len(heads) == 0:
160return
161heads, index = find_pruneable_heads_and_indices(
162heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
163)
164index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
165# Prune conv1d layers
166self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
167self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
168# Update hyper params
169self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
170self.n_head = self.n_head - len(heads)
171self.pruned_heads = self.pruned_heads.union(heads)
172
173def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
174w = torch.matmul(q, k)
175if self.scale:
176w = w / math.sqrt(v.size(-1))
177# w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights
178# XD: self.b may be larger than w, so we need to crop it
179b = self.bias[:, :, : w.size(-2), : w.size(-1)]
180w = w * b + -1e4 * (1 - b)
181
182if attention_mask is not None:
183# Apply the attention mask
184w = w + attention_mask
185
186w = nn.Softmax(dim=-1)(w)
187w = self.attn_dropout(w)
188
189# Mask heads if we want to
190if head_mask is not None:
191w = w * head_mask
192
193outputs = [torch.matmul(w, v)]
194if output_attentions:
195outputs.append(w)
196return outputs
197
198def merge_heads(self, x):
199x = x.permute(0, 2, 1, 3).contiguous()
200new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
201return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
202
203def split_heads(self, x, k=False):
204new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
205x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
206if k:
207return x.permute(0, 2, 3, 1)
208else:
209return x.permute(0, 2, 1, 3)
210
211def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
212x = self.c_attn(x)
213query, key, value = x.split(self.split_size, dim=2)
214query = self.split_heads(query)
215key = self.split_heads(key, k=True)
216value = self.split_heads(value)
217
218attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
219a = attn_outputs[0]
220
221a = self.merge_heads(a)
222a = self.c_proj(a)
223a = self.resid_dropout(a)
224
225outputs = [a] + attn_outputs[1:]
226return outputs # a, (attentions)
227
228
229class MLP(nn.Module):
230def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
231super().__init__()
232nx = config.n_embd
233self.c_fc = Conv1D(n_state, nx)
234self.c_proj = Conv1D(nx, n_state)
235self.act = ACT_FNS[config.afn]
236self.dropout = nn.Dropout(config.resid_pdrop)
237
238def forward(self, x):
239h = self.act(self.c_fc(x))
240h2 = self.c_proj(h)
241return self.dropout(h2)
242
243
244class Block(nn.Module):
245def __init__(self, n_ctx, config, scale=False):
246super().__init__()
247nx = config.n_embd
248self.attn = Attention(nx, n_ctx, config, scale)
249self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
250self.mlp = MLP(4 * nx, config)
251self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
252
253def forward(self, x, attention_mask=None, head_mask=None, output_attentions=False):
254attn_outputs = self.attn(
255x, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions,
256)
257a = attn_outputs[0]
258
259n = self.ln_1(x + a)
260m = self.mlp(n)
261h = self.ln_2(n + m)
262
263outputs = [h] + attn_outputs[1:]
264return outputs
265
266
267class OpenAIGPTPreTrainedModel(PreTrainedModel):
268""" An abstract class to handle weights initialization and
269a simple interface for downloading and loading pretrained models.
270"""
271
272config_class = OpenAIGPTConfig
273load_tf_weights = load_tf_weights_in_openai_gpt
274base_model_prefix = "transformer"
275authorized_missing_keys = [r"position_ids"]
276
277def _init_weights(self, module):
278""" Initialize the weights.
279"""
280if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
281# Slightly different from the TF version which uses truncated_normal for initialization
282# cf https://github.com/pytorch/pytorch/pull/5617
283module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
284if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
285module.bias.data.zero_()
286elif isinstance(module, nn.LayerNorm):
287module.bias.data.zero_()
288module.weight.data.fill_(1.0)
289
290
291@dataclass
292class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
293"""
294Base class for outputs of models predicting if two sentences are consecutive or not.
295
296Args:
297lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
298Language modeling loss.
299mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
300Multiple choice classification loss.
301lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
302Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
303mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
304Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
305hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
306Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
307of shape :obj:`(batch_size, sequence_length, hidden_size)`.
308
309Hidden-states of the model at the output of each layer plus the initial embedding outputs.
310attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
311Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
312:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
313
314Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
315heads.
316"""
317
318lm_loss: Optional[torch.FloatTensor] = None
319mc_loss: Optional[torch.FloatTensor] = None
320lm_logits: torch.FloatTensor = None
321mc_logits: torch.FloatTensor = None
322hidden_states: Optional[Tuple[torch.FloatTensor]] = None
323attentions: Optional[Tuple[torch.FloatTensor]] = None
324
325
326OPENAI_GPT_START_DOCSTRING = r"""
327
328This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
329Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
330usage and behavior.
331
332Parameters:
333config (:class:`~transformers.OpenAIGPTConfig`): Model configuration class with all the parameters of the model.
334Initializing with a config file does not load the weights associated with the model, only the configuration.
335Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
336"""
337
338OPENAI_GPT_INPUTS_DOCSTRING = r"""
339Args:
340input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
341Indices of input sequence tokens in the vocabulary.
342
343Indices can be obtained using :class:`transformers.OpenAIGPTTokenizer`.
344See :func:`transformers.PreTrainedTokenizer.encode` and
345:func:`transformers.PreTrainedTokenizer.__call__` for details.
346
347`What are input IDs? <../glossary.html#input-ids>`__
348attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
349Mask to avoid performing attention on padding token indices.
350Mask values selected in ``[0, 1]``:
351``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
352
353`What are attention masks? <../glossary.html#attention-mask>`__
354token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
355Segment token indices to indicate first and second portions of the inputs.
356Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
357corresponds to a `sentence B` token
358
359`What are token type IDs? <../glossary.html#token-type-ids>`_
360position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
361Indices of positions of each input sequence tokens in the position embeddings.
362Selected in the range ``[0, config.max_position_embeddings - 1]``.
363
364`What are position IDs? <../glossary.html#position-ids>`_
365head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
366Mask to nullify selected heads of the self-attention modules.
367Mask values selected in ``[0, 1]``:
368:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
369inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
370Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
371This is useful if you want more control over how to convert `input_ids` indices into associated vectors
372than the model's internal embedding lookup matrix.
373output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
374If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
375output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
376If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
377return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
378If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
379plain tuple.
380"""
381
382
383@add_start_docstrings(
384"The bare OpenAI GPT transformer model outputting raw hidden-states without any specific head on top.",
385OPENAI_GPT_START_DOCSTRING,
386)
387class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
388def __init__(self, config):
389super().__init__(config)
390
391self.tokens_embed = nn.Embedding(config.vocab_size, config.n_embd)
392self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
393self.drop = nn.Dropout(config.embd_pdrop)
394self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
395
396self.register_buffer("position_ids", torch.arange(config.n_positions))
397self.init_weights()
398
399def get_input_embeddings(self):
400return self.tokens_embed
401
402def set_input_embeddings(self, new_embeddings):
403self.tokens_embed = new_embeddings
404
405def _prune_heads(self, heads_to_prune):
406""" Prunes heads of the model.
407heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
408"""
409for layer, heads in heads_to_prune.items():
410self.h[layer].attn.prune_heads(heads)
411
412@add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
413@add_code_sample_docstrings(
414tokenizer_class=_TOKENIZER_FOR_DOC,
415checkpoint="openai-gpt",
416output_type=BaseModelOutput,
417config_class=_CONFIG_FOR_DOC,
418)
419def forward(
420self,
421input_ids=None,
422attention_mask=None,
423token_type_ids=None,
424position_ids=None,
425head_mask=None,
426inputs_embeds=None,
427output_attentions=None,
428output_hidden_states=None,
429return_dict=None,
430):
431output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
432output_hidden_states = (
433output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
434)
435return_dict = return_dict if return_dict is not None else self.config.use_return_dict
436
437if input_ids is not None and inputs_embeds is not None:
438raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
439elif input_ids is not None:
440input_shape = input_ids.size()
441input_ids = input_ids.view(-1, input_shape[-1])
442elif inputs_embeds is not None:
443input_shape = inputs_embeds.size()[:-1]
444else:
445raise ValueError("You have to specify either input_ids or inputs_embeds")
446
447if position_ids is None:
448# Code is different from when we had a single embedding matrice from position and token embeddings
449position_ids = self.position_ids[None, : input_shape[-1]]
450
451# Attention mask.
452if attention_mask is not None:
453# We create a 3D attention mask from a 2D tensor mask.
454# Sizes are [batch_size, 1, 1, to_seq_length]
455# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
456# this attention mask is more simple than the triangular masking of causal attention
457# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
458attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
459
460# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
461# masked positions, this operation will create a tensor which is 0.0 for
462# positions we want to attend and -10000.0 for masked positions.
463# Since we are adding it to the raw scores before the softmax, this is
464# effectively the same as removing these entirely.
465attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
466attention_mask = (1.0 - attention_mask) * -10000.0
467
468# Prepare head mask if needed
469head_mask = self.get_head_mask(head_mask, self.config.n_layer)
470
471if inputs_embeds is None:
472inputs_embeds = self.tokens_embed(input_ids)
473position_embeds = self.positions_embed(position_ids)
474if token_type_ids is not None:
475token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
476token_type_embeds = self.tokens_embed(token_type_ids)
477else:
478token_type_embeds = 0
479hidden_states = inputs_embeds + position_embeds + token_type_embeds
480hidden_states = self.drop(hidden_states)
481
482output_shape = input_shape + (hidden_states.size(-1),)
483
484all_attentions = () if output_attentions else None
485all_hidden_states = () if output_hidden_states else None
486for i, block in enumerate(self.h):
487if output_hidden_states:
488all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
489
490outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions)
491hidden_states = outputs[0]
492if output_attentions:
493all_attentions = all_attentions + (outputs[1],)
494
495hidden_states = hidden_states.view(*output_shape)
496# Add last layer
497if output_hidden_states:
498all_hidden_states = all_hidden_states + (hidden_states,)
499
500if not return_dict:
501return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
502
503return BaseModelOutput(
504last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions,
505)
506
507
508@add_start_docstrings(
509"""OpenAI GPT Model transformer with a language modeling head on top
510(linear layer with weights tied to the input embeddings). """,
511OPENAI_GPT_START_DOCSTRING,
512)
513class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
514def __init__(self, config):
515super().__init__(config)
516self.transformer = OpenAIGPTModel(config)
517self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
518
519self.init_weights()
520
521def get_output_embeddings(self):
522return self.lm_head
523
524@add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
525@add_code_sample_docstrings(
526tokenizer_class=_TOKENIZER_FOR_DOC,
527checkpoint="openai-gpt",
528output_type=CausalLMOutput,
529config_class=_CONFIG_FOR_DOC,
530)
531def forward(
532self,
533input_ids=None,
534attention_mask=None,
535token_type_ids=None,
536position_ids=None,
537head_mask=None,
538inputs_embeds=None,
539labels=None,
540output_attentions=None,
541output_hidden_states=None,
542return_dict=None,
543):
544r"""
545labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
546Labels for language modeling.
547Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
548Indices are selected in ``[-100, 0, ..., config.vocab_size]``
549All labels set to ``-100`` are ignored (masked), the loss is only
550computed for labels in ``[0, ..., config.vocab_size]``
551"""
552return_dict = return_dict if return_dict is not None else self.config.use_return_dict
553
554transformer_outputs = self.transformer(
555input_ids,
556attention_mask=attention_mask,
557token_type_ids=token_type_ids,
558position_ids=position_ids,
559head_mask=head_mask,
560inputs_embeds=inputs_embeds,
561output_attentions=output_attentions,
562output_hidden_states=output_hidden_states,
563return_dict=return_dict,
564)
565hidden_states = transformer_outputs[0]
566lm_logits = self.lm_head(hidden_states)
567
568loss = None
569if labels is not None:
570# Shift so that tokens < n predict n
571shift_logits = lm_logits[..., :-1, :].contiguous()
572shift_labels = labels[..., 1:].contiguous()
573# Flatten the tokens
574loss_fct = CrossEntropyLoss()
575loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
576
577if not return_dict:
578output = (lm_logits,) + transformer_outputs[1:]
579return ((loss,) + output) if loss is not None else output
580
581return CausalLMOutput(
582loss=loss,
583logits=lm_logits,
584hidden_states=transformer_outputs.hidden_states,
585attentions=transformer_outputs.attentions,
586)
587
588
589@add_start_docstrings(
590"""OpenAI GPT Model transformer with a language modeling and a multiple-choice classification
591head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
592The language modeling head has its weights tied to the input embeddings,
593the classification head takes as input the input of a specified classification token index in the input sequence).
594""",
595OPENAI_GPT_START_DOCSTRING,
596)
597class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
598def __init__(self, config):
599super().__init__(config)
600
601config.num_labels = 1
602self.transformer = OpenAIGPTModel(config)
603self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
604self.multiple_choice_head = SequenceSummary(config)
605
606self.init_weights()
607
608def get_output_embeddings(self):
609return self.lm_head
610
611@add_start_docstrings_to_callable(OPENAI_GPT_INPUTS_DOCSTRING)
612@replace_return_docstrings(output_type=OpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
613def forward(
614self,
615input_ids=None,
616attention_mask=None,
617token_type_ids=None,
618position_ids=None,
619head_mask=None,
620inputs_embeds=None,
621mc_token_ids=None,
622labels=None,
623mc_labels=None,
624output_attentions=None,
625output_hidden_states=None,
626return_dict=None,
627**kwargs
628):
629r"""
630mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
631Index of the classification token in each input sequence.
632Selected in the range ``[0, input_ids.size(-1) - 1]``.
633labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`)
634Labels for language modeling.
635Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
636Indices are selected in ``[-1, 0, ..., config.vocab_size]``
637All labels set to ``-100`` are ignored (masked), the loss is only
638computed for labels in ``[0, ..., config.vocab_size]``
639mc_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size)`, `optional`, defaults to :obj:`None`)
640Labels for computing the multiple choice classification loss.
641Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
642of the input tensors. (see `input_ids` above)
643kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
644Used to hide legacy arguments that have been deprecated.
645
646Return:
647
648Examples::
649
650from transformers import OpenAIGPTTokenizer, OpenAIGPTDoubleHeadsModel
651import torch
652
653tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
654model = OpenAIGPTDoubleHeadsModel.from_pretrained('openai-gpt', return_dict=True)
655tokenizer.add_special_tokens({'cls_token': '[CLS]'}) # Add a [CLS] to the vocabulary (we should train it also!)
656model.resize_token_embeddings(len(tokenizer))
657
658choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
659input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
660mc_token_ids = torch.tensor([input_ids.size(-1)-1, input_ids.size(-1)-1]).unsqueeze(0) # Batch size 1
661
662outputs = model(input_ids, mc_token_ids=mc_token_ids)
663lm_logits = outputs.lm_logits
664mc_logits = outputs.mc_logits
665"""
666return_dict = return_dict if return_dict is not None else self.config.use_return_dict
667if "lm_labels" in kwargs:
668warnings.warn(
669"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
670FutureWarning,
671)
672labels = kwargs.pop("lm_labels")
673assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
674
675transformer_outputs = self.transformer(
676input_ids,
677attention_mask=attention_mask,
678token_type_ids=token_type_ids,
679position_ids=position_ids,
680head_mask=head_mask,
681inputs_embeds=inputs_embeds,
682output_attentions=output_attentions,
683output_hidden_states=output_hidden_states,
684return_dict=return_dict,
685)
686hidden_states = transformer_outputs[0]
687
688lm_logits = self.lm_head(hidden_states)
689mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1)
690
691lm_loss, mc_loss = None, None
692if mc_labels is not None:
693loss_fct = CrossEntropyLoss()
694mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))
695if labels is not None:
696shift_logits = lm_logits[..., :-1, :].contiguous()
697shift_labels = labels[..., 1:].contiguous()
698loss_fct = CrossEntropyLoss()
699lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
700
701if not return_dict:
702output = (lm_logits, mc_logits) + transformer_outputs[1:]
703if mc_loss is not None:
704output = (mc_loss,) + output
705return ((lm_loss,) + output) if lm_loss is not None else output
706
707return OpenAIGPTDoubleHeadsModelOutput(
708lm_loss=lm_loss,
709mc_loss=mc_loss,
710lm_logits=lm_logits,
711mc_logits=mc_logits,
712hidden_states=transformer_outputs.hidden_states,
713attentions=transformer_outputs.attentions,
714)
715