pytorch
1# mypy: allow-untyped-defs
2# Allows one to expose an API in a private submodule publicly as per the definition
3# in PyTorch's public api policy.
4#
5# It is a temporary solution while we figure out if it should be the long-term solution
6# or if we should amend PyTorch's public api policy. The concern is that this approach
7# may not be very robust because it's not clear what __module__ is used for.
8# However, both numpy and jax overwrite the __module__ attribute of their APIs
9# without problem, so it seems fine.
10def exposed_in(module):11def wrapper(fn):12fn.__module__ = module13return fn14
15return wrapper16