pytorch-lightning
237 строк · 10.3 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""LightningDataModule for loading DataLoaders with ease."""
15
16import inspect
17from typing import IO, Any, Dict, Iterable, Optional, Union, cast
18
19from lightning_utilities import apply_to_collection
20from torch.utils.data import DataLoader, Dataset, IterableDataset
21from typing_extensions import Self
22
23import lightning.pytorch as pl
24from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
25from lightning.pytorch.core.hooks import DataHooks
26from lightning.pytorch.core.mixins import HyperparametersMixin
27from lightning.pytorch.core.saving import _load_from_checkpoint
28from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
29from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
30
31
32class LightningDataModule(DataHooks, HyperparametersMixin):
33"""A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is
34consistent data splits, data preparation and transforms across models.
35
36Example::
37
38import lightning as L
39import torch.utils.data as data
40from lightning.pytorch.demos.boring_classes import RandomDataset
41
42class MyDataModule(L.LightningDataModule):
43def prepare_data(self):
44# download, IO, etc. Useful with shared filesystems
45# only called on 1 GPU/TPU in distributed
46...
47
48def setup(self, stage):
49# make assignments here (val/train/test split)
50# called on every process in DDP
51dataset = RandomDataset(1, 100)
52self.train, self.val, self.test = data.random_split(
53dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
54)
55
56def train_dataloader(self):
57return data.DataLoader(self.train)
58
59def val_dataloader(self):
60return data.DataLoader(self.val)
61
62def test_dataloader(self):
63return data.DataLoader(self.test)
64
65def teardown(self):
66# clean up state after the trainer stops, delete files...
67# called on every process in DDP
68...
69
70"""
71
72name: Optional[str] = None
73CHECKPOINT_HYPER_PARAMS_KEY = "datamodule_hyper_parameters"
74CHECKPOINT_HYPER_PARAMS_NAME = "datamodule_hparams_name"
75CHECKPOINT_HYPER_PARAMS_TYPE = "datamodule_hparams_type"
76
77def __init__(self) -> None:
78super().__init__()
79# Pointer to the trainer object
80self.trainer: Optional["pl.Trainer"] = None
81
82@classmethod
83def from_datasets(
84cls,
85train_dataset: Optional[Union[Dataset, Iterable[Dataset]]] = None,
86val_dataset: Optional[Union[Dataset, Iterable[Dataset]]] = None,
87test_dataset: Optional[Union[Dataset, Iterable[Dataset]]] = None,
88predict_dataset: Optional[Union[Dataset, Iterable[Dataset]]] = None,
89batch_size: int = 1,
90num_workers: int = 0,
91**datamodule_kwargs: Any,
92) -> "LightningDataModule":
93r"""Create an instance from torch.utils.data.Dataset.
94
95Args:
96train_dataset: Optional dataset or iterable of datasets to be used for train_dataloader()
97val_dataset: Optional dataset or iterable of datasets to be used for val_dataloader()
98test_dataset: Optional dataset or iterable of datasets to be used for test_dataloader()
99predict_dataset: Optional dataset or iterable of datasets to be used for predict_dataloader()
100batch_size: Batch size to use for each dataloader. Default is 1. This parameter gets forwarded to the
101``__init__`` if the datamodule has such a name defined in its signature.
102num_workers: Number of subprocesses to use for data loading. 0 means that the
103data will be loaded in the main process. Number of CPUs available. This parameter gets forwarded to the
104``__init__`` if the datamodule has such a name defined in its signature.
105**datamodule_kwargs: Additional parameters that get passed down to the datamodule's ``__init__``.
106
107"""
108
109def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
110shuffle &= not isinstance(ds, IterableDataset)
111return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
112
113def train_dataloader() -> TRAIN_DATALOADERS:
114return apply_to_collection(train_dataset, Dataset, dataloader, shuffle=True)
115
116def val_dataloader() -> EVAL_DATALOADERS:
117return apply_to_collection(val_dataset, Dataset, dataloader)
118
119def test_dataloader() -> EVAL_DATALOADERS:
120return apply_to_collection(test_dataset, Dataset, dataloader)
121
122def predict_dataloader() -> EVAL_DATALOADERS:
123return apply_to_collection(predict_dataset, Dataset, dataloader)
124
125candidate_kwargs = {"batch_size": batch_size, "num_workers": num_workers}
126accepted_params = inspect.signature(cls.__init__).parameters
127accepts_kwargs = any(param.kind == param.VAR_KEYWORD for param in accepted_params.values())
128if accepts_kwargs:
129special_kwargs = candidate_kwargs
130else:
131accepted_param_names = set(accepted_params)
132accepted_param_names.discard("self")
133special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_param_names}
134
135datamodule = cls(**datamodule_kwargs, **special_kwargs)
136if train_dataset is not None:
137datamodule.train_dataloader = train_dataloader # type: ignore[method-assign]
138if val_dataset is not None:
139datamodule.val_dataloader = val_dataloader # type: ignore[method-assign]
140if test_dataset is not None:
141datamodule.test_dataloader = test_dataloader # type: ignore[method-assign]
142if predict_dataset is not None:
143datamodule.predict_dataloader = predict_dataloader # type: ignore[method-assign]
144return datamodule
145
146def state_dict(self) -> Dict[str, Any]:
147"""Called when saving a checkpoint, implement to generate and save datamodule state.
148
149Returns:
150A dictionary containing datamodule state.
151
152"""
153return {}
154
155def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
156"""Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
157
158Args:
159state_dict: the datamodule state returned by ``state_dict``.
160
161"""
162pass
163
164@_restricted_classmethod
165def load_from_checkpoint(
166cls,
167checkpoint_path: Union[_PATH, IO],
168map_location: _MAP_LOCATION_TYPE = None,
169hparams_file: Optional[_PATH] = None,
170**kwargs: Any,
171) -> Self:
172r"""Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint it stores the
173arguments passed to ``__init__`` in the checkpoint under ``"datamodule_hyper_parameters"``.
174
175Any arguments specified through \*\*kwargs will override args stored in ``"datamodule_hyper_parameters"``.
176
177Args:
178checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
179map_location:
180If your checkpoint saved a GPU model and you now load on CPUs
181or a different number of GPUs, use this to map to the new setup.
182The behaviour is the same as in :func:`torch.load`.
183hparams_file: Optional path to a ``.yaml`` or ``.csv`` file with hierarchical structure
184as in this example::
185
186dataloader:
187batch_size: 32
188
189You most likely won't need this since Lightning will always save the hyperparameters
190to the checkpoint.
191However, if your checkpoint weights don't have the hyperparameters saved,
192use this method to pass in a ``.yaml`` file with the hparams you'd like to use.
193These will be converted into a :class:`~dict` and passed into your
194:class:`LightningDataModule` for use.
195
196If your datamodule's ``hparams`` argument is :class:`~argparse.Namespace`
197and ``.yaml`` file has hierarchical structure, you need to refactor your datamodule to treat
198``hparams`` as :class:`~dict`.
199\**kwargs: Any extra keyword args needed to init the datamodule. Can also be used to override saved
200hyperparameter values.
201
202Return:
203:class:`LightningDataModule` instance with loaded weights and hyperparameters (if available).
204
205Note:
206``load_from_checkpoint`` is a **class** method. You must use your :class:`LightningDataModule`
207**class** to call it instead of the :class:`LightningDataModule` instance, or a
208``TypeError`` will be raised.
209
210Example::
211
212# load weights without mapping ...
213datamodule = MyLightningDataModule.load_from_checkpoint('path/to/checkpoint.ckpt')
214
215# or load weights and hyperparameters from separate files.
216datamodule = MyLightningDataModule.load_from_checkpoint(
217'path/to/checkpoint.ckpt',
218hparams_file='/path/to/hparams_file.yaml'
219)
220
221# override some of the params with new values
222datamodule = MyLightningDataModule.load_from_checkpoint(
223PATH,
224batch_size=32,
225num_workers=10,
226)
227
228"""
229loaded = _load_from_checkpoint(
230cls, # type: ignore[arg-type]
231checkpoint_path,
232map_location=map_location,
233hparams_file=hparams_file,
234strict=None,
235**kwargs,
236)
237return cast(Self, loaded)
238