CSS-LM
504 строки · 25.6 Кб
1# coding=utf-8
2# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16""" Configuration base class and utilities."""
17
18
19import copy20import json21import logging22import os23from typing import Any, Dict, Tuple24
25from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url26
27
28logger = logging.getLogger(__name__)29
30
31class PretrainedConfig(object):32r""" Base class for all configuration classes.33Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving
34configurations.
35
36Note:
37A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
38initialize a model does **not** load the model weights.
39It only affects the model's configuration.
40
41Class attributes (overridden by derived classes)
42- **model_type** (:obj:`str`): An identifier for the model type, serialized into the JSON file, and used to
43recreate the correct object in :class:`~transformers.AutoConfig`.
44
45Args:
46output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`):
47Whether or not the model should return all hidden-states.
48output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
49Whether or not the model should returns all attentions.
50use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
51Whether or not the model should return the last key/values attentions (not used by all models).
52return_dict (:obj:`bool`, `optional`, defaults to :obj:`False`):
53Whether or not the model should return a :class:`~transformers.file_utils.ModelOutput` instead of a
54plain tuple.
55is_encoder_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
56Whether the model is used as an encoder/decoder or not.
57is_decoder (:obj:`bool`, `optional`, defaults to :obj:`False`):
58Whether the model is used as decoder or not (in which case it's used as an encoder).
59prune_heads (:obj:`Dict[int, List[int]]`, `optional`, defaults to :obj:`{}`):
60Pruned heads of the model. The keys are the selected layer indices and the associated values, the list
61of heads to prune in said layer.
62
63For instance ``{1: [0, 2], 2: [2, 3]}`` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer
642.
65xla_device (:obj:`bool`, `optional`):
66A flag to indicate if TPU are available or not.
67
68Parameters for sequence generation
69- **max_length** (:obj:`int`, `optional`, defaults to 20) -- Maximum length that will be used by
70default in the :obj:`generate` method of the model.
71- **min_length** (:obj:`int`, `optional`, defaults to 10) -- Minimum length that will be used by
72default in the :obj:`generate` method of the model.
73- **do_sample** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Flag that will be used by default in
74the :obj:`generate` method of the model. Whether or not to use sampling ; use greedy decoding otherwise.
75- **early_stopping** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Flag that will be used by
76default in the :obj:`generate` method of the model. Whether to stop the beam search when at least
77``num_beams`` sentences are finished per batch or not.
78- **num_beams** (:obj:`int`, `optional`, defaults to 1) -- Number of beams for beam search that will be
79used by default in the :obj:`generate` method of the model. 1 means no beam search.
80- **temperature** (:obj:`float`, `optional`, defaults to 1) -- The value used to module the next token
81probabilities that will be used by default in the :obj:`generate` method of the model. Must be strictly
82positive.
83- **top_k** (:obj:`int`, `optional`, defaults to 50) -- Number of highest probability vocabulary tokens to
84keep for top-k-filtering that will be used by default in the :obj:`generate` method of the model.
85- **top_p** (:obj:`float`, `optional`, defaults to 1) -- Value that will be used by default in the
86:obj:`generate` method of the model for ``top_p``. If set to float < 1, only the most probable tokens
87with probabilities that add up to ``top_p`` or highest are kept for generation.
88- **repetition_penalty** (:obj:`float`, `optional`, defaults to 1) -- Parameter for repetition penalty
89that will be used by default in the :obj:`generate` method of the model. 1.0 means no penalty.
90- **length_penalty** (:obj:`float`, `optional`, defaults to 1) -- Exponential penalty to the length that
91will be used by default in the :obj:`generate` method of the model.
92- **no_repeat_ngram_size** (:obj:`int`, `optional`, defaults to 0) -- Value that will be used by default
93in the :obj:`generate` method of the model for ``no_repeat_ngram_size``. If set to int > 0, all ngrams of
94that size can only occur once.
95- **bad_words_ids** (:obj:`List[int]`, `optional`) -- List of token ids that are not allowed to be
96generated that will be used by default in the :obj:`generate` method of the model. In order to get the
97tokens of the words that should not appear in the generated text, use
98:obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
99- **num_return_sequences** (:obj:`int`, `optional`, defaults to 1) -- Number of independently computed
100returned sequences for each element in the batch that will be used by default in the :obj:`generate`
101method of the model.
102
103Parameters for fine-tuning tasks
104- **architectures** (:obj:`List[str]`, `optional`) -- Model architectures that can be used with the
105model pretrained weights.
106- **finetuning_task** (:obj:`str`, `optional`) -- Name of the task used to fine-tune the model. This can be
107used when converting from an original (TensorFlow or PyTorch) checkpoint.
108- **id2label** (:obj:`List[str]`, `optional`) -- A map from index (for instance prediction index, or target
109index) to label.
110- **label2id** (:obj:`Dict[str, int]`, `optional`) -- A map from label to index for the model.
111- **num_labels** (:obj:`int`, `optional`) -- Number of labels to use in the last layer added to the model,
112typically for a classification task.
113- **task_specific_params** (:obj:`Dict[str, Any]`, `optional`) -- Additional keyword arguments to store for
114the current task.
115
116Parameters linked to the tokenizer
117- **prefix** (:obj:`str`, `optional`) -- A specific prompt that should be added at the beginning of each
118text before calling the model.
119- **bos_token_id** (:obj:`int`, `optional`)) -- The id of the `beginning-of-stream` token.
120- **pad_token_id** (:obj:`int`, `optional`)) -- The id of the `padding` token.
121- **eos_token_id** (:obj:`int`, `optional`)) -- The id of the `end-of-stream` token.
122- **decoder_start_token_id** (:obj:`int`, `optional`)) -- If an encoder-decoder model starts decoding with
123a different token than `bos`, the id of that token.
124
125PyTorch specific parameters
126- **torchscript** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should be
127used with Torchscript.
128
129TensorFlow specific parameters
130- **use_bfloat16** (:obj:`bool`, `optional`, defaults to :obj:`False`) -- Whether or not the model should
131use BFloat16 scalars (only used by some TensorFlow models).
132"""
133model_type: str = ""134
135def __init__(self, **kwargs):136# Attributes with defaults137self.return_dict = kwargs.pop("return_dict", False)138self.output_hidden_states = kwargs.pop("output_hidden_states", False)139self.output_attentions = kwargs.pop("output_attentions", False)140self.use_cache = kwargs.pop("use_cache", True) # Not used by all models141self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models142self.use_bfloat16 = kwargs.pop("use_bfloat16", False)143self.pruned_heads = kwargs.pop("pruned_heads", {})144
145# Is decoder is used in encoder-decoder models to differentiate encoder from decoder146self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)147self.is_decoder = kwargs.pop("is_decoder", False)148
149# Parameters for sequence generation150self.max_length = kwargs.pop("max_length", 20)151self.min_length = kwargs.pop("min_length", 0)152self.do_sample = kwargs.pop("do_sample", False)153self.early_stopping = kwargs.pop("early_stopping", False)154self.num_beams = kwargs.pop("num_beams", 1)155self.temperature = kwargs.pop("temperature", 1.0)156self.top_k = kwargs.pop("top_k", 50)157self.top_p = kwargs.pop("top_p", 1.0)158self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)159self.length_penalty = kwargs.pop("length_penalty", 1.0)160self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)161self.bad_words_ids = kwargs.pop("bad_words_ids", None)162self.num_return_sequences = kwargs.pop("num_return_sequences", 1)163
164# Fine-tuning task arguments165self.architectures = kwargs.pop("architectures", None)166self.finetuning_task = kwargs.pop("finetuning_task", None)167self.id2label = kwargs.pop("id2label", None)168self.label2id = kwargs.pop("label2id", None)169if self.id2label is not None:170kwargs.pop("num_labels", None)171self.id2label = dict((int(key), value) for key, value in self.id2label.items())172# Keys are always strings in JSON so convert ids to int here.173else:174self.num_labels = kwargs.pop("num_labels", 2)175
176# Tokenizer arguments TODO: eventually tokenizer and models should share the same config177self.prefix = kwargs.pop("prefix", None)178self.bos_token_id = kwargs.pop("bos_token_id", None)179self.pad_token_id = kwargs.pop("pad_token_id", None)180self.eos_token_id = kwargs.pop("eos_token_id", None)181self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)182
183# task specific arguments184self.task_specific_params = kwargs.pop("task_specific_params", None)185
186# TPU arguments187self.xla_device = kwargs.pop("xla_device", None)188
189# Additional attributes without default values190for key, value in kwargs.items():191try:192setattr(self, key, value)193except AttributeError as err:194logger.error("Can't set {} with value {} for {}".format(key, value, self))195raise err196
197@property198def use_return_dict(self) -> bool:199"""200:obj:`bool`: Whether or not return :class:`~transformers.file_utils.ModelOutput` instead of tuples.
201"""
202# If torchscript is set, force `return_dict=False` to avoid jit errors203return self.return_dict and not self.torchscript204
205@property206def num_labels(self) -> int:207"""208:obj:`int`: The number of labels for classification models.
209"""
210return len(self.id2label)211
212@num_labels.setter213def num_labels(self, num_labels: int):214self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}215self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))216
217def save_pretrained(self, save_directory: str):218"""219Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the
220:func:`~transformers.PretrainedConfig.from_pretrained` class method.
221
222Args:
223save_directory (:obj:`str`):
224Directory where the configuration JSON file will be saved (will be created if it does not exist).
225"""
226if os.path.isfile(save_directory):227raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory))228os.makedirs(save_directory, exist_ok=True)229# If we save using the predefined names, we can load using `from_pretrained`230output_config_file = os.path.join(save_directory, CONFIG_NAME)231
232self.to_json_file(output_config_file, use_diff=True)233logger.info("Configuration saved in {}".format(output_config_file))234
235@classmethod236def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "PretrainedConfig":237r"""238Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pretrained model
239configuration.
240
241Args:
242pretrained_model_name_or_path (:obj:`str`):
243This can be either:
244
245- the `shortcut name` of a pretrained model configuration to load from cache or download, e.g.,
246``bert-base-uncased``.
247- the `identifier name` of a pretrained model configuration that was uploaded to our S3 by any user,
248e.g., ``dbmdz/bert-base-german-cased``.
249- a path to a `directory` containing a configuration file saved using the
250:func:`~transformers.PretrainedConfig.save_pretrained` method, e.g., ``./my_model_directory/``.
251- a path or url to a saved configuration JSON `file`, e.g.,
252``./my_model_directory/configuration.json``.
253cache_dir (:obj:`str`, `optional`):
254Path to a directory in which a downloaded pretrained model configuration should be cached if the
255standard cache should not be used.
256force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
257Wheter or not to force to (re-)download the configuration files and override the cached versions if they
258exist.
259resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
260Whether or not to delete incompletely received file. Attempts to resume the download if such a file
261exists.
262proxies (:obj:`Dict[str, str]`, `optional`):
263A dictionary of proxy servers to use by protocol or endpoint, e.g.,
264:obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.`
265The proxies are used on each request.
266return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
267If :obj:`False`, then this function returns just the final configuration object.
268
269If :obj:`True`, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs`
270is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e.,
271the part of ``kwargs`` which has not been used to update ``config`` and is otherwise ignored.
272kwargs (:obj:`Dict[str, Any]`, `optional`):
273The values in kwargs of any keys which are configuration attributes will be used to override the loaded
274values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is
275controlled by the ``return_unused_kwargs`` keyword parameter.
276
277Returns:
278:class:`PretrainedConfig`: The configuration object instantiated from this pretrained model.
279
280Examples::
281
282# We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
283# derived class: BertConfig
284config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
285config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
286config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
287config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
288assert config.output_attention == True
289config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
290foo=False, return_unused_kwargs=True)
291assert config.output_attention == True
292assert unused_kwargs == {'foo': False}
293
294"""
295config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)296return cls.from_dict(config_dict, **kwargs)297
298@classmethod299def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:300"""301From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used
302for instantiating a :class:`~transformers.PretrainedConfig` using ``from_dict``.
303
304Parameters:
305pretrained_model_name_or_path (:obj:`str`):
306The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
307
308Returns:
309:obj:`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
310
311"""
312cache_dir = kwargs.pop("cache_dir", None)313force_download = kwargs.pop("force_download", False)314resume_download = kwargs.pop("resume_download", False)315proxies = kwargs.pop("proxies", None)316local_files_only = kwargs.pop("local_files_only", False)317
318if os.path.isdir(pretrained_model_name_or_path):319config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)320elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):321config_file = pretrained_model_name_or_path322else:323config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)324
325try:326# Load from URL or cache if already cached327resolved_config_file = cached_path(328config_file,329cache_dir=cache_dir,330force_download=force_download,331proxies=proxies,332resume_download=resume_download,333local_files_only=local_files_only,334)335# Load config dict336if resolved_config_file is None:337raise EnvironmentError338config_dict = cls._dict_from_json_file(resolved_config_file)339
340except EnvironmentError:341msg = (342f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"343f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"344f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"345)346raise EnvironmentError(msg)347
348except json.JSONDecodeError:349msg = (350"Couldn't reach server at '{}' to download configuration file or "351"configuration file is not a valid JSON file. "352"Please check network or file content here: {}.".format(config_file, resolved_config_file)353)354raise EnvironmentError(msg)355
356if resolved_config_file == config_file:357logger.info("loading configuration file {}".format(config_file))358else:359logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))360
361return config_dict, kwargs362
363@classmethod364def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":365"""366Instantiates a :class:`~transformers.PretrainedConfig` from a Python dictionary of parameters.
367
368Args:
369config_dict (:obj:`Dict[str, Any]`):
370Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
371retrieved from a pretrained checkpoint by leveraging the
372:func:`~transformers.PretrainedConfig.get_config_dict` method.
373kwargs (:obj:`Dict[str, Any]`):
374Additional parameters from which to initialize the configuration object.
375
376Returns:
377:class:`PretrainedConfig`: The configuration object instantiated from those parameters.
378"""
379return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)380
381config = cls(**config_dict)382
383if hasattr(config, "pruned_heads"):384config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())385
386# Update config with kwargs if needed387to_remove = []388for key, value in kwargs.items():389if hasattr(config, key):390setattr(config, key, value)391to_remove.append(key)392for key in to_remove:393kwargs.pop(key, None)394
395logger.info("Model config %s", str(config))396if return_unused_kwargs:397return config, kwargs398else:399return config400
401@classmethod402def from_json_file(cls, json_file: str) -> "PretrainedConfig":403"""404Instantiates a :class:`~transformers.PretrainedConfig` from the path to a JSON file of parameters.
405
406Args:
407json_file (:obj:`str`):
408Path to the JSON file containing the parameters.
409
410Returns:
411:class:`PretrainedConfig`: The configuration object instantiated from that JSON file.
412
413"""
414config_dict = cls._dict_from_json_file(json_file)415return cls(**config_dict)416
417@classmethod418def _dict_from_json_file(cls, json_file: str):419with open(json_file, "r", encoding="utf-8") as reader:420text = reader.read()421return json.loads(text)422
423def __eq__(self, other):424return self.__dict__ == other.__dict__425
426def __repr__(self):427return "{} {}".format(self.__class__.__name__, self.to_json_string())428
429def to_diff_dict(self) -> Dict[str, Any]:430"""431Removes all attributes from config which correspond to the default
432config attributes for better readability and serializes to a Python
433dictionary.
434
435Returns:
436:obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
437"""
438config_dict = self.to_dict()439
440# get the default config dict441default_config_dict = PretrainedConfig().to_dict()442
443serializable_config_dict = {}444
445# only serialize values that differ from the default config446for key, value in config_dict.items():447if key not in default_config_dict or value != default_config_dict[key]:448serializable_config_dict[key] = value449
450return serializable_config_dict451
452def to_dict(self) -> Dict[str, Any]:453"""454Serializes this instance to a Python dictionary.
455
456Returns:
457:obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
458"""
459output = copy.deepcopy(self.__dict__)460if hasattr(self.__class__, "model_type"):461output["model_type"] = self.__class__.model_type462return output463
464def to_json_string(self, use_diff: bool = True) -> str:465"""466Serializes this instance to a JSON string.
467
468Args:
469use_diff (:obj:`bool`, `optional`, defaults to :obj:`True`):
470If set to ``True``, only the difference between the config instance and the default
471``PretrainedConfig()`` is serialized to JSON string.
472
473Returns:
474:obj:`str`: String containing all the attributes that make up this configuration instance in JSON format.
475"""
476if use_diff is True:477config_dict = self.to_diff_dict()478else:479config_dict = self.to_dict()480return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"481
482def to_json_file(self, json_file_path: str, use_diff: bool = True):483"""484Save this instance to a JSON file.
485
486Args:
487json_file_path (:obj:`str`):
488Path to the JSON file in which this configuration instance's parameters will be saved.
489use_diff (:obj:`bool`, `optional`, defaults to :obj:`True`):
490If set to ``True``, only the difference between the config instance and the default
491``PretrainedConfig()`` is serialized to JSON file.
492"""
493with open(json_file_path, "w", encoding="utf-8") as writer:494writer.write(self.to_json_string(use_diff=use_diff))495
496def update(self, config_dict: Dict[str, Any]):497"""498Updates attributes of this class with attributes from ``config_dict``.
499
500Args:
501config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that shall be updated for this class.
502"""
503for key, value in config_dict.items():504setattr(self, key, value)505