CSS-LM
148 строк · 5.6 Кб
1# coding=utf-8
2# Copyright 2020 The Fairseq Authors and The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" BART configuration """
16
17
18import logging
19
20from .configuration_utils import PretrainedConfig
21
22
23logger = logging.getLogger(__name__)
24
25BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
26"facebook/bart-base": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-base/config.json",
27"facebook/bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
28"facebook/bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
29"facebook/bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
30"facebook/bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
31"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
32"yjernite/bart_eli5": "https://s3.amazonaws.com/models.huggingface.co/bert/yjernite/bart_eli5/config.json",
33}
34
35
36class BartConfig(PretrainedConfig):
37r"""
38Configuration class for Bart. Parameters are renamed from the fairseq implementation
39"""
40model_type = "bart"
41
42def __init__(
43self,
44activation_dropout=0.0,
45extra_pos_embeddings=2,
46activation_function="gelu",
47vocab_size=50265,
48d_model=1024,
49encoder_ffn_dim=4096,
50encoder_layers=12,
51encoder_attention_heads=16,
52decoder_ffn_dim=4096,
53decoder_layers=12,
54decoder_attention_heads=16,
55encoder_layerdrop=0.0,
56decoder_layerdrop=0.0,
57attention_dropout=0.0,
58dropout=0.1,
59max_position_embeddings=1024,
60init_std=0.02,
61classifier_dropout=0.0,
62num_labels=3,
63is_encoder_decoder=True,
64pad_token_id=1,
65bos_token_id=0,
66eos_token_id=2,
67normalize_before=False,
68add_final_layer_norm=False,
69scale_embedding=False,
70normalize_embedding=True,
71static_position_embeddings=False,
72add_bias_logits=False,
73**common_kwargs
74):
75r"""
76:class:`~transformers.BartConfig` is the configuration class for `BartModel`.
77
78Examples::
79
80>>> from transformers import BartConfig, BartModel
81
82>>> config = BartConfig.from_pretrained('facebook/bart-large')
83>>> model = BartModel(config)
84"""
85if "hidden_size" in common_kwargs:
86raise ValueError("hidden size is called d_model")
87super().__init__(
88num_labels=num_labels,
89pad_token_id=pad_token_id,
90bos_token_id=bos_token_id,
91eos_token_id=eos_token_id,
92is_encoder_decoder=is_encoder_decoder,
93**common_kwargs,
94)
95self.vocab_size = vocab_size
96self.d_model = d_model # encoder_embed_dim and decoder_embed_dim
97self.encoder_ffn_dim = encoder_ffn_dim
98self.encoder_layers = self.num_hidden_layers = encoder_layers
99self.encoder_attention_heads = encoder_attention_heads
100self.encoder_layerdrop = encoder_layerdrop
101self.decoder_layerdrop = decoder_layerdrop
102self.decoder_ffn_dim = decoder_ffn_dim
103self.decoder_layers = decoder_layers
104self.decoder_attention_heads = decoder_attention_heads
105self.max_position_embeddings = max_position_embeddings
106self.init_std = init_std # Normal(0, this parameter)
107self.activation_function = activation_function
108
109# Params introduced for Mbart
110self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
111self.normalize_embedding = normalize_embedding # True for mbart, False otherwise
112self.normalize_before = normalize_before # combo of fairseq's encoder_ and decoder_normalize_before
113self.add_final_layer_norm = add_final_layer_norm
114
115# Params introduced for Marian
116self.add_bias_logits = add_bias_logits
117self.static_position_embeddings = static_position_embeddings
118
119# 3 Types of Dropout
120self.attention_dropout = attention_dropout
121self.activation_dropout = activation_dropout
122self.dropout = dropout
123
124# Classifier stuff
125self.classif_dropout = classifier_dropout
126
127# pos embedding offset
128self.extra_pos_embeddings = self.pad_token_id + 1
129
130@property
131def num_attention_heads(self) -> int:
132return self.encoder_attention_heads
133
134@property
135def hidden_size(self) -> int:
136return self.d_model
137
138def is_valid_mbart(self) -> bool:
139"""Is the configuration aligned with the MBART paper."""
140if self.normalize_before and self.add_final_layer_norm and self.scale_embedding:
141return True
142if self.normalize_before or self.add_final_layer_norm or self.scale_embedding:
143logger.info("This configuration is a mixture of MBART and BART settings")
144return False
145
146
147class MBartConfig(BartConfig):
148model_type = "mbart"
149