3
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
7
from tempfile import NamedTemporaryFile
9
from torch.cuda._memory_viz import _frames_fmt, _block_extra
12
logger = logging.getLogger(__name__)
14
def observe_garbage(observer):
18
# when GC runs during exit, things like `sys` will already be unloaded
19
# so we have to disable the callback to avoid hitting errors.
22
atexit.register(disable)
24
def gc_callback(phase, info):
29
gc.set_debug(gc.DEBUG_SAVEALL)
31
orig_trace = sys.getprofile()
34
def do_collect(*args, **kwargs):
36
if not self_return[0]:
39
sys.setprofile(orig_trace)
42
# things in gc.garbage have survived a collection
43
# so to free them we have to collect a generation greater than them
44
# but that might _also_ free other stuff and we don't want to miss
45
# that stuff. So we have to now force gc at the highest level here,
46
# report all of what we found, _then_ we can free it up.
47
if info['generation'] != 2:
51
# we have to re-run GC to clean up the cycles
52
# we saved from before.
54
before = torch.cuda.memory_allocated()
56
after = torch.cuda.memory_allocated()
58
logger.warning("CUDA Memory changed during GC, %d bytes freed.", before - after)
61
if orig_trace is not None:
62
return orig_trace(*args, **kwargs)
63
sys.setprofile(do_collect)
65
gc.callbacks.append(gc_callback)
67
# provide a way to disarm the callback
69
gc.callbacks.remove(gc_callback)
72
# Function to visualize cycles adapated from refcycle:
73
# Copyright 2013 Mark Dickinson
75
# Licensed under the Apache License, Version 2.0 (the "License");
76
# you may not use this file except in compliance with the License.
77
# You may obtain a copy of the License at
79
# http://www.apache.org/licenses/LICENSE-2.0
81
# Unless required by applicable law or agreed to in writing, software
82
# distributed under the License is distributed on an "AS IS" BASIS,
83
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
84
# See the License for the specific language governing permissions and
85
# limitations under the License.
90
return type(f().__closure__[0])
92
CellType = _get_cell_type()
94
def annotated_references(obj):
96
Return known information about references held by the given object.
98
Returns a mapping from referents to lists of descriptions. Note that there
99
may be more than one edge leading to any particular referent; hence the
100
need for a list. Descriptions are currently strings.
103
references: Dict[int, List[str]] = {}
105
def add_reference(name, obj):
106
references.setdefault(id(obj), []).append(name)
108
def add_attrs(*attrs):
110
if hasattr(obj, attr):
111
add_reference(attr, getattr(obj, attr))
113
def add_cell_references():
115
add_attrs("cell_contents")
117
# if cell_contents is empty,
118
# accessing it raises ValueError
119
# in this case there is no object to
123
def add_function_references():
124
add_attrs("__defaults__",
136
def add_sequence_references():
137
for position, item in enumerate(obj):
138
add_reference(f"[{position}]", item)
140
def add_dict_references():
141
for key, value in obj.items():
142
add_reference("key", key)
143
add_reference(f"[{repr(key)}]", value)
145
def add_set_references():
147
add_reference("element", elt)
149
def add_bound_method_references():
150
add_attrs("__self__", "__func__", "im_class")
152
def add_weakref_references():
153
# For subclasses of weakref, we can't reliably distinguish the
154
# callback (if any) from other attributes.
155
if type(obj) is weakref.ref:
156
referents = gc.get_referents(obj)
157
if len(referents) == 1:
158
target = referents[0]
159
add_reference("__callback__", target)
162
def add_frame_references():
163
f_locals = obj.f_locals
164
add_attrs("f_back", "f_code", "f_builtins", "f_globals", "f_trace", "f_locals")
165
# Some badly-behaved code replaces the f_locals dict with
166
# something that doesn't support the full dict interface. So we
167
# only continue with the annotation if f_locals is a Python dict.
168
if type(f_locals) is dict:
169
for name, local in obj.f_locals.items():
170
add_reference(f"local {name}", local)
172
def add_getset_descriptor_references():
173
add_attrs("__objclass__", "__name__", "__doc__")
175
type_based_references = {
176
tuple: add_sequence_references,
177
list: add_sequence_references,
178
dict: add_dict_references,
179
set: add_set_references,
180
frozenset: add_set_references,
181
types.FunctionType: add_function_references,
182
types.FrameType: add_frame_references,
183
CellType: add_cell_references,
184
types.MethodType: add_bound_method_references,
185
weakref.ref: add_weakref_references,
186
types.GetSetDescriptorType: add_getset_descriptor_references,
189
for type_ in type(obj).__mro__:
190
if type_ in type_based_references:
191
type_based_references[type_]()
193
add_attrs("__dict__", "__class__")
194
if isinstance(obj, type):
199
###############################################################################
203
BASE_TYPES = (int, float, complex, type(None), str, bytes)
204
FRAME_FILENAME_LIMIT = 32
206
def object_annotation(obj):
208
Return a string to be used for Graphviz nodes.
210
The string should be short but as informative as possible.
213
def format_sequence(obj):
214
body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for i, x in zip(range(8), obj))
216
body = f'{body}, ...{len(obj) - 8}'
219
# For basic types, use the repr.
220
if isinstance(obj, BASE_TYPES):
222
if type(obj).__name__ == 'function':
223
return f"function\n{obj.__name__}"
224
elif isinstance(obj, types.MethodType):
226
func_name = obj.__func__.__qualname__
227
except AttributeError:
228
func_name = "<anonymous>"
229
return f"instancemethod\n{func_name}"
230
elif isinstance(obj, list):
231
return f"[{format_sequence(obj)}]"
232
elif isinstance(obj, tuple):
233
return f"({format_sequence(obj)})"
234
elif isinstance(obj, dict):
235
return f"dict[{len(obj)}]"
236
elif isinstance(obj, types.ModuleType):
237
return f"module\n{obj.__name__}"
238
elif isinstance(obj, type):
239
return f"type\n{obj.__name__}"
240
elif isinstance(obj, weakref.ref):
243
return "weakref (dead referent)"
245
return f"weakref to id 0x{id(referent):x}"
246
elif isinstance(obj, types.FrameType):
247
filename = obj.f_code.co_filename
248
if len(filename) > FRAME_FILENAME_LIMIT:
249
filename = "..." + filename[-(FRAME_FILENAME_LIMIT - 3):]
250
return f"frame\n{filename}:{obj.f_lineno}"
252
return f"object\n{type(obj).__module__}.{type(obj).__name__}"
256
class Node(NamedTuple):
258
context: Optional[str]
260
referrents: List[Tuple[str, int]]
262
def create_graph(objects, *, context=None, filter=None):
264
context = cuda_allocation_context()
266
filter = is_cuda_tensor
268
nodes = [Node(object_annotation(obj), context(obj), filter(obj), []) for obj in objects]
269
node_referrers: List[List[int]] = [[] for obj in objects]
271
id_to_node = {id(obj): i for i, obj in enumerate(objects)}
273
fidx = id_to_node[id(obj)]
275
references = annotated_references(obj)
276
for referrent in gc.get_referents(obj):
278
tidx = id_to_node.get(rid, None)
282
labels = references.get(rid, ["?"])
283
node_referrers[tidx].append(fidx)
285
f.referrents.append((label, tidx))
287
to_search = [i for i, n in enumerate(nodes) if n.root]
290
idx = to_search.pop()
294
referrers = node_referrers[idx]
295
to_search.extend(referrers)
296
id_to_filtered_id: Dict[int, int] = {}
297
filtered: List[Any] = []
298
for i, n in enumerate(nodes):
300
id_to_filtered_id[i] = len(id_to_filtered_id)
303
n.referrents[:] = [(label, id_to_filtered_id[idx])
304
for (label, idx) in n.referrents
305
if idx in id_to_filtered_id]
312
def is_cuda_tensor(obj):
313
return isinstance(obj, torch.Tensor) and obj.is_cuda and not isinstance(obj, torch._subclasses.FakeTensor)
315
def cuda_allocation_context():
316
snapshot = torch.cuda.memory._snapshot()
318
for seg in snapshot['segments']:
319
addr = seg['address']
320
for blk in seg['blocks']:
321
if blk['state'] == 'active_allocated':
322
frames, real_size = _block_extra(blk)
323
addr_to_frame[addr] = frames
326
def object_context(obj):
327
if is_cuda_tensor(obj):
328
addr = obj.untyped_storage().data_ptr()
329
frames = addr_to_frame.get(addr)
330
if frames is not None:
331
return '\n'.join(_frames_fmt(frames, full_filename=True))
333
return object_context
336
lines = ["digraph GraphName {", "node [shape=rect];", 'rankdir=LR;']
337
for i, n in enumerate(nodes):
338
lines.append(f'{i} [label={escape(n.label)}, color={ "red" if n.root else "black"}];')
340
for i, f in enumerate(nodes):
341
for label, j in f.referrents:
342
lines.append(f'{i} -> {j} [label = {escape(label)}]')
344
return '\n'.join(lines)
359
flex-direction: column;
387
<div id="preContainer">
388
<pre id="stacktrace">Mouse over tensor objects to see where they were allocated.</pre>
391
<script src='https://cdnjs.cloudflare.com/ajax/libs/viz.js/1.8.0/viz-lite.js'></script>
394
let image = Viz(dot, {format: 'svg'});
395
document.getElementById('main').innerHTML = image
401
_listener_template = """
402
document.getElementById('node{id}').addEventListener('mouseover', function(event) {{
403
document.getElementById("stacktrace").textContent = {stack}
408
for i, n in enumerate(nodes):
409
if n.context is None:
411
s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}'))
414
return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners))
416
def observe_tensor_cycles(callback):
417
torch.cuda.memory._record_memory_history(max_entries=100000)
419
def observer(garbage):
421
if not any(is_cuda_tensor(obj) for obj in garbage):
422
logger.info("No CUDA Tensors found in garbage")
424
callback(to_html(create_graph(garbage)))
425
return observe_garbage(observer)
428
def warn_tensor_cycles():
430
Install a warning that reports whenever a cycle that is holding CUDA memory is observed.
432
The warning produces an .html file that visualizes the cycle,
433
and links it to the stack frame that allocted the CUDA tensor.
435
Reference cycles are freed by the cycle collector rather than being cleaned up
436
when the objects in the cycle first become unreachable. If a cycle points to a tensor,
437
the CUDA memory for that tensor will not be freed until garbage collection runs.
438
Accumulation of CUDA allocations can lead to out of memory errors (OOMs), as well as
439
non-deterministic allocation behavior which is harder to debug.
441
logger.info("Watching Python reference cycles for CUDA Tensors.")
443
def write_and_log(html):
444
with NamedTemporaryFile('w', suffix='.html', delete=False) as f:
446
logger.warning('Reference cycle includes a CUDA Tensor see visualization of cycle %s', f.name)
447
return observe_tensor_cycles(write_and_log)