pytorch-lightning
400 строк · 11.2 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import ast
16import inspect
17from pathlib import Path
18from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, Union
19
20if TYPE_CHECKING:
21from lightning.app.core import LightningFlow, LightningWork
22
23
24class LightningVisitor(ast.NodeVisitor):
25"""Base class for visitor that finds class definitions based on class inheritance. Derived classes are expected to
26define class_name and implement the analyze_class_def method.
27
28Attributes
29----------
30class_name: str
31Name of class to identify, to be defined in subclasses.
32
33"""
34
35class_name: Optional[str] = None
36
37def __init__(self):
38self.found: List[Dict[str, Any]] = []
39
40def analyze_class_def(self, node: ast.ClassDef) -> Dict[str, Any]:
41return {}
42
43def visit_ClassDef(self, node: ast.ClassDef) -> None:
44bases = []
45for base in node.bases:
46if type(base) == ast.Attribute:
47bases.append(base.attr)
48elif type(base) == ast.Name:
49bases.append(base.id)
50if self.class_name in bases:
51entry = {"name": node.name, "type": self.class_name}
52entry.update(self.analyze_class_def(node))
53self.found.append(entry)
54
55
56class LightningModuleVisitor(LightningVisitor):
57"""Finds Lightning modules based on class inheritance.
58
59Attributes
60----------
61class_name: Optional[str]
62Name of class to identify.
63methods: Set[str]
64Names of methods that are part of the LightningModule API.
65hooks: Set[str]
66Names of hooks that are part of the LightningModule API.
67
68"""
69
70class_name: Optional[str] = "LightningModule"
71
72methods: Set[str] = {
73"configure_optimizers",
74"forward",
75"freeze",
76"log",
77"log_dict",
78"print",
79"save_hyperparameters",
80"test_step",
81"test_step_end",
82"to_onnx",
83"to_torchscript",
84"training_step",
85"training_step_end",
86"unfreeze",
87"validation_step",
88"validation_step_end",
89}
90
91hooks: Set[str] = {
92"backward",
93"get_progress_bar_dict",
94"manual_backward",
95"manual_optimizer_step",
96"on_after_backward",
97"on_before_zero_grad",
98"on_fit_start",
99"on_fit_end",
100"on_load_checkpoint",
101"on_save_checkpoint",
102"on_pretrain_routine_start",
103"on_pretrain_routine_end",
104"on_test_batch_start",
105"on_test_batch_end",
106"on_test_epoch_start",
107"on_test_epoch_end",
108"on_train_batch_start",
109"on_train_batch_end",
110"on_train_epoch_start",
111"on_train_epoch_end",
112"on_validation_batch_start",
113"on_validation_batch_end",
114"on_validation_epoch_start",
115"on_validation_epoch_end",
116"optimizer_step",
117"optimizer_zero_grad",
118"prepare_data",
119"setup",
120"teardown",
121"train_dataloader",
122"val_dataloader",
123"test_dataloader",
124"transfer_batch_to_device",
125}
126
127
128class LightningDataModuleVisitor(LightningVisitor):
129"""Finds Lightning data modules based on class inheritance.
130
131Attributes
132----------
133class_name: Optional[str]
134Name of class to identify.
135methods: Set[str]
136Names of methods that are part of the LightningDataModule API.
137
138"""
139
140class_name = "LightningDataModule"
141
142methods: Set[str] = {
143"prepare_data",
144"setup",
145"train_dataloader",
146"val_dataloader",
147"test_dataloader",
148"transfer_batch_to_device",
149}
150
151
152class LightningLoggerVisitor(LightningVisitor):
153"""Finds Lightning loggers based on class inheritance.
154
155Attributes
156----------
157class_name: Optional[str]
158Name of class to identify.
159methods: Set[str]
160Names of methods that are part of the Logger API.
161
162"""
163
164class_name = "Logger"
165
166methods: Set[str] = {"log_hyperparams", "log_metrics"}
167
168
169class LightningCallbackVisitor(LightningVisitor):
170"""Finds Lightning callbacks based on class inheritance.
171
172Attributes
173----------
174class_name: Optional[str]
175Name of class to identify.
176methods: Set[str]
177Names of methods that are part of the Logger API.
178
179"""
180
181class_name = "Callback"
182
183methods: Set[str] = {
184"setup",
185"teardown",
186"on_init_start",
187"on_init_end",
188"on_fit_start",
189"on_fit_end",
190"on_sanity_check_start",
191"on_sanity_check_end",
192"on_train_batch_start",
193"on_train_batch_end",
194"on_train_epoch_start",
195"on_train_epoch_end",
196"on_validation_epoch_start",
197"on_validation_epoch_end",
198"on_test_epoch_start",
199"on_test_epoch_end",
200"on_epoch_start",
201"on_epoch_end",
202"on_batch_start",
203"on_validation_batch_start",
204"on_validation_batch_end",
205"on_test_batch_start",
206"on_test_batch_end",
207"on_batch_end",
208"on_train_start",
209"on_train_end",
210"on_pretrain_routine_start",
211"on_pretrain_routine_end",
212"on_validation_start",
213"on_validation_end",
214"on_test_start",
215"on_test_end",
216"on_keyboard_interrupt",
217"on_save_checkpoint",
218"on_load_checkpoint",
219}
220
221
222class LightningStrategyVisitor(LightningVisitor):
223"""Finds Lightning callbacks based on class inheritance.
224
225Attributes
226----------
227class_name: Optional[str]
228Name of class to identify.
229methods: Set[str]
230Names of methods that are part of the Logger API.
231
232"""
233
234class_name = "Strategy"
235
236methods: Set[str] = {
237"setup",
238"train",
239"training_step",
240"validation_step",
241"test_step",
242"backward",
243"barrier",
244"broadcast",
245"sync_tensor",
246}
247
248
249class LightningTrainerVisitor(LightningVisitor):
250class_name = "Trainer"
251
252
253class LightningCLIVisitor(LightningVisitor):
254class_name = "LightningCLI"
255
256
257class LightningPrecisionPluginVisitor(LightningVisitor):
258class_name = "PrecisionPlugin"
259
260
261class LightningAcceleratorVisitor(LightningVisitor):
262class_name = "Accelerator"
263
264
265class TorchMetricVisitor(LightningVisitor):
266class_name = "Metric"
267
268
269class FabricVisitor(LightningVisitor):
270class_name = "Fabric"
271
272
273class LightningProfilerVisitor(LightningVisitor):
274class_name = "Profiler"
275
276
277class Scanner:
278"""Finds relevant Lightning objects in files in the file system.
279
280Attributes
281----------
282visitor_classes: List[Type]
283List of visitor classes to use when traversing files.
284Parameters
285----------
286path: str
287Path to file, or directory where to look for files to scan.
288glob_pattern: str
289Glob pattern to use when looking for files in the path,
290applied when path is a directory. Default is "**/*.py".
291
292"""
293
294# TODO: Finalize introspecting the methods from all the discovered methods.
295visitor_classes: List[Type] = [
296LightningCLIVisitor,
297LightningTrainerVisitor,
298LightningModuleVisitor,
299LightningDataModuleVisitor,
300LightningCallbackVisitor,
301LightningStrategyVisitor,
302LightningPrecisionPluginVisitor,
303LightningAcceleratorVisitor,
304LightningLoggerVisitor,
305TorchMetricVisitor,
306FabricVisitor,
307LightningProfilerVisitor,
308]
309
310def __init__(self, path: str, glob_pattern: str = "**/*.py"):
311path_ = Path(path)
312if path_.is_dir():
313self.paths = path_.glob(glob_pattern)
314else:
315self.paths = [path_]
316
317self.modules_found: List[Dict[str, Any]] = []
318
319def has_class(self, cls) -> bool:
320# This method isn't strong enough as it is using only `ImportFrom`.
321# TODO: Use proper classDef scanning.
322classes = []
323
324for path in self.paths:
325try:
326module = ast.parse(path.open().read())
327except SyntaxError:
328print(f"Error while parsing {path}: SKIPPING")
329continue
330
331for node in ast.walk(module):
332if isinstance(node, ast.ImportFrom):
333for import_from_cls in node.names:
334classes.append(import_from_cls.name)
335
336if isinstance(node, ast.Call):
337cls_name = getattr(node.func, "attr", None)
338if cls_name:
339classes.append(cls_name)
340
341return cls.__name__ in classes
342
343def scan(self) -> List[Dict[str, str]]:
344"""Finds Lightning modules in files, returning importable objects.
345
346Returns
347-------
348List[Dict[str, Any]]
349List of dicts containing all metadata required
350to import modules found.
351
352"""
353modules_found: Dict[str, List[Dict[str, Any]]] = {}
354
355for path in self.paths:
356try:
357module = ast.parse(path.open().read())
358except SyntaxError:
359print(f"Error while parsing {path}: SKIPPING")
360continue
361for visitor_class in self.visitor_classes:
362visitor = visitor_class()
363visitor.visit(module)
364if not visitor.found:
365continue
366_path = str(path)
367ns_info = {
368"file": _path,
369"namespace": _path.replace("/", ".").replace(".py", ""),
370}
371modules_found[visitor_class.class_name] = [{**entry, **ns_info} for entry in visitor.found]
372
373return modules_found
374
375
376def _is_method_context(component: Union["LightningFlow", "LightningWork"], selected_caller_name: str) -> bool:
377"""Checks whether the call to a component originates from within the context of the component's ``__init__``
378method."""
379frame = inspect.currentframe().f_back
380
381while frame is not None:
382caller_name = frame.f_code.co_name
383caller_self = frame.f_locals.get("self")
384if caller_name == selected_caller_name and caller_self is component:
385# the call originates from a frame under component.__init__
386return True
387frame = frame.f_back
388
389return False
390
391
392def _is_init_context(component: Union["LightningFlow", "LightningWork"]) -> bool:
393"""Checks whether the call to a component originates from within the context of the component's ``__init__``
394method."""
395return _is_method_context(component, "__init__")
396
397
398def _is_run_context(component: Union["LightningFlow", "LightningWork"]) -> bool:
399"""Checks whether the call to a component originates from within the context of the component's ``run`` method."""
400return _is_method_context(component, "run") or _is_method_context(component, "load_state_dict")
401