CSS-LM
214 строк · 9.8 Кб
1# coding=utf-8
2# Copyright 2018 The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Configuration base class and utilities."""
16
17
18import copy19import json20import logging21import os22
23from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP24from .file_utils import (25CONFIG_NAME,26MODEL_CARD_NAME,27TF2_WEIGHTS_NAME,28WEIGHTS_NAME,29cached_path,30hf_bucket_url,31is_remote_url,32)
33
34
35logger = logging.getLogger(__name__)36
37
38class ModelCard:39r""" Structured Model Card class.40Store model card as well as methods for loading/downloading/saving model cards.
41
42Please read the following paper for details and explanation on the sections:
43"Model Cards for Model Reporting"
44by Margaret Mitchell, Simone Wu,
45Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer,
46Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards.
47Link: https://arxiv.org/abs/1810.03993
48
49Note:
50A model card can be loaded and saved to disk.
51
52Parameters:
53"""
54
55def __init__(self, **kwargs):56# Recomended attributes from https://arxiv.org/abs/1810.03993 (see papers)57self.model_details = kwargs.pop("model_details", {})58self.intended_use = kwargs.pop("intended_use", {})59self.factors = kwargs.pop("factors", {})60self.metrics = kwargs.pop("metrics", {})61self.evaluation_data = kwargs.pop("evaluation_data", {})62self.training_data = kwargs.pop("training_data", {})63self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})64self.ethical_considerations = kwargs.pop("ethical_considerations", {})65self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})66
67# Open additional attributes68for key, value in kwargs.items():69try:70setattr(self, key, value)71except AttributeError as err:72logger.error("Can't set {} with value {} for {}".format(key, value, self))73raise err74
75def save_pretrained(self, save_directory_or_file):76""" Save a model card object to the directory or file `save_directory_or_file`.77"""
78if os.path.isdir(save_directory_or_file):79# If we save using the predefined names, we can load using `from_pretrained`80output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)81else:82output_model_card_file = save_directory_or_file83
84self.to_json_file(output_model_card_file)85logger.info("Model card saved in {}".format(output_model_card_file))86
87@classmethod88def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):89r""" Instantiate a :class:`~transformers.ModelCard` from a pre-trained model model card.90
91Parameters:
92pretrained_model_name_or_path: either:
93
94- a string with the `shortcut name` of a pre-trained model card to load from cache or download, e.g.: ``bert-base-uncased``.
95- a string with the `identifier name` of a pre-trained model card that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
96- a path to a `directory` containing a model card file saved using the :func:`~transformers.ModelCard.save_pretrained` method, e.g.: ``./my_model_directory/``.
97- a path or url to a saved model card JSON `file`, e.g.: ``./my_model_directory/modelcard.json``.
98
99cache_dir: (`optional`) string:
100Path to a directory in which a downloaded pre-trained model
101card should be cached if the standard cache should not be used.
102
103kwargs: (`optional`) dict: key/value pairs with which to update the ModelCard object after loading.
104
105- The values in kwargs of any keys which are model card attributes will be used to override the loaded values.
106- Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the `return_unused_kwargs` keyword parameter.
107
108proxies: (`optional`) dict, default None:
109A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
110The proxies are used on each request.
111
112find_from_standard_name: (`optional`) boolean, default True:
113If the pretrained_model_name_or_path ends with our standard model or config filenames, replace them with our standard modelcard filename.
114Can be used to directly feed a model/config url and access the colocated modelcard.
115
116return_unused_kwargs: (`optional`) bool:
117
118- If False, then this function returns just the final model card object.
119- If True, then this functions returns a tuple `(model card, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not model card attributes: ie the part of kwargs which has not been used to update `ModelCard` and is otherwise ignored.
120
121Examples::
122
123modelcard = ModelCard.from_pretrained('bert-base-uncased') # Download model card from S3 and cache.
124modelcard = ModelCard.from_pretrained('./test/saved_model/') # E.g. model card was saved using `save_pretrained('./test/saved_model/')`
125modelcard = ModelCard.from_pretrained('./test/saved_model/modelcard.json')
126modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
127
128"""
129cache_dir = kwargs.pop("cache_dir", None)130proxies = kwargs.pop("proxies", None)131find_from_standard_name = kwargs.pop("find_from_standard_name", True)132return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)133
134if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:135# For simplicity we use the same pretrained url than the configuration files136# but with a different suffix (modelcard.json). This suffix is replaced below.137model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]138elif os.path.isdir(pretrained_model_name_or_path):139model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME)140elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):141model_card_file = pretrained_model_name_or_path142else:143model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False)144
145if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:146model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)147model_card_file = model_card_file.replace(WEIGHTS_NAME, MODEL_CARD_NAME)148model_card_file = model_card_file.replace(TF2_WEIGHTS_NAME, MODEL_CARD_NAME)149
150try:151# Load from URL or cache if already cached152resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, proxies=proxies)153if resolved_model_card_file is None:154raise EnvironmentError155if resolved_model_card_file == model_card_file:156logger.info("loading model card file {}".format(model_card_file))157else:158logger.info(159"loading model card file {} from cache at {}".format(model_card_file, resolved_model_card_file)160)161# Load model card162modelcard = cls.from_json_file(resolved_model_card_file)163
164except (EnvironmentError, json.JSONDecodeError):165# We fall back on creating an empty model card166modelcard = cls()167
168# Update model card with kwargs if needed169to_remove = []170for key, value in kwargs.items():171if hasattr(modelcard, key):172setattr(modelcard, key, value)173to_remove.append(key)174for key in to_remove:175kwargs.pop(key, None)176
177logger.info("Model card: %s", str(modelcard))178if return_unused_kwargs:179return modelcard, kwargs180else:181return modelcard182
183@classmethod184def from_dict(cls, json_object):185"""Constructs a `ModelCard` from a Python dictionary of parameters."""186return cls(**json_object)187
188@classmethod189def from_json_file(cls, json_file):190"""Constructs a `ModelCard` from a json file of parameters."""191with open(json_file, "r", encoding="utf-8") as reader:192text = reader.read()193dict_obj = json.loads(text)194return cls(**dict_obj)195
196def __eq__(self, other):197return self.__dict__ == other.__dict__198
199def __repr__(self):200return str(self.to_json_string())201
202def to_dict(self):203"""Serializes this instance to a Python dictionary."""204output = copy.deepcopy(self.__dict__)205return output206
207def to_json_string(self):208"""Serializes this instance to a JSON string."""209return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"210
211def to_json_file(self, json_file_path):212""" Save this instance to a json file."""213with open(json_file_path, "w", encoding="utf-8") as writer:214writer.write(self.to_json_string())215