pytorch-lightning

Форк
0
400 строк · 11.2 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import ast
16
import inspect
17
from pathlib import Path
18
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, Union
19

20
if TYPE_CHECKING:
21
    from lightning.app.core import LightningFlow, LightningWork
22

23

24
class LightningVisitor(ast.NodeVisitor):
25
    """Base class for visitor that finds class definitions based on class inheritance. Derived classes are expected to
26
    define class_name and implement the analyze_class_def method.
27

28
    Attributes
29
    ----------
30
    class_name: str
31
        Name of class to identify, to be defined in subclasses.
32

33
    """
34

35
    class_name: Optional[str] = None
36

37
    def __init__(self):
38
        self.found: List[Dict[str, Any]] = []
39

40
    def analyze_class_def(self, node: ast.ClassDef) -> Dict[str, Any]:
41
        return {}
42

43
    def visit_ClassDef(self, node: ast.ClassDef) -> None:
44
        bases = []
45
        for base in node.bases:
46
            if type(base) == ast.Attribute:
47
                bases.append(base.attr)
48
            elif type(base) == ast.Name:
49
                bases.append(base.id)
50
        if self.class_name in bases:
51
            entry = {"name": node.name, "type": self.class_name}
52
            entry.update(self.analyze_class_def(node))
53
            self.found.append(entry)
54

55

56
class LightningModuleVisitor(LightningVisitor):
57
    """Finds Lightning modules based on class inheritance.
58

59
    Attributes
60
    ----------
61
    class_name: Optional[str]
62
        Name of class to identify.
63
    methods: Set[str]
64
        Names of methods that are part of the LightningModule API.
65
    hooks: Set[str]
66
        Names of hooks that are part of the LightningModule API.
67

68
    """
69

70
    class_name: Optional[str] = "LightningModule"
71

72
    methods: Set[str] = {
73
        "configure_optimizers",
74
        "forward",
75
        "freeze",
76
        "log",
77
        "log_dict",
78
        "print",
79
        "save_hyperparameters",
80
        "test_step",
81
        "test_step_end",
82
        "to_onnx",
83
        "to_torchscript",
84
        "training_step",
85
        "training_step_end",
86
        "unfreeze",
87
        "validation_step",
88
        "validation_step_end",
89
    }
90

91
    hooks: Set[str] = {
92
        "backward",
93
        "get_progress_bar_dict",
94
        "manual_backward",
95
        "manual_optimizer_step",
96
        "on_after_backward",
97
        "on_before_zero_grad",
98
        "on_fit_start",
99
        "on_fit_end",
100
        "on_load_checkpoint",
101
        "on_save_checkpoint",
102
        "on_pretrain_routine_start",
103
        "on_pretrain_routine_end",
104
        "on_test_batch_start",
105
        "on_test_batch_end",
106
        "on_test_epoch_start",
107
        "on_test_epoch_end",
108
        "on_train_batch_start",
109
        "on_train_batch_end",
110
        "on_train_epoch_start",
111
        "on_train_epoch_end",
112
        "on_validation_batch_start",
113
        "on_validation_batch_end",
114
        "on_validation_epoch_start",
115
        "on_validation_epoch_end",
116
        "optimizer_step",
117
        "optimizer_zero_grad",
118
        "prepare_data",
119
        "setup",
120
        "teardown",
121
        "train_dataloader",
122
        "val_dataloader",
123
        "test_dataloader",
124
        "transfer_batch_to_device",
125
    }
126

127

128
class LightningDataModuleVisitor(LightningVisitor):
129
    """Finds Lightning data modules based on class inheritance.
130

131
    Attributes
132
    ----------
133
    class_name: Optional[str]
134
        Name of class to identify.
135
    methods: Set[str]
136
        Names of methods that are part of the LightningDataModule API.
137

138
    """
139

140
    class_name = "LightningDataModule"
141

142
    methods: Set[str] = {
143
        "prepare_data",
144
        "setup",
145
        "train_dataloader",
146
        "val_dataloader",
147
        "test_dataloader",
148
        "transfer_batch_to_device",
149
    }
150

151

152
class LightningLoggerVisitor(LightningVisitor):
153
    """Finds Lightning loggers based on class inheritance.
154

155
    Attributes
156
    ----------
157
    class_name: Optional[str]
158
        Name of class to identify.
159
    methods: Set[str]
160
        Names of methods that are part of the Logger API.
161

162
    """
163

164
    class_name = "Logger"
165

166
    methods: Set[str] = {"log_hyperparams", "log_metrics"}
167

168

169
class LightningCallbackVisitor(LightningVisitor):
170
    """Finds Lightning callbacks based on class inheritance.
171

172
    Attributes
173
    ----------
174
    class_name: Optional[str]
175
        Name of class to identify.
176
    methods: Set[str]
177
        Names of methods that are part of the Logger API.
178

179
    """
180

181
    class_name = "Callback"
182

183
    methods: Set[str] = {
184
        "setup",
185
        "teardown",
186
        "on_init_start",
187
        "on_init_end",
188
        "on_fit_start",
189
        "on_fit_end",
190
        "on_sanity_check_start",
191
        "on_sanity_check_end",
192
        "on_train_batch_start",
193
        "on_train_batch_end",
194
        "on_train_epoch_start",
195
        "on_train_epoch_end",
196
        "on_validation_epoch_start",
197
        "on_validation_epoch_end",
198
        "on_test_epoch_start",
199
        "on_test_epoch_end",
200
        "on_epoch_start",
201
        "on_epoch_end",
202
        "on_batch_start",
203
        "on_validation_batch_start",
204
        "on_validation_batch_end",
205
        "on_test_batch_start",
206
        "on_test_batch_end",
207
        "on_batch_end",
208
        "on_train_start",
209
        "on_train_end",
210
        "on_pretrain_routine_start",
211
        "on_pretrain_routine_end",
212
        "on_validation_start",
213
        "on_validation_end",
214
        "on_test_start",
215
        "on_test_end",
216
        "on_keyboard_interrupt",
217
        "on_save_checkpoint",
218
        "on_load_checkpoint",
219
    }
220

221

222
class LightningStrategyVisitor(LightningVisitor):
223
    """Finds Lightning callbacks based on class inheritance.
224

225
    Attributes
226
    ----------
227
    class_name: Optional[str]
228
        Name of class to identify.
229
    methods: Set[str]
230
        Names of methods that are part of the Logger API.
231

232
    """
233

234
    class_name = "Strategy"
235

236
    methods: Set[str] = {
237
        "setup",
238
        "train",
239
        "training_step",
240
        "validation_step",
241
        "test_step",
242
        "backward",
243
        "barrier",
244
        "broadcast",
245
        "sync_tensor",
246
    }
247

248

249
class LightningTrainerVisitor(LightningVisitor):
250
    class_name = "Trainer"
251

252

253
class LightningCLIVisitor(LightningVisitor):
254
    class_name = "LightningCLI"
255

256

257
class LightningPrecisionPluginVisitor(LightningVisitor):
258
    class_name = "PrecisionPlugin"
259

260

261
class LightningAcceleratorVisitor(LightningVisitor):
262
    class_name = "Accelerator"
263

264

265
class TorchMetricVisitor(LightningVisitor):
266
    class_name = "Metric"
267

268

269
class FabricVisitor(LightningVisitor):
270
    class_name = "Fabric"
271

272

273
class LightningProfilerVisitor(LightningVisitor):
274
    class_name = "Profiler"
275

276

277
class Scanner:
278
    """Finds relevant Lightning objects in files in the file system.
279

280
    Attributes
281
    ----------
282
    visitor_classes: List[Type]
283
        List of visitor classes to use when traversing files.
284
    Parameters
285
    ----------
286
    path: str
287
        Path to file, or directory where to look for files to scan.
288
    glob_pattern: str
289
        Glob pattern to use when looking for files in the path,
290
        applied when path is a directory. Default is "**/*.py".
291

292
    """
293

294
    # TODO: Finalize introspecting the methods from all the discovered methods.
295
    visitor_classes: List[Type] = [
296
        LightningCLIVisitor,
297
        LightningTrainerVisitor,
298
        LightningModuleVisitor,
299
        LightningDataModuleVisitor,
300
        LightningCallbackVisitor,
301
        LightningStrategyVisitor,
302
        LightningPrecisionPluginVisitor,
303
        LightningAcceleratorVisitor,
304
        LightningLoggerVisitor,
305
        TorchMetricVisitor,
306
        FabricVisitor,
307
        LightningProfilerVisitor,
308
    ]
309

310
    def __init__(self, path: str, glob_pattern: str = "**/*.py"):
311
        path_ = Path(path)
312
        if path_.is_dir():
313
            self.paths = path_.glob(glob_pattern)
314
        else:
315
            self.paths = [path_]
316

317
        self.modules_found: List[Dict[str, Any]] = []
318

319
    def has_class(self, cls) -> bool:
320
        # This method isn't strong enough as it is using only `ImportFrom`.
321
        # TODO: Use proper classDef scanning.
322
        classes = []
323

324
        for path in self.paths:
325
            try:
326
                module = ast.parse(path.open().read())
327
            except SyntaxError:
328
                print(f"Error while parsing {path}: SKIPPING")
329
                continue
330

331
            for node in ast.walk(module):
332
                if isinstance(node, ast.ImportFrom):
333
                    for import_from_cls in node.names:
334
                        classes.append(import_from_cls.name)
335

336
                if isinstance(node, ast.Call):
337
                    cls_name = getattr(node.func, "attr", None)
338
                    if cls_name:
339
                        classes.append(cls_name)
340

341
        return cls.__name__ in classes
342

343
    def scan(self) -> List[Dict[str, str]]:
344
        """Finds Lightning modules in files, returning importable objects.
345

346
        Returns
347
        -------
348
        List[Dict[str, Any]]
349
            List of dicts containing all metadata required
350
            to import modules found.
351

352
        """
353
        modules_found: Dict[str, List[Dict[str, Any]]] = {}
354

355
        for path in self.paths:
356
            try:
357
                module = ast.parse(path.open().read())
358
            except SyntaxError:
359
                print(f"Error while parsing {path}: SKIPPING")
360
                continue
361
            for visitor_class in self.visitor_classes:
362
                visitor = visitor_class()
363
                visitor.visit(module)
364
                if not visitor.found:
365
                    continue
366
                _path = str(path)
367
                ns_info = {
368
                    "file": _path,
369
                    "namespace": _path.replace("/", ".").replace(".py", ""),
370
                }
371
                modules_found[visitor_class.class_name] = [{**entry, **ns_info} for entry in visitor.found]
372

373
        return modules_found
374

375

376
def _is_method_context(component: Union["LightningFlow", "LightningWork"], selected_caller_name: str) -> bool:
377
    """Checks whether the call to a component originates from within the context of the component's ``__init__``
378
    method."""
379
    frame = inspect.currentframe().f_back
380

381
    while frame is not None:
382
        caller_name = frame.f_code.co_name
383
        caller_self = frame.f_locals.get("self")
384
        if caller_name == selected_caller_name and caller_self is component:
385
            # the call originates from a frame under component.__init__
386
            return True
387
        frame = frame.f_back
388

389
    return False
390

391

392
def _is_init_context(component: Union["LightningFlow", "LightningWork"]) -> bool:
393
    """Checks whether the call to a component originates from within the context of the component's ``__init__``
394
    method."""
395
    return _is_method_context(component, "__init__")
396

397

398
def _is_run_context(component: Union["LightningFlow", "LightningWork"]) -> bool:
399
    """Checks whether the call to a component originates from within the context of the component's ``run`` method."""
400
    return _is_method_context(component, "run") or _is_method_context(component, "load_state_dict")
401

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.