OpenDelta

Форк
0
/
auto_delta.py 
388 строк · 15.6 Кб
1
from copy import deepcopy
2
from typing import Any, Dict, OrderedDict
3
from bigmodelvis import Visualization
4
import torch.nn as nn
5
from opendelta.utils.logging import get_logger
6
import importlib
7
from opendelta.delta_configs import BaseDeltaConfig
8
from opendelta.basemodel import DeltaBase
9
logger = get_logger(__name__)
10

11

12
DELTA_CONFIG_MAPPING = {
13
    "lora": "LoraConfig",
14
    "low_rank_adapter": "LowRankAdapterConfig",
15
    "bitfit": "BitFitConfig",
16
    "adapter":"AdapterConfig",
17
    "compacter":"CompacterConfig",
18
    "prefix": "PrefixConfig",
19
    "soft_prompt": "SoftPromptConfig",
20
    "parallel_adapter": "ParallelAdapterConfig",
21
}
22

23
DELTA_MODEL_MAPPING = {
24
    "lora": "LoraModel",
25
    "low_rank_adapter": "LowRankAdapterModel",
26
    "bitfit": "BitFitModel",
27
    "adapter":"AdapterModel",
28
    "compacter": "CompacterModel",
29
    "prefix": "PrefixModel",
30
    "soft_prompt": "SoftPromptModel",
31
    "parallel_adapter": "ParallelAdapterModel",
32
}
33

34
class _LazyConfigMapping(OrderedDict):
35
    """
36
    A dictionary that lazily load its values when they are requested.
37
    """
38

39
    def __init__(self, mapping):
40
        self._mapping = mapping
41
        self._extra_content = {}
42
        self._modules = {}
43

44
    def __getitem__(self, key):
45
        if key in self._extra_content:
46
            return self._extra_content[key]
47
        if key not in self._mapping:
48
            raise KeyError(key)
49
        value = self._mapping[key]
50
        module_name = key #model_type_to_module_name(key)
51
        # if module_name not in self._modules:
52
        self._modules[module_name] = importlib.import_module(f".{module_name}", "opendelta.delta_models")
53
        return getattr(self._modules[module_name], value)
54

55
    def keys(self):
56
        return list(self._mapping.keys()) + list(self._extra_content.keys())
57

58
    def values(self):
59
        return [self[k] for k in self._mapping.keys()] + list(self._extra_content.values())
60

61
    def items(self):
62
        return [(k, self[k]) for k in self._mapping.keys()] + list(self._extra_content.items())
63

64
    def __iter__(self):
65
        return iter(list(self._mapping.keys()) + list(self._extra_content.keys()))
66

67
    def __contains__(self, item):
68
        return item in self._mapping or item in self._extra_content
69

70
    def register(self, key, value):
71
        """
72
        Register a new configuration in this mapping.
73
        """
74
        if key in self._mapping.keys():
75
            raise ValueError(f"'{key}' is already used by a Transformers config, pick another name.")
76
        self._extra_content[key] = value
77

78

79
LAZY_CONFIG_MAPPING = _LazyConfigMapping(DELTA_CONFIG_MAPPING)
80

81

82

83
class AutoDeltaConfig:
84
    r"""
85
    This is a generic configuration class that will be instantiated as one of the configuration classes of the library
86
    when created with the :meth:`~AutoDeltaConfig.from_finetuned` or :meth:`~AutoDeltaConfig.from_dict` class method. 
87
    This class cannot be instantiated directly using ``__init__()`` (throws an error).
88
    """
89

90
    def __init__(self, *args, **kwargs):
91
        raise AttributeError(
92
            f"{self.__class__.__name__} is designed to be instantiated using\n\t(1) `{self.__class__.__name__}.from_finetuned(finetuned_model_name_or_path)`\nor\t(2) `{self.__class__.__name__}.from_dict(config_dict, **kwargs)` "
93
        )
94

95
    @classmethod
96
    def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
97
        r""" Instantiate a DeltaConfig according to the dict. Automatically load the config specified by
98
        :obj:`delta_type`.
99

100
        Args:
101
            config_dict (:obj:`dict`): The dict of configs of delta model.
102
            kwargs: Other keyword argument pass to initialize the config.
103

104
        Examples:
105
        
106
        .. code-block:: python
107

108
            config = AutoDeltaConfig.from_dict({"delta_type":"lora"}) # This will load the dault lora config.
109
            config = AutoDeltaConfig.from_dict({"delta_type":"lora", "lora_r":5}) # Will load the default lora config, with lora_r = 5
110

111
        """
112
        config_dict = deepcopy(config_dict)
113
        delta_type = config_dict.pop("delta_type", None)
114
        if delta_type is None:
115
            raise RuntimeError("Do not specify a delta type, cannot load the default config")
116
        config_class = LAZY_CONFIG_MAPPING[delta_type]
117
        return config_class.from_dict(config_dict, **kwargs)
118

119

120
    @classmethod
121
    def from_finetuned(cls, finetuned_delta_path, **kwargs):
122
        r"""
123
        Instantiate one of the configuration classes of the library from a finetuned delta model configuration.
124
        The configuration class to instantiate is selected based on the ``delta_type`` property of the config object that
125
        is loaded.
126

127
        Parameters:
128

129
            finetuned_delta_path (:obj:`str` or :obj:`os.PathLike`, *optional*): Can be either:
130

131
                - A string, the model id of a finetuned delta model configuration hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like ``Davin/lora``, or namespaced under a user or organization name, like ``DeltaHub/lora_t5-base_mrpc``.
132
                - A path to a *directory* containing a configuration file saved using the :py:meth:`~opendelta.basemodel.DeltaBase.save_finetuned` method, e.g., ``./my_model_directory/``.
133
                - A path or url to a saved configuration JSON *file*, e.g.,``./my_model_directory/configuration.json``.
134

135
            cache_dir (:obj:`str` or :obj:`os.PathLike`, *optional*):
136
                Path to a directory in which a downloaded pretrained model configuration should be cached if the
137
                standard cache should not be used.
138
            
139
        Examples:
140

141
        .. code-block:: python
142

143
            from transformers import AutoConfig
144
            delta_config = AutoDeltaConfig.from_finetuned("thunlp/FactQA_T5-large_Adapter")
145

146
        """
147

148

149
        config_dict, kwargs = BaseDeltaConfig.get_config_dict(finetuned_delta_path, **kwargs)
150
        if "delta_type" in config_dict:
151
            config_class = LAZY_CONFIG_MAPPING[config_dict["delta_type"]]
152
            return config_class.from_dict(config_dict, **kwargs)
153
        else:
154
            # Fallback: use pattern matching on the string.
155
            for pattern, config_class in LAZY_CONFIG_MAPPING.items():
156
                if pattern in str(finetuned_delta_path):
157
                    return config_class.from_dict(config_dict, **kwargs)
158

159
        raise ValueError(
160
            f"Unrecognized model in {finetuned_delta_path}. "
161
            f"Should have a `delta_type` key in the loaded config, or contain one of the following strings "
162
            f"in its name: {', '.join(LAZY_CONFIG_MAPPING.keys())}"
163
        )
164

165
### AutoModels below
166

167
class _LazyAutoMapping(OrderedDict):
168
    """
169
    " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
170

171
    Args:
172

173
        - config_mapping: The map model type to config class
174
        - model_mapping: The map model type to model (or tokenizer) class
175
    """
176

177
    def __init__(self, config_mapping, model_mapping):
178
        self._config_mapping = config_mapping
179
        self._reverse_config_mapping = {v: k for k, v in config_mapping.items()}
180
        self._model_mapping = model_mapping
181
        self._extra_content = {}
182
        self._modules = {}
183

184
    def __getitem__(self, key):
185
        if key in self._extra_content:
186
            return self._extra_content[key]
187
        model_type = self._reverse_config_mapping[key.__name__]
188
        if model_type not in self._model_mapping:
189
            raise KeyError(key)
190
        model_name = self._model_mapping[model_type]
191
        return self._load_attr_from_module(model_type, model_name)
192

193
    def _load_attr_from_module(self, model_type, attr):
194
        if model_type not in self._modules:
195
            self._modules[model_type] = importlib.import_module(f".{model_type}", "opendelta.delta_models")
196
        return getattribute_from_module(self._modules[model_type], attr)
197

198
    def keys(self):
199
        mapping_keys = [
200
            self._load_attr_from_module(key, name)
201
            for key, name in self._config_mapping.items()
202
            if key in self._model_mapping.keys()
203
        ]
204
        return mapping_keys + list(self._extra_content.keys())
205

206
    def get(self, key, default):
207
        try:
208
            return self.__getitem__(key)
209
        except KeyError:
210
            return default
211

212
    def __bool__(self):
213
        return bool(self.keys())
214

215
    def values(self):
216
        mapping_values = [
217
            self._load_attr_from_module(key, name)
218
            for key, name in self._model_mapping.items()
219
            if key in self._config_mapping.keys()
220
        ]
221
        return mapping_values + list(self._extra_content.values())
222

223
    def items(self):
224
        mapping_items = [
225
            (
226
                self._load_attr_from_module(key, self._config_mapping[key]),
227
                self._load_attr_from_module(key, self._model_mapping[key]),
228
            )
229
            for key in self._model_mapping.keys()
230
            if key in self._config_mapping.keys()
231
        ]
232
        return mapping_items + list(self._extra_content.items())
233

234
    def __iter__(self):
235
        return iter(self.keys())
236

237
    def __contains__(self, item):
238
        if item in self._extra_content:
239
            return True
240
        if not hasattr(item, "__name__") or item.__name__ not in self._reverse_config_mapping:
241
            return False
242
        model_type = self._reverse_config_mapping[item.__name__]
243
        return model_type in self._model_mapping
244

245
    def register(self, key, value):
246
        """
247
        Register a new model in this mapping.
248
        """
249
        if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
250
            model_type = self._reverse_config_mapping[key.__name__]
251
            if model_type in self._model_mapping.keys():
252
                raise ValueError(f"'{key}' is already used by a Transformers model.")
253

254
        self._extra_content[key] = value
255

256

257

258
LAZY_DELTA_MAPPING = _LazyAutoMapping(DELTA_CONFIG_MAPPING, DELTA_MODEL_MAPPING)
259

260

261

262
def get_values(model_mapping):
263
    result = []
264
    for model in model_mapping.values():
265
        if isinstance(model, (list, tuple)):
266
            result += list(model)
267
        else:
268
            result.append(model)
269

270
    return result
271

272

273
def getattribute_from_module(module, attr):
274
    if attr is None:
275
        return None
276
    if isinstance(attr, tuple):
277
        return tuple(getattribute_from_module(module, a) for a in attr)
278
    if hasattr(module, attr):
279
        return getattr(module, attr)
280
    # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the
281
    # object at the top level.
282
    transformers_module = importlib.import_module("transformers")
283
    return getattribute_from_module(transformers_module, attr)
284

285

286

287
class AutoDeltaModel:
288
    r"""
289
    """
290
    _delta_model_mapping = LAZY_DELTA_MAPPING
291
    def __init__(self, *args, **kwargs):
292
        # raise EnvironmentError(
293
        #     f"{self.__class__.__name__} is designed to be instantiated "
294
        #     f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
295
        #     f"`{self.__class__.__name__}.from_config(config)` methods."
296
        # )
297

298
        raise AttributeError(
299
            f"{self.__class__.__name__} is designed to be instantiated using\n\t(1) `{self.__class__.__name__}.from_finetuned(finetuned_delta_path, backbone_model, *model_args, **kwargs)`\nor\t(2) `{self.__class__.__name__}.from_config(delta_config, backbone_model, **kwargs)`"
300
        )
301

302
    @classmethod
303
    def from_config(cls, config, backbone_model, **kwargs) -> DeltaBase:
304
        r"""Automatically instantiates a delta model based on the :obj:`config`. The delta model correspond to the delta
305
        :obj:`config` will be loaded and initialized using the arguments in :obj:`config`.
306

307
        .. note::
308
            Only using :meth:`from_config` method will not load the finetuned weight file (e.g., pytorch_model.bin).
309
            Please use from_finetuned directly.
310

311
        Args:
312
            config (:obj:`BaseDeltaConfig`):
313
            backbone_model (:obj:`nn.Module`):
314

315
        Examples:
316

317
        .. code-block:: python
318

319
            config = AutoDeltaConfig.from_finetuned("DeltaHub/lora_t5-base_mrpc")
320
            delta_model = AutoDeltaModel.from_config(config, backbone_model)
321

322
        """
323
        if type(config) in cls._delta_model_mapping.keys():
324
            model_class = cls._delta_model_mapping[type(config)]
325
            return model_class.from_config(config, backbone_model, **kwargs)
326

327
        raise ValueError(
328
            f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
329
            f"Model type should be one of {', '.join(c.__name__ for c in cls._delta_model_mapping.keys())}."
330
        )
331

332
    @classmethod
333
    def from_finetuned(cls, finetuned_delta_path, backbone_model, *model_args, **kwargs) -> DeltaBase:
334
        r""" Automatically instantiated a delta model and load the finetuned checkpoints based on the
335
        :obj:`finetuned_delta_path`, which can either be a string pointing to a local path or a url pointint to
336
        the delta hub. It will check the hash after loading the delta model to see whether the correct backbone and
337
        delta checkpoint are used.
338

339
        Args:
340
            finetuned_delta_path (:obj:`str` or :obj:`os.PathLike`, *optional*): Can be either: 
341

342
                - A string, the model name of a finetuned delta model configuration hosted inside a model repo on `Delta Center <https://www.openbmb.org/toolKits/deltacenter>`_, like ``thunlp/FactQA_T5-large_Adapter``.
343
                - A path to a directory containing a configuration file saved using the :meth:`~opendelta.utils.saving_loading_utils.SaveLoadMixin.save_finetuned` method, e.g., ``./my_model_directory/``.
344
                - A path or url to a saved configuration JSON *file*, e.g., ``./my_model_directory/configuration.json``.The last two option are not tested but inherited from huggingface.
345

346
            backbone_model (:obj:`nn.Module`): The backbone model to be modified.
347
            model_args: Other argument for initialize the model. See :`DeltaBase.from_finetuned` for details.
348
            kwargs: Other kwargs that will be passed into DeltaBase.from_finetuned. See `DeltaBase.from_finetuned` for details.
349

350
        Example:
351

352
        .. code-block:: python
353

354
            delta_model = AutoDeltaModel.from_finetuned("thunlp/FactQA_T5-large_Adapter", backbone_model=5)
355

356
        """
357
        delta_config = kwargs.pop("delta_config", None)
358

359
        if not isinstance(delta_config, BaseDeltaConfig):
360
            delta_config, kwargs = AutoDeltaConfig.from_finetuned(
361
                finetuned_delta_path, return_unused_kwargs=True, **kwargs
362
            )
363
        if type(delta_config) in cls._delta_model_mapping.keys():
364
            model_class = cls._delta_model_mapping[type(delta_config)]
365
            return model_class.from_finetuned(finetuned_delta_path, backbone_model, *model_args, delta_config=delta_config,  **kwargs)
366
        raise ValueError(
367
            f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
368
            f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
369
        )
370

371

372

373

374

375
if __name__ == "__main__":
376

377
    config = AutoDeltaConfig.from_dict({"delta_type":"lora", "lora_r": 7})
378

379

380
    from transformers import AutoModelForSequenceClassification
381
    model = AutoModelForSequenceClassification.from_pretrained("../../plm_cache/roberta-base/", num_labels=2)
382
    # from IPython import embed
383
    delta_model = AutoDeltaModel.from_config(config, model)
384
    delta_model.freeze_module(exclude = ['deltas','classifier'], set_state_dict = True)
385

386

387
    # delta_model.save_finetuned("autodelta_try", push_to_hub=True, private=True)
388
    delta_model = AutoDeltaModel.from_finetuned("ShengdingHu/autodelta_try", model, use_auth_token=True)
389

390

391

392

393

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.