pytorch-lightning
193 строки · 6.5 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import functools
16import inspect
17import runpy
18import sys
19import time
20from pathlib import Path
21from typing import Any, Optional
22
23
24def get_default_args(func):
25signature = inspect.signature(func)
26return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
27
28
29def wrap_fn(fn, cls, method_name, trace, stack_level=1, pre_fn=None, post_fn=None, is_class_method=None):
30"""Wrap a function so that its execution can be traced and its args and return values modified."""
31class_name = cls.__qualname__
32
33@functools.wraps(fn)
34def fn_with_tracing(self, *args: Any, **kwargs: Any):
35if class_name not in trace:
36trace[class_name] = {}
37
38self_id = id(self)
39stack = inspect.stack()
40frame = stack[stack_level]
41frame_id = id(frame)
42stack_len = len(stack) - 1
43
44if self_id not in trace[class_name]:
45trace[class_name][self_id] = {}
46
47if method_name not in trace[class_name][self_id]:
48trace[class_name][self_id][method_name] = {}
49
50if frame_id not in trace[class_name][self_id][method_name]:
51trace[class_name][self_id][method_name][frame_id] = {}
52
53trace_entry = trace[class_name][self_id][method_name][frame_id]
54
55if pre_fn:
56# If a pre_fn is specified, it can both record information
57# in a trace, as well as return modified args and kwargs
58# that will be provided to the actual fn being wrappped
59pre_trace, args, kwargs = pre_fn(self, *args, **kwargs)
60trace_entry["pre"] = pre_trace
61
62# We record the invocation and the calling location in the trace
63trace_entry["frame"] = {
64"filename": frame.filename,
65"lineno": frame.lineno,
66"function": frame.function,
67"depth": stack_len,
68}
69
70# we cache the dfeault parameters used during the function call
71trace_entry["default_args"] = get_default_args(fn)
72
73# we cache also the parameters used during the function call
74trace_entry["call_args"] = kwargs
75
76trace_entry["call"] = {"start": time.time_ns()}
77
78ret = fn(self, *args, **kwargs) if not is_class_method else fn(*args, **kwargs)
79
80trace_entry["call"]["end"] = time.time_ns()
81
82if post_fn:
83# If a post_fn is specified, it can both record information
84# in a trace, as well as modify the value returned from fn
85post_trace, ret = post_fn(self, ret)
86trace_entry["post"] = post_trace
87
88return ret
89
90return fn_with_tracing
91
92
93class Tracer:
94def __init__(self):
95self.methods = []
96self.orig = {}
97self.res = {}
98
99def add_traced(self, cls, method_name, stack_level=1, pre_fn=None, post_fn=None):
100"""Record the fact that we will want to trace method_name in class cls.
101
102Optionally provide two functions that will execute prior to and after the method. The functions also have a
103chance to modify the input arguments and the return values of the methods.
104
105"""
106self.methods.append((cls, method_name, stack_level, pre_fn, post_fn))
107
108def _instrument(self):
109"""Modify classes by wrapping methods that need to be traced.
110
111Initialize the output trace dict.
112
113"""
114self.res = {}
115for cls, method, stack_level, pre_fn, post_fn in self.methods:
116fn = getattr(cls, method)
117# this checks if the passed function is a class method
118fn_is_class_method: bool = hasattr(fn, "__self__")
119
120if cls not in self.orig:
121self.orig[cls] = {}
122self.orig[cls][method] = fn
123wrapped_fn = wrap_fn(
124fn,
125cls,
126method,
127self.res,
128stack_level=stack_level,
129pre_fn=pre_fn,
130post_fn=post_fn,
131is_class_method=fn_is_class_method,
132)
133
134# this is needed to wrap class methods
135if fn_is_class_method:
136wrapped_fn = classmethod(wrapped_fn)
137
138setattr(cls, method, wrapped_fn)
139
140def _restore(self):
141"""Restore original methods so classes go back to their initial state."""
142for cls in self.orig:
143for method in self.orig[cls]:
144setattr(cls, method, self.orig[cls][method])
145
146def _cleanup(self):
147"""Cleanup trace by converting trace[class_name][instance_id][method_name][frame_id] to
148trace[class_name][][method_name][] thereby removing references to instance ids."""
149out = {}
150for class_name in self.res:
151out[class_name] = []
152for self_id in self.res[class_name]:
153instance = self.res[class_name][self_id]
154out_instance = {"id": self_id}
155for method_name, method in instance.items():
156frames = []
157for frame_id, frame in method.items():
158frame["id"] = frame_id
159frames.append(frame)
160out_instance[method_name] = frames
161out[class_name].append(out_instance)
162self.res = out
163
164def trace(self, *args: Any, init_globals=None) -> Optional[dict]:
165"""Execute the command-line arguments in args after instrumenting for tracing.
166
167Restore the classes to their initial state after tracing.
168
169"""
170args = list(args)
171script = args[0]
172script_dir = Path(script).parent.absolute()
173
174sys_path = sys.path[:]
175sys_argv = sys.argv[:]
176
177sys.path.append(str(script_dir))
178
179sys.argv = args
180
181self._instrument()
182
183res = runpy.run_path(script, run_name="__main__", init_globals=init_globals or globals())
184
185self._restore()
186self._cleanup()
187
188sys.path = sys_path[:]
189sys.argv = sys_argv[:]
190
191res["tracer_res"] = self.res
192
193return res
194