pytorch-lightning
378 строк · 14.2 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import ast
16import contextlib
17import csv
18import inspect
19import logging
20import os
21from argparse import Namespace
22from copy import deepcopy
23from enum import Enum
24from pathlib import Path
25from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union
26from warnings import warn
27
28import torch
29import yaml
30from lightning_utilities.core.apply_func import apply_to_collection
31
32import lightning.pytorch as pl
33from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem
34from lightning.fabric.utilities.cloud_io import _load as pl_load
35from lightning.fabric.utilities.data import AttributeDict
36from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
37from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator
38from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
39from lightning.pytorch.utilities.migration import pl_legacy_patch
40from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
41from lightning.pytorch.utilities.model_helpers import is_overridden
42from lightning.pytorch.utilities.parsing import parse_class_init_keys
43from lightning.pytorch.utilities.rank_zero import rank_zero_warn
44
45if TYPE_CHECKING:
46from torch.storage import UntypedStorage
47
48log = logging.getLogger(__name__)
49# the older shall be on the top
50CHECKPOINT_PAST_HPARAMS_KEYS = ("hparams", "module_arguments") # used in 0.7.6
51
52
53def _load_from_checkpoint(
54cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
55checkpoint_path: Union[_PATH, IO],
56map_location: _MAP_LOCATION_TYPE = None,
57hparams_file: Optional[_PATH] = None,
58strict: Optional[bool] = None,
59**kwargs: Any,
60) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
61map_location = map_location or _default_map_location
62with pl_legacy_patch():
63checkpoint = pl_load(checkpoint_path, map_location=map_location)
64
65# convert legacy checkpoints to the new format
66checkpoint = _pl_migrate_checkpoint(
67checkpoint, checkpoint_path=(checkpoint_path if isinstance(checkpoint_path, (str, Path)) else None)
68)
69
70if hparams_file is not None:
71extension = str(hparams_file).split(".")[-1]
72if extension.lower() == "csv":
73hparams = load_hparams_from_tags_csv(hparams_file)
74elif extension.lower() in ("yml", "yaml"):
75hparams = load_hparams_from_yaml(hparams_file)
76else:
77raise ValueError(".csv, .yml or .yaml is required for `hparams_file`")
78
79# overwrite hparams by the given file
80checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
81
82# TODO: make this a migration:
83# for past checkpoint need to add the new key
84checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {})
85# override the hparams with values that were passed in
86checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
87
88if issubclass(cls, pl.LightningDataModule):
89return _load_state(cls, checkpoint, **kwargs)
90if issubclass(cls, pl.LightningModule):
91model = _load_state(cls, checkpoint, strict=strict, **kwargs)
92state_dict = checkpoint["state_dict"]
93if not state_dict:
94rank_zero_warn(f"The state dict in {checkpoint_path!r} contains no parameters.")
95return model
96
97device = next((t for t in state_dict.values() if isinstance(t, torch.Tensor)), torch.tensor(0)).device
98assert isinstance(model, pl.LightningModule)
99return model.to(device)
100
101raise NotImplementedError(f"Unsupported {cls}")
102
103
104def _default_map_location(storage: "UntypedStorage", location: str) -> Optional["UntypedStorage"]:
105if (
106location.startswith("mps")
107and not MPSAccelerator.is_available()
108or location.startswith("cuda")
109and not CUDAAccelerator.is_available()
110or location.startswith("xla")
111and not XLAAccelerator.is_available()
112):
113return storage.cpu()
114return None # default behavior by `torch.load()`
115
116
117def _load_state(
118cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
119checkpoint: Dict[str, Any],
120strict: Optional[bool] = None,
121**cls_kwargs_new: Any,
122) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
123cls_spec = inspect.getfullargspec(cls.__init__)
124cls_init_args_name = inspect.signature(cls.__init__).parameters.keys()
125
126self_var, args_var, kwargs_var = parse_class_init_keys(cls)
127drop_names = [n for n in (self_var, args_var, kwargs_var) if n]
128cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name))
129
130cls_kwargs_loaded = {}
131# pass in the values we saved automatically
132if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
133if issubclass(cls, pl.LightningModule):
134# TODO: make this a migration:
135# 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
136for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
137cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))
138
139# 2. Try to restore model hparams from checkpoint using the new key
140cls_kwargs_loaded.update(checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_KEY, {}))
141
142# 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace
143cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
144
145# 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priority
146args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
147if args_name and args_name in cls_init_args_name:
148cls_kwargs_loaded = {args_name: cls_kwargs_loaded}
149
150_cls_kwargs = {}
151_cls_kwargs.update(cls_kwargs_loaded)
152_cls_kwargs.update(cls_kwargs_new)
153
154instantiator = None
155instantiator_path = _cls_kwargs.pop("_instantiator", None)
156if instantiator_path is not None:
157# import custom instantiator
158module_path, name = instantiator_path.rsplit(".", 1)
159instantiator = getattr(__import__(module_path, fromlist=[name]), name)
160
161if not cls_spec.varkw:
162# filter kwargs according to class init unless it allows any argument via kwargs
163_cls_kwargs = {k: v for k, v in _cls_kwargs.items() if k in cls_init_args_name}
164
165obj = instantiator(cls, _cls_kwargs) if instantiator else cls(**_cls_kwargs)
166
167if isinstance(obj, pl.LightningDataModule):
168if obj.__class__.__qualname__ in checkpoint:
169obj.load_state_dict(checkpoint[obj.__class__.__qualname__])
170return obj
171
172if isinstance(obj, pl.LightningModule):
173if obj._strict_loading is not None and strict is not None and strict != obj.strict_loading:
174raise ValueError(
175f"You set `.load_from_checkpoint(..., strict={strict!r})` which is in conflict with"
176f" `{cls.__name__}.strict_loading={obj.strict_loading!r}. Please set the same value for both of them."
177)
178strict = obj.strict_loading if strict is None else strict
179
180if is_overridden("configure_model", obj):
181obj.configure_model()
182
183# give model a chance to load something
184obj.on_load_checkpoint(checkpoint)
185
186# load the state_dict on the model automatically
187keys = obj.load_state_dict(checkpoint["state_dict"], strict=strict)
188
189if not strict:
190if keys.missing_keys:
191rank_zero_warn(
192f"Found keys that are in the model state dict but not in the checkpoint: {keys.missing_keys}"
193)
194if keys.unexpected_keys:
195rank_zero_warn(
196f"Found keys that are not in the model state dict but in the checkpoint: {keys.unexpected_keys}"
197)
198
199return obj
200
201
202def _convert_loaded_hparams(
203model_args: Dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None
204) -> Dict[str, Any]:
205"""Convert hparams according given type in callable or string (past) format."""
206# if not hparams type define
207if not hparams_type:
208return model_args
209# if past checkpoint loaded, convert str to callable
210if isinstance(hparams_type, str):
211hparams_type = AttributeDict
212# convert hparams
213return hparams_type(model_args)
214
215
216def update_hparams(hparams: dict, updates: dict) -> None:
217"""Overrides hparams with new values.
218
219>>> hparams = {'c': 4}
220>>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1})
221>>> hparams['a']['b'], hparams['c']
222(2, 1)
223>>> update_hparams(hparams, {'a': {'b': 4}, 'c': 7})
224>>> hparams['a']['b'], hparams['c']
225(4, 7)
226
227Args:
228hparams: the original params and also target object
229updates: new params to be used as update
230
231"""
232for k, v in updates.items():
233# if missing, add the key
234if k not in hparams:
235hparams[k] = v
236continue
237
238# recurse if dictionary
239if isinstance(v, dict):
240update_hparams(hparams[k], updates[k])
241else:
242# update the value
243hparams.update({k: v})
244
245
246def load_hparams_from_tags_csv(tags_csv: _PATH) -> Dict[str, Any]:
247"""Load hparams from a file.
248
249>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
250>>> path_csv = os.path.join('.', 'testing-hparams.csv')
251>>> save_hparams_to_tags_csv(path_csv, hparams)
252>>> hparams_new = load_hparams_from_tags_csv(path_csv)
253>>> vars(hparams) == hparams_new
254True
255>>> os.remove(path_csv)
256
257"""
258fs = get_filesystem(tags_csv)
259if not fs.exists(tags_csv):
260rank_zero_warn(f"Missing Tags: {tags_csv}.", category=RuntimeWarning)
261return {}
262
263with fs.open(tags_csv, "r", newline="") as fp:
264csv_reader = csv.reader(fp, delimiter=",")
265return {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
266
267
268def save_hparams_to_tags_csv(tags_csv: _PATH, hparams: Union[dict, Namespace]) -> None:
269fs = get_filesystem(tags_csv)
270if not _is_dir(fs, os.path.dirname(tags_csv)):
271raise RuntimeError(f"Missing folder: {os.path.dirname(tags_csv)}.")
272
273if isinstance(hparams, Namespace):
274hparams = vars(hparams)
275
276with fs.open(tags_csv, "w", newline="") as fp:
277fieldnames = ["key", "value"]
278writer = csv.DictWriter(fp, fieldnames=fieldnames)
279writer.writerow({"key": "key", "value": "value"})
280for k, v in hparams.items():
281writer.writerow({"key": k, "value": v})
282
283
284def load_hparams_from_yaml(config_yaml: _PATH, use_omegaconf: bool = True) -> Dict[str, Any]:
285"""Load hparams from a file.
286
287Args:
288config_yaml: Path to config yaml file
289use_omegaconf: If omegaconf is available and ``use_omegaconf=True``,
290the hparams will be converted to ``DictConfig`` if possible.
291
292>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
293>>> path_yaml = './testing-hparams.yaml'
294>>> save_hparams_to_yaml(path_yaml, hparams)
295>>> hparams_new = load_hparams_from_yaml(path_yaml)
296>>> vars(hparams) == hparams_new
297True
298>>> os.remove(path_yaml)
299
300"""
301fs = get_filesystem(config_yaml)
302if not fs.exists(config_yaml):
303rank_zero_warn(f"Missing Tags: {config_yaml}.", category=RuntimeWarning)
304return {}
305
306with fs.open(config_yaml, "r") as fp:
307hparams = yaml.full_load(fp)
308
309if _OMEGACONF_AVAILABLE and use_omegaconf:
310from omegaconf import OmegaConf
311from omegaconf.errors import UnsupportedValueType, ValidationError
312
313with contextlib.suppress(UnsupportedValueType, ValidationError):
314return OmegaConf.create(hparams)
315return hparams
316
317
318def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], use_omegaconf: bool = True) -> None:
319"""
320Args:
321config_yaml: path to new YAML file
322hparams: parameters to be saved
323use_omegaconf: If omegaconf is available and ``use_omegaconf=True``,
324the hparams will be converted to ``DictConfig`` if possible.
325
326"""
327fs = get_filesystem(config_yaml)
328if not _is_dir(fs, os.path.dirname(config_yaml)):
329raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.")
330
331# convert Namespace or AD to dict
332if isinstance(hparams, Namespace):
333hparams = vars(hparams)
334elif isinstance(hparams, AttributeDict):
335hparams = dict(hparams)
336
337# saving with OmegaConf objects
338if _OMEGACONF_AVAILABLE and use_omegaconf:
339from omegaconf import OmegaConf
340from omegaconf.dictconfig import DictConfig
341from omegaconf.errors import UnsupportedValueType, ValidationError
342
343# deepcopy: hparams from user shouldn't be resolved
344hparams = deepcopy(hparams)
345hparams = apply_to_collection(hparams, DictConfig, OmegaConf.to_container, resolve=True)
346with fs.open(config_yaml, "w", encoding="utf-8") as fp:
347try:
348OmegaConf.save(hparams, fp)
349return
350except (UnsupportedValueType, ValidationError):
351pass
352
353if not isinstance(hparams, dict):
354raise TypeError("hparams must be dictionary")
355
356hparams_allowed = {}
357# drop parameters which contain some strange datatypes as fsspec
358for k, v in hparams.items():
359try:
360v = v.name if isinstance(v, Enum) else v
361yaml.dump(v)
362except TypeError:
363warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")
364hparams[k] = type(v).__name__
365else:
366hparams_allowed[k] = v
367
368# saving the standard way
369with fs.open(config_yaml, "w", newline="") as fp:
370yaml.dump(hparams_allowed, fp)
371
372
373def convert(val: str) -> Union[int, float, bool, str]:
374try:
375return ast.literal_eval(val)
376except (ValueError, SyntaxError) as err:
377log.debug(err)
378return val
379