pytorch-lightning

Форк
0
193 строки · 6.5 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import functools
16
import inspect
17
import runpy
18
import sys
19
import time
20
from pathlib import Path
21
from typing import Any, Optional
22

23

24
def get_default_args(func):
25
    signature = inspect.signature(func)
26
    return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
27

28

29
def wrap_fn(fn, cls, method_name, trace, stack_level=1, pre_fn=None, post_fn=None, is_class_method=None):
30
    """Wrap a function so that its execution can be traced and its args and return values modified."""
31
    class_name = cls.__qualname__
32

33
    @functools.wraps(fn)
34
    def fn_with_tracing(self, *args: Any, **kwargs: Any):
35
        if class_name not in trace:
36
            trace[class_name] = {}
37

38
        self_id = id(self)
39
        stack = inspect.stack()
40
        frame = stack[stack_level]
41
        frame_id = id(frame)
42
        stack_len = len(stack) - 1
43

44
        if self_id not in trace[class_name]:
45
            trace[class_name][self_id] = {}
46

47
        if method_name not in trace[class_name][self_id]:
48
            trace[class_name][self_id][method_name] = {}
49

50
        if frame_id not in trace[class_name][self_id][method_name]:
51
            trace[class_name][self_id][method_name][frame_id] = {}
52

53
        trace_entry = trace[class_name][self_id][method_name][frame_id]
54

55
        if pre_fn:
56
            # If a pre_fn is specified, it can both record information
57
            # in a trace, as well as return modified args and kwargs
58
            # that will be provided to the actual fn being wrappped
59
            pre_trace, args, kwargs = pre_fn(self, *args, **kwargs)
60
            trace_entry["pre"] = pre_trace
61

62
        # We record the invocation and the calling location in the trace
63
        trace_entry["frame"] = {
64
            "filename": frame.filename,
65
            "lineno": frame.lineno,
66
            "function": frame.function,
67
            "depth": stack_len,
68
        }
69

70
        # we cache the dfeault parameters used during the function call
71
        trace_entry["default_args"] = get_default_args(fn)
72

73
        # we cache also the parameters used during the function call
74
        trace_entry["call_args"] = kwargs
75

76
        trace_entry["call"] = {"start": time.time_ns()}
77

78
        ret = fn(self, *args, **kwargs) if not is_class_method else fn(*args, **kwargs)
79

80
        trace_entry["call"]["end"] = time.time_ns()
81

82
        if post_fn:
83
            # If a post_fn is specified, it can both record information
84
            # in a trace, as well as modify the value returned from fn
85
            post_trace, ret = post_fn(self, ret)
86
            trace_entry["post"] = post_trace
87

88
        return ret
89

90
    return fn_with_tracing
91

92

93
class Tracer:
94
    def __init__(self):
95
        self.methods = []
96
        self.orig = {}
97
        self.res = {}
98

99
    def add_traced(self, cls, method_name, stack_level=1, pre_fn=None, post_fn=None):
100
        """Record the fact that we will want to trace method_name in class cls.
101

102
        Optionally provide two functions that will execute prior to and after the method. The functions also have a
103
        chance to modify the input arguments and the return values of the methods.
104

105
        """
106
        self.methods.append((cls, method_name, stack_level, pre_fn, post_fn))
107

108
    def _instrument(self):
109
        """Modify classes by wrapping methods that need to be traced.
110

111
        Initialize the output trace dict.
112

113
        """
114
        self.res = {}
115
        for cls, method, stack_level, pre_fn, post_fn in self.methods:
116
            fn = getattr(cls, method)
117
            # this checks if the passed function is a class method
118
            fn_is_class_method: bool = hasattr(fn, "__self__")
119

120
            if cls not in self.orig:
121
                self.orig[cls] = {}
122
            self.orig[cls][method] = fn
123
            wrapped_fn = wrap_fn(
124
                fn,
125
                cls,
126
                method,
127
                self.res,
128
                stack_level=stack_level,
129
                pre_fn=pre_fn,
130
                post_fn=post_fn,
131
                is_class_method=fn_is_class_method,
132
            )
133

134
            # this is needed to wrap class methods
135
            if fn_is_class_method:
136
                wrapped_fn = classmethod(wrapped_fn)
137

138
            setattr(cls, method, wrapped_fn)
139

140
    def _restore(self):
141
        """Restore original methods so classes go back to their initial state."""
142
        for cls in self.orig:
143
            for method in self.orig[cls]:
144
                setattr(cls, method, self.orig[cls][method])
145

146
    def _cleanup(self):
147
        """Cleanup trace by converting trace[class_name][instance_id][method_name][frame_id] to
148
        trace[class_name][][method_name][] thereby removing references to instance ids."""
149
        out = {}
150
        for class_name in self.res:
151
            out[class_name] = []
152
            for self_id in self.res[class_name]:
153
                instance = self.res[class_name][self_id]
154
                out_instance = {"id": self_id}
155
                for method_name, method in instance.items():
156
                    frames = []
157
                    for frame_id, frame in method.items():
158
                        frame["id"] = frame_id
159
                        frames.append(frame)
160
                    out_instance[method_name] = frames
161
                out[class_name].append(out_instance)
162
        self.res = out
163

164
    def trace(self, *args: Any, init_globals=None) -> Optional[dict]:
165
        """Execute the command-line arguments in args after instrumenting for tracing.
166

167
        Restore the classes to their initial state after tracing.
168

169
        """
170
        args = list(args)
171
        script = args[0]
172
        script_dir = Path(script).parent.absolute()
173

174
        sys_path = sys.path[:]
175
        sys_argv = sys.argv[:]
176

177
        sys.path.append(str(script_dir))
178

179
        sys.argv = args
180

181
        self._instrument()
182

183
        res = runpy.run_path(script, run_name="__main__", init_globals=init_globals or globals())
184

185
        self._restore()
186
        self._cleanup()
187

188
        sys.path = sys_path[:]
189
        sys.argv = sys_argv[:]
190

191
        res["tracer_res"] = self.res
192

193
        return res
194

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.