CSS-LM

Форк
0
/
modelcard.py 
214 строк · 9.8 Кб
1
# coding=utf-8
2
# Copyright 2018 The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Configuration base class and utilities."""
16

17

18
import copy
19
import json
20
import logging
21
import os
22

23
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
24
from .file_utils import (
25
    CONFIG_NAME,
26
    MODEL_CARD_NAME,
27
    TF2_WEIGHTS_NAME,
28
    WEIGHTS_NAME,
29
    cached_path,
30
    hf_bucket_url,
31
    is_remote_url,
32
)
33

34

35
logger = logging.getLogger(__name__)
36

37

38
class ModelCard:
39
    r""" Structured Model Card class.
40
        Store model card as well as methods for loading/downloading/saving model cards.
41

42
        Please read the following paper for details and explanation on the sections:
43
            "Model Cards for Model Reporting"
44
                by Margaret Mitchell, Simone Wu,
45
                Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer,
46
                Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards.
47
            Link: https://arxiv.org/abs/1810.03993
48

49
        Note:
50
            A model card can be loaded and saved to disk.
51

52
        Parameters:
53
    """
54

55
    def __init__(self, **kwargs):
56
        # Recomended attributes from https://arxiv.org/abs/1810.03993 (see papers)
57
        self.model_details = kwargs.pop("model_details", {})
58
        self.intended_use = kwargs.pop("intended_use", {})
59
        self.factors = kwargs.pop("factors", {})
60
        self.metrics = kwargs.pop("metrics", {})
61
        self.evaluation_data = kwargs.pop("evaluation_data", {})
62
        self.training_data = kwargs.pop("training_data", {})
63
        self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
64
        self.ethical_considerations = kwargs.pop("ethical_considerations", {})
65
        self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})
66

67
        # Open additional attributes
68
        for key, value in kwargs.items():
69
            try:
70
                setattr(self, key, value)
71
            except AttributeError as err:
72
                logger.error("Can't set {} with value {} for {}".format(key, value, self))
73
                raise err
74

75
    def save_pretrained(self, save_directory_or_file):
76
        """ Save a model card object to the directory or file `save_directory_or_file`.
77
        """
78
        if os.path.isdir(save_directory_or_file):
79
            # If we save using the predefined names, we can load using `from_pretrained`
80
            output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
81
        else:
82
            output_model_card_file = save_directory_or_file
83

84
        self.to_json_file(output_model_card_file)
85
        logger.info("Model card saved in {}".format(output_model_card_file))
86

87
    @classmethod
88
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
89
        r""" Instantiate a :class:`~transformers.ModelCard` from a pre-trained model model card.
90

91
        Parameters:
92
            pretrained_model_name_or_path: either:
93

94
                - a string with the `shortcut name` of a pre-trained model card to load from cache or download, e.g.: ``bert-base-uncased``.
95
                - a string with the `identifier name` of a pre-trained model card that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
96
                - a path to a `directory` containing a model card file saved using the :func:`~transformers.ModelCard.save_pretrained` method, e.g.: ``./my_model_directory/``.
97
                - a path or url to a saved model card JSON `file`, e.g.: ``./my_model_directory/modelcard.json``.
98

99
            cache_dir: (`optional`) string:
100
                Path to a directory in which a downloaded pre-trained model
101
                card should be cached if the standard cache should not be used.
102

103
            kwargs: (`optional`) dict: key/value pairs with which to update the ModelCard object after loading.
104

105
                - The values in kwargs of any keys which are model card attributes will be used to override the loaded values.
106
                - Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the `return_unused_kwargs` keyword parameter.
107

108
            proxies: (`optional`) dict, default None:
109
                A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
110
                The proxies are used on each request.
111

112
            find_from_standard_name: (`optional`) boolean, default True:
113
                If the pretrained_model_name_or_path ends with our standard model or config filenames, replace them with our standard modelcard filename.
114
                Can be used to directly feed a model/config url and access the colocated modelcard.
115

116
            return_unused_kwargs: (`optional`) bool:
117

118
                - If False, then this function returns just the final model card object.
119
                - If True, then this functions returns a tuple `(model card, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not model card attributes: ie the part of kwargs which has not been used to update `ModelCard` and is otherwise ignored.
120

121
        Examples::
122

123
            modelcard = ModelCard.from_pretrained('bert-base-uncased')    # Download model card from S3 and cache.
124
            modelcard = ModelCard.from_pretrained('./test/saved_model/')  # E.g. model card was saved using `save_pretrained('./test/saved_model/')`
125
            modelcard = ModelCard.from_pretrained('./test/saved_model/modelcard.json')
126
            modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
127

128
        """
129
        cache_dir = kwargs.pop("cache_dir", None)
130
        proxies = kwargs.pop("proxies", None)
131
        find_from_standard_name = kwargs.pop("find_from_standard_name", True)
132
        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
133

134
        if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
135
            # For simplicity we use the same pretrained url than the configuration files
136
            # but with a different suffix (modelcard.json). This suffix is replaced below.
137
            model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
138
        elif os.path.isdir(pretrained_model_name_or_path):
139
            model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME)
140
        elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
141
            model_card_file = pretrained_model_name_or_path
142
        else:
143
            model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False)
144

145
        if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
146
            model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
147
            model_card_file = model_card_file.replace(WEIGHTS_NAME, MODEL_CARD_NAME)
148
            model_card_file = model_card_file.replace(TF2_WEIGHTS_NAME, MODEL_CARD_NAME)
149

150
        try:
151
            # Load from URL or cache if already cached
152
            resolved_model_card_file = cached_path(model_card_file, cache_dir=cache_dir, proxies=proxies)
153
            if resolved_model_card_file is None:
154
                raise EnvironmentError
155
            if resolved_model_card_file == model_card_file:
156
                logger.info("loading model card file {}".format(model_card_file))
157
            else:
158
                logger.info(
159
                    "loading model card file {} from cache at {}".format(model_card_file, resolved_model_card_file)
160
                )
161
            # Load model card
162
            modelcard = cls.from_json_file(resolved_model_card_file)
163

164
        except (EnvironmentError, json.JSONDecodeError):
165
            # We fall back on creating an empty model card
166
            modelcard = cls()
167

168
        # Update model card with kwargs if needed
169
        to_remove = []
170
        for key, value in kwargs.items():
171
            if hasattr(modelcard, key):
172
                setattr(modelcard, key, value)
173
                to_remove.append(key)
174
        for key in to_remove:
175
            kwargs.pop(key, None)
176

177
        logger.info("Model card: %s", str(modelcard))
178
        if return_unused_kwargs:
179
            return modelcard, kwargs
180
        else:
181
            return modelcard
182

183
    @classmethod
184
    def from_dict(cls, json_object):
185
        """Constructs a `ModelCard` from a Python dictionary of parameters."""
186
        return cls(**json_object)
187

188
    @classmethod
189
    def from_json_file(cls, json_file):
190
        """Constructs a `ModelCard` from a json file of parameters."""
191
        with open(json_file, "r", encoding="utf-8") as reader:
192
            text = reader.read()
193
        dict_obj = json.loads(text)
194
        return cls(**dict_obj)
195

196
    def __eq__(self, other):
197
        return self.__dict__ == other.__dict__
198

199
    def __repr__(self):
200
        return str(self.to_json_string())
201

202
    def to_dict(self):
203
        """Serializes this instance to a Python dictionary."""
204
        output = copy.deepcopy(self.__dict__)
205
        return output
206

207
    def to_json_string(self):
208
        """Serializes this instance to a JSON string."""
209
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
210

211
    def to_json_file(self, json_file_path):
212
        """ Save this instance to a json file."""
213
        with open(json_file_path, "w", encoding="utf-8") as writer:
214
            writer.write(self.to_json_string())
215

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.