pytorch
1# mypy: allow-untyped-defs
2"""
3APIs related to torch.compile which lazily import torch._dynamo to avoid
4circular dependencies.
5"""
6
7import functools
8
9
10def _disable_dynamo(fn=None, recursive=True):
11"""
12This API should be only used inside torch, external users should still use
13torch._dynamo.disable. The main goal of this API is to avoid circular
14imports issues that is common while using _dynamo.disable inside torch
15itself.
16
17This API avoids it by lazily importing torch._dynamo from the import time to
18the invocation of the decorated function.
19"""
20if fn is not None:
21
22@functools.wraps(fn)
23def inner(*args, **kwargs):
24# cache this on the first invocation to avoid adding too much overhead.
25disable_fn = getattr(fn, "__dynamo_disable", None)
26if disable_fn is None:
27import torch._dynamo
28
29disable_fn = torch._dynamo.disable(fn, recursive)
30fn.__dynamo_disable = disable_fn
31
32return disable_fn(*args, **kwargs)
33
34return inner
35else:
36# decorator usage like @_disable_dynamo(recursive=False). The resulting
37# object expects the original decorated function as the arg.
38return functools.partial(_disable_dynamo, recursive=recursive)
39