pytorch-lightning

Форк
0
378 строк · 14.2 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import ast
16
import contextlib
17
import csv
18
import inspect
19
import logging
20
import os
21
from argparse import Namespace
22
from copy import deepcopy
23
from enum import Enum
24
from pathlib import Path
25
from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union
26
from warnings import warn
27

28
import torch
29
import yaml
30
from lightning_utilities.core.apply_func import apply_to_collection
31

32
import lightning.pytorch as pl
33
from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem
34
from lightning.fabric.utilities.cloud_io import _load as pl_load
35
from lightning.fabric.utilities.data import AttributeDict
36
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
37
from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator
38
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
39
from lightning.pytorch.utilities.migration import pl_legacy_patch
40
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
41
from lightning.pytorch.utilities.model_helpers import is_overridden
42
from lightning.pytorch.utilities.parsing import parse_class_init_keys
43
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
44

45
if TYPE_CHECKING:
46
    from torch.storage import UntypedStorage
47

48
log = logging.getLogger(__name__)
49
# the older shall be on the top
50
CHECKPOINT_PAST_HPARAMS_KEYS = ("hparams", "module_arguments")  # used in 0.7.6
51

52

53
def _load_from_checkpoint(
54
    cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
55
    checkpoint_path: Union[_PATH, IO],
56
    map_location: _MAP_LOCATION_TYPE = None,
57
    hparams_file: Optional[_PATH] = None,
58
    strict: Optional[bool] = None,
59
    **kwargs: Any,
60
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
61
    map_location = map_location or _default_map_location
62
    with pl_legacy_patch():
63
        checkpoint = pl_load(checkpoint_path, map_location=map_location)
64

65
    # convert legacy checkpoints to the new format
66
    checkpoint = _pl_migrate_checkpoint(
67
        checkpoint, checkpoint_path=(checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None)
68
    )
69

70
    if hparams_file is not None:
71
        extension = str(hparams_file).split(".")[-1]
72
        if extension.lower() == "csv":
73
            hparams = load_hparams_from_tags_csv(hparams_file)
74
        elif extension.lower() in ("yml", "yaml"):
75
            hparams = load_hparams_from_yaml(hparams_file)
76
        else:
77
            raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")
78

79
        # overwrite hparams by the given file
80
        checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
81

82
    # TODO: make this a migration:
83
    # for past checkpoint need to add the new key
84
    checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {})
85
    # override the hparams with values that were passed in
86
    checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
87

88
    if issubclass(cls, pl.LightningDataModule):
89
        return _load_state(cls, checkpoint, **kwargs)
90
    if issubclass(cls, pl.LightningModule):
91
        model = _load_state(cls, checkpoint, strict=strict, **kwargs)
92
        state_dict = checkpoint["state_dict"]
93
        if not state_dict:
94
            rank_zero_warn(f"The state dict in {checkpoint_path!r} contains no parameters.")
95
            return model
96

97
        device = next((t for t in state_dict.values() if isinstance(t, torch.Tensor)), torch.tensor(0)).device
98
        assert isinstance(model, pl.LightningModule)
99
        return model.to(device)
100

101
    raise NotImplementedError(f"Unsupported {cls}")
102

103

104
def _default_map_location(storage: "UntypedStorage", location: str) -> Optional["UntypedStorage"]:
105
    if (
106
        location.startswith("mps")
107
        and not MPSAccelerator.is_available()
108
        or location.startswith("cuda")
109
        and not CUDAAccelerator.is_available()
110
        or location.startswith("xla")
111
        and not XLAAccelerator.is_available()
112
    ):
113
        return storage.cpu()
114
    return None  # default behavior by `torch.load()`
115

116

117
def _load_state(
118
    cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
119
    checkpoint: Dict[str, Any],
120
    strict: Optional[bool] = None,
121
    **cls_kwargs_new: Any,
122
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
123
    cls_spec = inspect.getfullargspec(cls.__init__)
124
    cls_init_args_name = inspect.signature(cls.__init__).parameters.keys()
125

126
    self_var, args_var, kwargs_var = parse_class_init_keys(cls)
127
    drop_names = [n for n in (self_var, args_var, kwargs_var) if n]
128
    cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name))
129

130
    cls_kwargs_loaded = {}
131
    # pass in the values we saved automatically
132
    if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
133
        if issubclass(cls, pl.LightningModule):
134
            # TODO: make this a migration:
135
            # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
136
            for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
137
                cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))
138

139
        # 2. Try to restore model hparams from checkpoint using the new key
140
        cls_kwargs_loaded.update(checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_KEY, {}))
141

142
        # 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace
143
        cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
144

145
        # 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priority
146
        args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
147
        if args_name and args_name in cls_init_args_name:
148
            cls_kwargs_loaded = {args_name: cls_kwargs_loaded}
149

150
    _cls_kwargs = {}
151
    _cls_kwargs.update(cls_kwargs_loaded)
152
    _cls_kwargs.update(cls_kwargs_new)
153

154
    instantiator = None
155
    instantiator_path = _cls_kwargs.pop("_instantiator", None)
156
    if instantiator_path is not None:
157
        # import custom instantiator
158
        module_path, name = instantiator_path.rsplit(".", 1)
159
        instantiator = getattr(__import__(module_path, fromlist=[name]), name)
160

161
    if not cls_spec.varkw:
162
        # filter kwargs according to class init unless it allows any argument via kwargs
163
        _cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}
164

165
    obj = instantiator(cls, _cls_kwargs) if instantiator else cls(**_cls_kwargs)
166

167
    if isinstance(obj, pl.LightningDataModule):
168
        if obj.__class__.__qualname__ in checkpoint:
169
            obj.load_state_dict(checkpoint[obj.__class__.__qualname__])
170
        return obj
171

172
    if isinstance(obj, pl.LightningModule):
173
        if obj._strict_loading is not None and strict is not None and strict != obj.strict_loading:
174
            raise ValueError(
175
                f"You set `.load_from_checkpoint(..., strict={strict!r})` which is in conflict with"
176
                f" `{cls.__name__}.strict_loading={obj.strict_loading!r}. Please set the same value for both of them."
177
            )
178
        strict = obj.strict_loading if strict is None else strict
179

180
        if is_overridden("configure_model", obj):
181
            obj.configure_model()
182

183
        # give model a chance to load something
184
        obj.on_load_checkpoint(checkpoint)
185

186
    # load the state_dict on the model automatically
187
    keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
188

189
    if not strict:
190
        if keys.missing_keys:
191
            rank_zero_warn(
192
                f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
193
            )
194
        if keys.unexpected_keys:
195
            rank_zero_warn(
196
                f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
197
            )
198

199
    return obj
200

201

202
def _convert_loaded_hparams(
203
    model_args: Dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None
204
) -> Dict[str, Any]:
205
    """Convert hparams according given type in callable or string (past) format."""
206
    # if not hparams type define
207
    if not hparams_type:
208
        return model_args
209
    # if past checkpoint loaded, convert str to callable
210
    if isinstance(hparams_type, str):
211
        hparams_type = AttributeDict
212
    # convert hparams
213
    return hparams_type(model_args)
214

215

216
def update_hparams(hparams: dict, updates: dict) -> None:
217
    """Overrides hparams with new values.
218

219
    >>> hparams = {'c': 4}
220
    >>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1})
221
    >>> hparams['a']['b'], hparams['c']
222
    (2, 1)
223
    >>> update_hparams(hparams, {'a': {'b': 4}, 'c': 7})
224
    >>> hparams['a']['b'], hparams['c']
225
    (4, 7)
226

227
    Args:
228
        hparams: the original params and also target object
229
        updates: new params to be used as update
230

231
    """
232
    for k, v in updates.items():
233
        # if missing, add the key
234
        if k not in hparams:
235
            hparams[k] = v
236
            continue
237

238
        # recurse if dictionary
239
        if isinstance(v, dict):
240
            update_hparams(hparams[k], updates[k])
241
        else:
242
            # update the value
243
            hparams.update({k: v})
244

245

246
def load_hparams_from_tags_csv(tags_csv: _PATH) -> Dict[str, Any]:
247
    """Load hparams from a file.
248

249
    >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
250
    >>> path_csv = os.path.join('.', 'testing-hparams.csv')
251
    >>> save_hparams_to_tags_csv(path_csv, hparams)
252
    >>> hparams_new = load_hparams_from_tags_csv(path_csv)
253
    >>> vars(hparams) == hparams_new
254
    True
255
    >>> os.remove(path_csv)
256

257
    """
258
    fs = get_filesystem(tags_csv)
259
    if not fs.exists(tags_csv):
260
        rank_zero_warn(f"Missing Tags: {tags_csv}.", category=RuntimeWarning)
261
        return {}
262

263
    with fs.open(tags_csv, "r", newline="") as fp:
264
        csv_reader = csv.reader(fp, delimiter=",")
265
        return {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
266

267

268
def save_hparams_to_tags_csv(tags_csv: _PATH, hparams: Union[dict, Namespace]) -> None:
269
    fs = get_filesystem(tags_csv)
270
    if not _is_dir(fs, os.path.dirname(tags_csv)):
271
        raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")
272

273
    if isinstance(hparams, Namespace):
274
        hparams = vars(hparams)
275

276
    with fs.open(tags_csv, "w", newline="") as fp:
277
        fieldnames = ["key", "value"]
278
        writer = csv.DictWriter(fp, fieldnames=fieldnames)
279
        writer.writerow({"key": "key", "value": "value"})
280
        for k, v in hparams.items():
281
            writer.writerow({"key": k, "value": v})
282

283

284
def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Dict[str, Any]:
285
    """Load hparams from a file.
286

287
        Args:
288
            config_yaml: Path to config yaml file
289
            use_omegaconf: If omegaconf is available and ``use_omegaconf=True``,
290
                the hparams will be converted to ``DictConfig`` if possible.
291

292
    >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
293
    >>> path_yaml = './testing-hparams.yaml'
294
    >>> save_hparams_to_yaml(path_yaml, hparams)
295
    >>> hparams_new = load_hparams_from_yaml(path_yaml)
296
    >>> vars(hparams) == hparams_new
297
    True
298
    >>> os.remove(path_yaml)
299

300
    """
301
    fs = get_filesystem(config_yaml)
302
    if not fs.exists(config_yaml):
303
        rank_zero_warn(f"Missing Tags: {config_yaml}.", category=RuntimeWarning)
304
        return {}
305

306
    with fs.open(config_yaml, "r") as fp:
307
        hparams = yaml.full_load(fp)
308

309
    if _OMEGACONF_AVAILABLE and use_omegaconf:
310
        from omegaconf import OmegaConf
311
        from omegaconf.errors import UnsupportedValueType, ValidationError
312

313
        with contextlib.suppress(UnsupportedValueType, ValidationError):
314
            return OmegaConf.create(hparams)
315
    return hparams
316

317

318
def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], use_omegaconf: bool = True) -> None:
319
    """
320
    Args:
321
        config_yaml: path to new YAML file
322
        hparams: parameters to be saved
323
        use_omegaconf: If omegaconf is available and ``use_omegaconf=True``,
324
            the hparams will be converted to ``DictConfig`` if possible.
325

326
    """
327
    fs = get_filesystem(config_yaml)
328
    if not _is_dir(fs, os.path.dirname(config_yaml)):
329
        raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")
330

331
    # convert Namespace or AD to dict
332
    if isinstance(hparams, Namespace):
333
        hparams = vars(hparams)
334
    elif isinstance(hparams, AttributeDict):
335
        hparams = dict(hparams)
336

337
    # saving with OmegaConf objects
338
    if _OMEGACONF_AVAILABLE and use_omegaconf:
339
        from omegaconf import OmegaConf
340
        from omegaconf.dictconfig import DictConfig
341
        from omegaconf.errors import UnsupportedValueType, ValidationError
342

343
        # deepcopy: hparams from user shouldn't be resolved
344
        hparams = deepcopy(hparams)
345
        hparams = apply_to_collection(hparams, DictConfig, OmegaConf.to_container, resolve=True)
346
        with fs.open(config_yaml, "w", encoding="utf-8") as fp:
347
            try:
348
                OmegaConf.save(hparams, fp)
349
                return
350
            except (UnsupportedValueType, ValidationError):
351
                pass
352

353
    if not isinstance(hparams, dict):
354
        raise TypeError("hparams must be dictionary")
355

356
    hparams_allowed = {}
357
    # drop parameters which contain some strange datatypes as fsspec
358
    for k, v in hparams.items():
359
        try:
360
            v = v.name if isinstance(v, Enum) else v
361
            yaml.dump(v)
362
        except TypeError:
363
            warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")
364
            hparams[k] = type(v).__name__
365
        else:
366
            hparams_allowed[k] = v
367

368
    # saving the standard way
369
    with fs.open(config_yaml, "w", newline="") as fp:
370
        yaml.dump(hparams_allowed, fp)
371

372

373
def convert(val: str) -> Union[int, float, bool, str]:
374
    try:
375
        return ast.literal_eval(val)
376
    except (ValueError, SyntaxError) as err:
377
        log.debug(err)
378
        return val
379

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.