pytorch

Форк
0
/
_cycles.py 
447 строк · 14.4 Кб
1
import gc
2
import sys
3
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
4
import types
5
import weakref
6
import json
7
from tempfile import NamedTemporaryFile
8
import torch
9
from torch.cuda._memory_viz import _frames_fmt, _block_extra
10
import atexit
11
import logging
12
logger = logging.getLogger(__name__)
13

14
def observe_garbage(observer):
15
    enabled = True
16

17
    def disable():
18
        # when GC runs during exit, things like `sys` will already be unloaded
19
        # so we have to disable the callback to avoid hitting errors.
20
        nonlocal enabled
21
        enabled = False
22
    atexit.register(disable)
23

24
    def gc_callback(phase, info):
25
        nonlocal enabled
26
        if not enabled:
27
            return
28
        if phase == "start":
29
            gc.set_debug(gc.DEBUG_SAVEALL)
30
        elif phase == "stop":
31
            orig_trace = sys.getprofile()
32
            self_return = [False]
33

34
            def do_collect(*args, **kwargs):
35
                nonlocal enabled
36
                if not self_return[0]:
37
                    self_return[0] = True
38
                else:
39
                    sys.setprofile(orig_trace)
40
                    enabled = False
41
                    try:
42
                        # things in gc.garbage have survived a collection
43
                        # so to free them we have to collect a generation greater than them
44
                        # but that might _also_ free other stuff and we don't want to miss
45
                        # that stuff. So we have to now force gc at the highest level here,
46
                        # report all of what we found, _then_ we can free it up.
47
                        if info['generation'] != 2:
48
                            gc.collect()
49
                        observer(gc.garbage)
50
                        gc.garbage.clear()
51
                        # we have to re-run GC to clean up the cycles
52
                        # we saved from before.
53
                        gc.set_debug(0)
54
                        before = torch.cuda.memory_allocated()
55
                        gc.collect()
56
                        after = torch.cuda.memory_allocated()
57
                        if before != after:
58
                            logger.warning("CUDA Memory changed during GC, %d bytes freed.", before - after)
59
                    finally:
60
                        enabled = True
61
                if orig_trace is not None:
62
                    return orig_trace(*args, **kwargs)
63
            sys.setprofile(do_collect)
64

65
    gc.callbacks.append(gc_callback)
66

67
    # provide a way to disarm the callback
68
    def remove():
69
        gc.callbacks.remove(gc_callback)
70
    return remove
71

72
# Function to visualize cycles adapated from refcycle:
73
# Copyright 2013 Mark Dickinson
74
#
75
# Licensed under the Apache License, Version 2.0 (the "License");
76
# you may not use this file except in compliance with the License.
77
# You may obtain a copy of the License at
78
#
79
#   http://www.apache.org/licenses/LICENSE-2.0
80
#
81
# Unless required by applicable law or agreed to in writing, software
82
# distributed under the License is distributed on an "AS IS" BASIS,
83
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84
# See the License for the specific language governing permissions and
85
# limitations under the License.
86

87
def _get_cell_type():
88
    def f(x=None):
89
        return lambda: x
90
    return type(f().__closure__[0])
91

92
CellType = _get_cell_type()
93

94
def annotated_references(obj):
95
    """
96
    Return known information about references held by the given object.
97

98
    Returns a mapping from referents to lists of descriptions.  Note that there
99
    may be more than one edge leading to any particular referent; hence the
100
    need for a list.  Descriptions are currently strings.
101

102
    """
103
    references: Dict[int, List[str]] = {}
104

105
    def add_reference(name, obj):
106
        references.setdefault(id(obj), []).append(name)
107

108
    def add_attrs(*attrs):
109
        for attr in attrs:
110
            if hasattr(obj, attr):
111
                add_reference(attr, getattr(obj, attr))
112

113
    def add_cell_references():
114
        try:
115
            add_attrs("cell_contents")
116
        except ValueError:
117
            # if cell_contents is empty,
118
            # accessing it raises ValueError
119
            # in this case there is no object to
120
            # annotate
121
            pass
122

123
    def add_function_references():
124
        add_attrs("__defaults__",
125
                  "__closure__",
126
                  "__globals__",
127
                  "__code__",
128
                  "__name__",
129
                  "__module__",
130
                  "__doc__"
131
                  "__qualname__",
132
                  "__annotations__",
133
                  "__kwdefaults__")
134

135

136
    def add_sequence_references():
137
        for position, item in enumerate(obj):
138
            add_reference(f"[{position}]", item)
139

140
    def add_dict_references():
141
        for key, value in obj.items():
142
            add_reference("key", key)
143
            add_reference(f"[{repr(key)}]", value)
144

145
    def add_set_references():
146
        for elt in obj:
147
            add_reference("element", elt)
148

149
    def add_bound_method_references():
150
        add_attrs("__self__", "__func__", "im_class")
151

152
    def add_weakref_references():
153
        # For subclasses of weakref, we can't reliably distinguish the
154
        # callback (if any) from other attributes.
155
        if type(obj) is weakref.ref:
156
            referents = gc.get_referents(obj)
157
            if len(referents) == 1:
158
                target = referents[0]
159
                add_reference("__callback__", target)
160

161

162
    def add_frame_references():
163
        f_locals = obj.f_locals
164
        add_attrs("f_back", "f_code", "f_builtins", "f_globals", "f_trace", "f_locals")
165
        # Some badly-behaved code replaces the f_locals dict with
166
        # something that doesn't support the full dict interface.  So we
167
        # only continue with the annotation if f_locals is a Python dict.
168
        if type(f_locals) is dict:
169
            for name, local in obj.f_locals.items():
170
                add_reference(f"local {name}", local)
171

172
    def add_getset_descriptor_references():
173
        add_attrs("__objclass__", "__name__", "__doc__")
174

175
    type_based_references = {
176
        tuple: add_sequence_references,
177
        list: add_sequence_references,
178
        dict: add_dict_references,
179
        set: add_set_references,
180
        frozenset: add_set_references,
181
        types.FunctionType: add_function_references,
182
        types.FrameType: add_frame_references,
183
        CellType: add_cell_references,
184
        types.MethodType: add_bound_method_references,
185
        weakref.ref: add_weakref_references,
186
        types.GetSetDescriptorType: add_getset_descriptor_references,
187
    }
188

189
    for type_ in type(obj).__mro__:
190
        if type_ in type_based_references:
191
            type_based_references[type_]()
192

193
    add_attrs("__dict__", "__class__")
194
    if isinstance(obj, type):
195
        add_attrs("__mro__")
196

197
    return references
198

199
###############################################################################
200
# Object annotations.
201

202

203
BASE_TYPES = (int, float, complex, type(None), str, bytes)
204
FRAME_FILENAME_LIMIT = 32
205

206
def object_annotation(obj):
207
    """
208
    Return a string to be used for Graphviz nodes.
209

210
    The string should be short but as informative as possible.
211
    """
212

213
    def format_sequence(obj):
214
        body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for i, x in zip(range(8), obj))
215
        if len(obj) > 8:
216
            body = f'{body}, ...{len(obj) - 8}'
217
        return body
218

219
    # For basic types, use the repr.
220
    if isinstance(obj, BASE_TYPES):
221
        return repr(obj)
222
    if type(obj).__name__ == 'function':
223
        return f"function\n{obj.__name__}"
224
    elif isinstance(obj, types.MethodType):
225
        try:
226
            func_name = obj.__func__.__qualname__
227
        except AttributeError:
228
            func_name = "<anonymous>"
229
        return f"instancemethod\n{func_name}"
230
    elif isinstance(obj, list):
231
        return f"[{format_sequence(obj)}]"
232
    elif isinstance(obj, tuple):
233
        return f"({format_sequence(obj)})"
234
    elif isinstance(obj, dict):
235
        return f"dict[{len(obj)}]"
236
    elif isinstance(obj, types.ModuleType):
237
        return f"module\n{obj.__name__}"
238
    elif isinstance(obj, type):
239
        return f"type\n{obj.__name__}"
240
    elif isinstance(obj, weakref.ref):
241
        referent = obj()
242
        if referent is None:
243
            return "weakref (dead referent)"
244
        else:
245
            return f"weakref to id 0x{id(referent):x}"
246
    elif isinstance(obj, types.FrameType):
247
        filename = obj.f_code.co_filename
248
        if len(filename) > FRAME_FILENAME_LIMIT:
249
            filename = "..." + filename[-(FRAME_FILENAME_LIMIT - 3):]
250
        return f"frame\n{filename}:{obj.f_lineno}"
251
    else:
252
        return f"object\n{type(obj).__module__}.{type(obj).__name__}"
253

254

255

256
class Node(NamedTuple):
257
    label: str
258
    context: Optional[str]
259
    root: bool
260
    referrents: List[Tuple[str, int]]
261

262
def create_graph(objects, *, context=None, filter=None):
263
    if context is None:
264
        context = cuda_allocation_context()
265
    if filter is None:
266
        filter = is_cuda_tensor
267

268
    nodes = [Node(object_annotation(obj), context(obj), filter(obj), []) for obj in objects]
269
    node_referrers: List[List[int]] = [[] for obj in objects]
270

271
    id_to_node = {id(obj): i for i, obj in enumerate(objects)}
272
    for obj in objects:
273
        fidx = id_to_node[id(obj)]
274
        f = nodes[fidx]
275
        references = annotated_references(obj)
276
        for referrent in gc.get_referents(obj):
277
            rid = id(referrent)
278
            tidx = id_to_node.get(rid, None)
279
            if tidx is None:
280
                continue
281
            t = nodes[tidx]
282
            labels = references.get(rid, ["?"])
283
            node_referrers[tidx].append(fidx)
284
            for label in labels:
285
                f.referrents.append((label, tidx))
286

287
    to_search = [i for i, n in enumerate(nodes) if n.root]
288
    to_keep = set()
289
    while to_search:
290
        idx = to_search.pop()
291
        if idx in to_keep:
292
            continue
293
        to_keep.add(idx)
294
        referrers = node_referrers[idx]
295
        to_search.extend(referrers)
296
    id_to_filtered_id: Dict[int, int] = {}
297
    filtered: List[Any] = []
298
    for i, n in enumerate(nodes):
299
        if i in to_keep:
300
            id_to_filtered_id[i] = len(id_to_filtered_id)
301
            filtered.append(n)
302
    for n in filtered:
303
        n.referrents[:] = [(label, id_to_filtered_id[idx])
304
                           for (label, idx) in n.referrents
305
                           if idx in id_to_filtered_id]
306
    return filtered
307

308
def escape(n):
309
    return json.dumps(n)
310

311

312
def is_cuda_tensor(obj):
313
    return isinstance(obj, torch.Tensor) and obj.is_cuda and not isinstance(obj, torch._subclasses.FakeTensor)
314

315
def cuda_allocation_context():
316
    snapshot = torch.cuda.memory._snapshot()
317
    addr_to_frame = {}
318
    for seg in snapshot['segments']:
319
        addr = seg['address']
320
        for blk in seg['blocks']:
321
            if blk['state'] == 'active_allocated':
322
                frames, real_size = _block_extra(blk)
323
                addr_to_frame[addr] = frames
324
            addr += blk['size']
325

326
    def object_context(obj):
327
        if is_cuda_tensor(obj):
328
            addr = obj.untyped_storage().data_ptr()
329
            frames = addr_to_frame.get(addr)
330
            if frames is not None:
331
                return '\n'.join(_frames_fmt(frames, full_filename=True))
332
        return None
333
    return object_context
334

335
def to_dot(nodes):
336
    lines = ["digraph GraphName {", "node [shape=rect];", 'rankdir=LR;']
337
    for i, n in enumerate(nodes):
338
        lines.append(f'{i} [label={escape(n.label)}, color={ "red" if n.root else "black"}];')
339

340
    for i, f in enumerate(nodes):
341
        for label, j in f.referrents:
342
            lines.append(f'{i} -> {j} [label = {escape(label)}]')
343
    lines.append("}\n")
344
    return '\n'.join(lines)
345

346
_template = """
347
<!DOCTYPE html>
348
<html>
349
<head>
350
  <style>
351
    body {
352
      margin: 0;
353
      padding: 0;
354
      overflow: hidden;
355
    }
356

357
    #container {
358
      display: flex;
359
      flex-direction: column;
360
      height: 100vh;
361
    }
362

363
    #main {
364
      flex: 2;
365
      overflow: auto;
366
    }
367

368
    #preContainer {
369
      flex: 1;
370
      overflow: auto;
371
    }
372

373
    svg {
374
        overflow: scroll;
375
    }
376

377
    pre {
378
      margin: 0;
379
      padding: 10px;
380
    }
381
  </style>
382
</head>
383
<body>
384
  <div id="container">
385
    <div id="main">
386
    </div>
387
    <div id="preContainer">
388
      <pre id="stacktrace">Mouse over tensor objects to see where they were allocated.</pre>
389
    </div>
390
  </div>
391
<script src='https://cdnjs.cloudflare.com/ajax/libs/viz.js/1.8.0/viz-lite.js'></script>
392
<script>
393
let dot = $DOT
394
let image = Viz(dot, {format: 'svg'});
395
document.getElementById('main').innerHTML = image
396
$LISTENERS
397
</script>
398
</body>
399
</html>
400
"""
401
_listener_template = """
402
document.getElementById('node{id}').addEventListener('mouseover', function(event) {{
403
  document.getElementById("stacktrace").textContent = {stack}
404
}})
405
"""
406
def to_html(nodes):
407
    listeners = []
408
    for i, n in enumerate(nodes):
409
        if n.context is None:
410
            continue
411
        s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}'))
412
        listeners.append(s)
413
    dot = to_dot(nodes)
414
    return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners))
415

416
def observe_tensor_cycles(callback):
417
    torch.cuda.memory._record_memory_history(max_entries=100000)
418

419
    def observer(garbage):
420
        if garbage:
421
            if not any(is_cuda_tensor(obj) for obj in garbage):
422
                logger.info("No CUDA Tensors found in garbage")
423
                return
424
            callback(to_html(create_graph(garbage)))
425
    return observe_garbage(observer)
426

427

428
def warn_tensor_cycles():
429
    """
430
    Install a warning that reports whenever a cycle that is holding CUDA memory is observed.
431

432
    The warning produces an .html file that visualizes the cycle,
433
    and links it to the stack frame that allocted the CUDA tensor.
434

435
    Reference cycles are freed by the cycle collector rather than being cleaned up
436
    when the objects in the cycle first become unreachable. If a cycle points to a tensor,
437
    the CUDA memory for that tensor will not be freed until garbage collection runs.
438
    Accumulation of CUDA allocations can lead to out of memory errors (OOMs), as well as
439
    non-deterministic allocation behavior which is harder to debug.
440
    """
441
    logger.info("Watching Python reference cycles for CUDA Tensors.")
442

443
    def write_and_log(html):
444
        with NamedTemporaryFile('w', suffix='.html', delete=False) as f:
445
            f.write(html)
446
            logger.warning('Reference cycle includes a CUDA Tensor see visualization of cycle %s', f.name)
447
    return observe_tensor_cycles(write_and_log)
448

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.