2
from torch._C import _rename_privateuse1_backend, _get_privateuse1_backend_name
3
from typing import List, Optional, Union
5
__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"]
11
_privateuse1_backend_name = "privateuseone"
13
def rename_privateuse1_backend(backend_name: str) -> None:
15
Rename the privateuse1 backend device to make it more convenient to use as a device name within PyTorch APIs.
19
(1) (In C++) implement kernels for various torch operations, and register them
20
to the PrivateUse1 dispatch key.
21
(2) (In python) call torch.utils.rename_privateuse1_backend("foo")
23
You can now use "foo" as an ordinary device string in python.
25
Note: this API can only be called once per process. Attempting to change
26
the external backend after it's already been set will result in an error.
28
Note(AMP): If you want to support AMP on your device, you can register a custom backend module.
29
The backend must register a custom backend module with ``torch._register_device_module("foo", BackendModule)``.
30
BackendModule needs to have the following API's:
32
(1) ``get_amp_supported_dtype() -> List[torch.dtype]``
33
get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype.
35
(2) ``is_autocast_enabled() -> bool``
36
check the AMP is enabled or not on your "foo" device.
38
(3) ``get_autocast_dtype() -> torch.dtype``
39
get the supported dtype on your "foo" device in AMP, which is set by ``set_autocast_dtype`` or the
40
default dtype, and the default dtype is ``torch.float16``.
42
(4) ``set_autocast_enabled(bool) -> None``
43
enable the AMP or not on your "foo" device.
45
(5) ``set_autocast_dtype(dtype) -> None``
46
set the supported dtype on your "foo" device in AMP, and the dtype be contained in the dtypes got
47
from ``get_amp_supported_dtype``.
49
Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's:
51
(1) ``_is_in_bad_fork() -> bool``
52
Return ``True`` if now it is in bad_fork, else return ``False``.
54
(2) ``manual_seed_all(seed int) -> None``
55
Sets the seed for generating random numbers for your devices.
57
(3) ``device_count() -> int``
58
Returns the number of "foo"s available.
60
(4) ``get_rng_state(device: Union[int, str, torch.device] = 'foo') -> Tensor``
61
Returns a list of ByteTensor representing the random number states of all devices.
63
(5) ``set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> None``
64
Sets the random number generator state of the specified "foo" device.
66
And there are some common funcs:
68
(1) ``is_available() -> bool``
69
Returns a bool indicating if "foo" is currently available.
71
(2) ``current_device() -> int``
72
Returns the index of a currently selected device.
74
For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend
75
For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example
79
>>> # xdoctest: +SKIP("failing")
80
>>> torch.utils.rename_privateuse1_backend("foo")
81
# This will work, assuming that you've implemented the right C++ kernels
82
# to implement torch.ones.
83
>>> a = torch.ones(2, device="foo")
86
_rename_privateuse1_backend(backend_name)
87
global _privateuse1_backend_name
88
_privateuse1_backend_name = backend_name
90
def _check_register_once(module, attr):
91
if hasattr(module, attr):
92
raise RuntimeError(f"The custom device module of {module} has already been registered with {attr}")
95
def _normalization_device(custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None) -> int:
96
def _get_current_device_index():
97
_get_device_index = "current_device"
98
if hasattr(torch, custom_backend_name) and \
99
hasattr(getattr(torch, custom_backend_name), _get_device_index):
100
return getattr(getattr(torch, custom_backend_name), _get_device_index)()
106
return _get_current_device_index()
109
elif isinstance(device, str):
110
device = torch.device(device)
113
if isinstance(device, torch.device):
114
if device.type != custom_backend_name:
115
raise RuntimeError(f"Invalid device, must be {custom_backend_name} device")
116
elif device.index is None:
117
device_idx = _get_current_device_index()
119
device_idx = device.index
126
def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
128
def wrap_tensor_backend(self: torch.Tensor) -> bool:
129
return self.device.type == custom_backend_name
131
_check_register_once(torch.Tensor, f'is_{custom_backend_name}')
132
setattr(torch.Tensor, f'is_{custom_backend_name}', wrap_tensor_backend)
134
def wrap_tensor_to(self: torch.Tensor, device: Optional[Union[int, torch.device]] = None, non_blocking=False,
135
**kwargs) -> torch.Tensor:
136
r"""Perform Tensor device conversion. Call the to operator implementation.
139
If the ``self`` Tensor already
140
has the correct :class:`torch.device`, then ``self`` is returned.
141
Otherwise, the returned tensor is a copy of ``self`` with the desired :class:`torch.device`.
144
device (int, optional): if specified, all parameters will be copied to that device
145
non_blocking (bool): If ``True`` and the source is in pinned memory,
146
the copy will be asynchronous with respect to the host. Otherwise,
147
the argument has no effect.
148
**kwargs (dict): For compatibility, may contain the key ``memory_format`` argument.
150
device_idx = _normalization_device(custom_backend_name, device)
151
return self.to(device=torch.device(f'{custom_backend_name}:{device_idx}'), non_blocking=non_blocking, **kwargs)
153
_check_register_once(torch.Tensor, custom_backend_name)
154
setattr(torch.Tensor, custom_backend_name, wrap_tensor_to)
157
def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
160
if not hasattr(torch.Tensor, custom_backend_name):
162
f"Can not automatically generate {custom_backend_name}() method for torch.nn.Module."
163
f"Because torch.Tensor doesn't has the method {custom_backend_name}()."
164
f"For this error, you can try setting for_tensor=True.")
166
def wrap_module_to(self: torch.nn.modules.module.T,
167
device: Optional[Union[int, torch.device]] = None) -> torch.nn.modules.module.T:
168
r"""Move all model parameters and buffers to the custom device.
170
This also makes associated parameters and buffers different objects. So
171
it should be called before constructing optimizer if the module will
172
live on device while being optimized.
175
This method modifies the module in-place.
178
device (int, optional): if specified, all parameters will be copied to that device
180
return self._apply(lambda t: getattr(t, custom_backend_name)(device))
182
_check_register_once(torch.nn.Module, custom_backend_name)
183
setattr(torch.nn.Module, custom_backend_name, wrap_module_to)
186
def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
187
unsupported_dtype: Optional[List[torch.dtype]] = None) -> None:
191
def wrap_storage_backend(self: torch.storage._StorageBase) -> bool:
192
r"""Return the internal :class:`torch.UntypedStorage`."""
193
return self.device.type == custom_backend_name
195
_check_register_once(torch.storage._StorageBase, f'is_{custom_backend_name}')
196
setattr(torch.storage._StorageBase, f'is_{custom_backend_name}', wrap_storage_backend)
198
def wrap_storage_to(self, device=None, non_blocking=False):
199
r"""Return a copy of this object in custom device memory.
201
If this object is already in device memory and on the correct device, then
202
no copy is performed and the original object is returned.
205
device (int): The destination device id. Defaults to the current device.
206
non_blocking (bool): If ``True`` and the source is in pinned memory,
207
the copy will be asynchronous with respect to the host. Otherwise,
208
the argument has no effect.
212
device_idx = _normalization_device(custom_backend_name, device)
214
if getattr(self, f'is_{custom_backend_name}'):
216
if self.get_device() == device_idx:
220
raise RuntimeError(f"Can not support a sparse storage move to {custom_backend_name} backend")
222
untyped_storage = torch.UntypedStorage(
223
self.size(), device=torch.device(f'{custom_backend_name}:{device_idx}')
225
untyped_storage.copy_(self, non_blocking)
226
return untyped_storage
228
_check_register_once(torch.storage._StorageBase, custom_backend_name)
229
setattr(torch.storage._StorageBase, custom_backend_name, wrap_storage_to)
235
def wrap_typed_storage_backend(self: torch.storage.TypedStorage) -> bool:
236
torch.storage._warn_typed_storage_removal()
237
return self._untyped_storage.device.type == custom_backend_name
239
_check_register_once(torch.TypedStorage, f'is_{custom_backend_name}')
240
setattr(torch.storage.TypedStorage, f'is_{custom_backend_name}', wrap_typed_storage_backend)
242
def wrap_typed_storage_to(self: torch.storage.TypedStorage,
243
device=None, non_blocking=False, **kwargs) -> torch.storage.TypedStorage:
244
torch.storage._warn_typed_storage_removal()
245
if unsupported_dtype and self.dtype in unsupported_dtype:
246
raise RuntimeError(f"Cannot create {custom_backend_name} storage "
247
f"as {self.dtype} dtype is not supported by this backend")
248
custom_backend_storage: torch.UntypedStorage = getattr(
249
self._untyped_storage, custom_backend_name)(device, non_blocking, **kwargs)
250
return self._new_wrapped_storage(custom_backend_storage)
252
_check_register_once(torch.TypedStorage, custom_backend_name)
253
setattr(torch.TypedStorage, custom_backend_name, wrap_typed_storage_to)
256
def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module: bool = True,
257
for_storage: bool = False,
258
unsupported_dtype: Optional[List[torch.dtype]] = None) -> None:
260
Automatically generate attributes and methods for the custom backend after rename privateuse1 backend.
262
In the default scenario, storage-related methods will not be generated automatically.
264
When you implement kernels for various torch operations, and register them to the PrivateUse1 dispatch key.
265
And call the function torch.rename_privateuse1_backend("foo") to rename your backend name.
266
At this point, you can easily register specific methods and attributes by calling this function.
267
Just like torch.Tensor.foo(), torch.Tensor.is_foo, torch.Storage.foo(), torch.Storage.is_foo.
269
Note: We recommend you use generic functions (check devices are equal or to(device=)).
270
We provide these methods for convenience only and they will be "monkey patched" onto the objects
271
and so will not be properly typed. For Storage methods generate, if you need to support sparse data storage,
272
you need to extend the implementation yourself.
275
for_tensor (bool): whether register related methods for torch.Tensor class.
276
for_module (bool): whether register related methods for torch.nn.Module class.
277
for_storage (bool): whether register related methods for torch.Storage class.
278
unsupported_dtype (List[torch.dtype]): takes effect only when the storage method needs to be generated,
279
indicating that the storage does not support the torch.dtype type.
283
>>> # xdoctest: +SKIP("failing")
284
>>> torch.utils.rename_privateuse1_backend("foo")
285
>>> torch.utils.generate_methods_for_privateuse1_backend()
286
# Then automatically generate backend-related attributes and methods.
287
>>> a = torch.tensor(2).foo()
289
>>> hasattr(torch.nn.Module, 'foo')
291
custom_backend_name = _get_privateuse1_backend_name()
294
_generate_tensor_methods_for_privateuse1_backend(custom_backend_name)
297
_generate_module_methods_for_privateuse1_backend(custom_backend_name)
300
_generate_storage_methods_for_privateuse1_backend(custom_backend_name, unsupported_dtype)
302
def _get_custom_mod_func(func_name: str):
304
Return the func named `func_name` defined in custom device module. If not defined,
305
return `None`. And the func is registered with `torch.utils.rename_privateuse1_backend('foo')`
306
and `torch._register_device_module('foo', BackendModule)`.
307
If the custom device module or the func is not defined, it will give warning or error message.
309
func_name (str): return the callable func named func_name defined in custom device module.
311
class DummyfooModule:
316
def func_name(*args, **kwargs):
318
torch.utils.rename_privateuse1_backend("foo")
319
torch._register_device_module("foo", DummyfooModule)
320
foo_is_available_func = torch.utils.backend_registration._get_custom_mod_func("is_available")
321
if foo_is_available_func:
322
foo_is_available = foo_is_available_func()
323
func_ = torch.utils.backend_registration._get_custom_mod_func("func_name")
325
result = func_(*args, **kwargs)
326
Attention: This function is not meant to be used directly by users, which is why
327
it is marked as private. It is a convenience function for backend implementers to
328
more easily call the hooks into their backend extensions.
330
assert isinstance(func_name, str), f"func_name must be `str`, but got `{type(func_name)}`."
331
backend_name = _get_privateuse1_backend_name()
332
custom_device_mod = getattr(torch, backend_name, None)
333
function = getattr(custom_device_mod, func_name, None)
334
if custom_device_mod is None or function is None:
335
message = f'Try to call torch.{backend_name}.{func_name}. The backend must register a custom backend '
336
message += f"module with `torch._register_device_module('{backend_name}', BackendModule)`. And "
337
message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n"
338
raise RuntimeError(message)