pytorch

Форк
0
/
backend_registration.py 
339 строк · 16.4 Кб
1
import torch
2
from torch._C import _rename_privateuse1_backend, _get_privateuse1_backend_name
3
from typing import List, Optional, Union
4

5
__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"]
6

7
# TODO: Should use `torch._C._get_privateuse1_backend_name()` to get
8
# renamed-backend name for `privateuse1`, but the func will cause an
9
# error with torch.jit.script, so we use the global variable named
10
# `_privateuse1_backend_name`.
11
_privateuse1_backend_name = "privateuseone"
12

13
def rename_privateuse1_backend(backend_name: str) -> None:
14
    r"""
15
    Rename the privateuse1 backend device to make it more convenient to use as a device name within PyTorch APIs.
16

17
    The steps are:
18

19
    (1) (In C++) implement kernels for various torch operations, and register them
20
        to the PrivateUse1 dispatch key.
21
    (2) (In python) call torch.utils.rename_privateuse1_backend("foo")
22

23
    You can now use "foo" as an ordinary device string in python.
24

25
    Note: this API can only be called once per process. Attempting to change
26
    the external backend after it's already been set will result in an error.
27

28
    Note(AMP): If you want to support AMP on your device, you can register a custom backend module.
29
    The backend must register a custom backend module with ``torch._register_device_module("foo", BackendModule)``.
30
    BackendModule needs to have the following API's:
31

32
    (1) ``get_amp_supported_dtype() -> List[torch.dtype]``
33
        get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype.
34

35
    (2) ``is_autocast_enabled() -> bool``
36
        check the AMP is enabled or not on your "foo" device.
37

38
    (3) ``get_autocast_dtype() -> torch.dtype``
39
        get the supported dtype on your "foo" device in AMP, which is set by ``set_autocast_dtype`` or the
40
        default dtype, and the default dtype is ``torch.float16``.
41

42
    (4) ``set_autocast_enabled(bool) -> None``
43
        enable the AMP or not on your "foo" device.
44

45
    (5) ``set_autocast_dtype(dtype) -> None``
46
        set the supported dtype on your "foo" device in AMP, and the dtype be contained in the dtypes got
47
        from ``get_amp_supported_dtype``.
48

49
    Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's:
50

51
    (1) ``_is_in_bad_fork() -> bool``
52
        Return ``True`` if now it is in bad_fork, else return ``False``.
53

54
    (2) ``manual_seed_all(seed int) -> None``
55
        Sets the seed for generating random numbers for your devices.
56

57
    (3) ``device_count() -> int``
58
        Returns the number of "foo"s available.
59

60
    (4) ``get_rng_state(device: Union[int, str, torch.device] = 'foo') -> Tensor``
61
        Returns a list of ByteTensor representing the random number states of all devices.
62

63
    (5) ``set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> None``
64
        Sets the random number generator state of the specified "foo" device.
65

66
    And there are some common funcs:
67

68
    (1) ``is_available() -> bool``
69
        Returns a bool indicating if "foo" is currently available.
70

71
    (2) ``current_device() -> int``
72
        Returns the index of a currently selected device.
73

74
    For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend
75
    For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example
76

77
    Example::
78

79
        >>> # xdoctest: +SKIP("failing")
80
        >>> torch.utils.rename_privateuse1_backend("foo")
81
        # This will work, assuming that you've implemented the right C++ kernels
82
        # to implement torch.ones.
83
        >>> a = torch.ones(2, device="foo")
84

85
    """
86
    _rename_privateuse1_backend(backend_name)
87
    global _privateuse1_backend_name
88
    _privateuse1_backend_name = backend_name
89

90
def _check_register_once(module, attr):
91
    if hasattr(module, attr):
92
        raise RuntimeError(f"The custom device module of {module} has already been registered with {attr}")
93

94

95
def _normalization_device(custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None) -> int:
96
    def _get_current_device_index():
97
        _get_device_index = "current_device"
98
        if hasattr(torch, custom_backend_name) and \
99
                hasattr(getattr(torch, custom_backend_name), _get_device_index):
100
            return getattr(getattr(torch, custom_backend_name), _get_device_index)()
101
        else:
102
            # The default device index is 0.
103
            return 0
104

105
    if device is None:
106
        return _get_current_device_index()
107
    # if isinstance(device, str), this means that the parameter passed in is in the string format "foo:0"
108
    # convert str object to torch.device object, and then process it uniformly
109
    elif isinstance(device, str):
110
        device = torch.device(device)
111

112
    # variable devcie can only be torch.device type or int type
113
    if isinstance(device, torch.device):
114
        if device.type != custom_backend_name:
115
            raise RuntimeError(f"Invalid device, must be {custom_backend_name} device")
116
        elif device.index is None:
117
            device_idx = _get_current_device_index()
118
        else:
119
            device_idx = device.index
120
    # if isinstance(device, int), we can take the index number directly
121
    else:
122
        device_idx = device
123
    return device_idx
124

125

126
def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
127
    @property  # type: ignore[misc]
128
    def wrap_tensor_backend(self: torch.Tensor) -> bool:
129
        return self.device.type == custom_backend_name
130

131
    _check_register_once(torch.Tensor, f'is_{custom_backend_name}')
132
    setattr(torch.Tensor, f'is_{custom_backend_name}', wrap_tensor_backend)
133

134
    def wrap_tensor_to(self: torch.Tensor, device: Optional[Union[int, torch.device]] = None, non_blocking=False,
135
                       **kwargs) -> torch.Tensor:
136
        r"""Perform Tensor device conversion. Call the to operator implementation.
137

138
        .. note::
139
            If the ``self`` Tensor already
140
            has the correct :class:`torch.device`, then ``self`` is returned.
141
            Otherwise, the returned tensor is a copy of ``self`` with the desired :class:`torch.device`.
142

143
        Args:
144
            device (int, optional): if specified, all parameters will be copied to that device
145
            non_blocking (bool): If ``True`` and the source is in pinned memory,
146
                the copy will be asynchronous with respect to the host. Otherwise,
147
                the argument has no effect.
148
            **kwargs (dict): For compatibility, may contain the key ``memory_format`` argument.
149
        """
150
        device_idx = _normalization_device(custom_backend_name, device)
151
        return self.to(device=torch.device(f'{custom_backend_name}:{device_idx}'), non_blocking=non_blocking, **kwargs)
152

153
    _check_register_once(torch.Tensor, custom_backend_name)
154
    setattr(torch.Tensor, custom_backend_name, wrap_tensor_to)
155

156

157
def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
158
    # Generate Module attributes and methods depends on Tensor methods,
159
    # so we need to check whether Tensor methods is already registered.
160
    if not hasattr(torch.Tensor, custom_backend_name):
161
        raise RuntimeError(
162
            f"Can not automatically generate {custom_backend_name}() method for torch.nn.Module."
163
            f"Because torch.Tensor doesn't has the method {custom_backend_name}()."
164
            f"For this error, you can try setting for_tensor=True.")
165

166
    def wrap_module_to(self: torch.nn.modules.module.T,
167
                       device: Optional[Union[int, torch.device]] = None) -> torch.nn.modules.module.T:
168
        r"""Move all model parameters and buffers to the custom device.
169

170
        This also makes associated parameters and buffers different objects. So
171
        it should be called before constructing optimizer if the module will
172
        live on device while being optimized.
173

174
        .. note::
175
            This method modifies the module in-place.
176

177
        Args:
178
            device (int, optional): if specified, all parameters will be copied to that device
179
        """
180
        return self._apply(lambda t: getattr(t, custom_backend_name)(device))
181

182
    _check_register_once(torch.nn.Module, custom_backend_name)
183
    setattr(torch.nn.Module, custom_backend_name, wrap_module_to)
184

185

186
def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
187
                                                      unsupported_dtype: Optional[List[torch.dtype]] = None) -> None:
188
    # Attribute is registered in the _StorageBase class
189
    # and UntypedStorage obtains through inheritance.
190
    @property  # type: ignore[misc]
191
    def wrap_storage_backend(self: torch.storage._StorageBase) -> bool:
192
        r"""Return the internal :class:`torch.UntypedStorage`."""
193
        return self.device.type == custom_backend_name
194

195
    _check_register_once(torch.storage._StorageBase, f'is_{custom_backend_name}')
196
    setattr(torch.storage._StorageBase, f'is_{custom_backend_name}', wrap_storage_backend)
197

198
    def wrap_storage_to(self, device=None, non_blocking=False):
199
        r"""Return a copy of this object in custom device memory.
200

201
        If this object is already in device memory and on the correct device, then
202
        no copy is performed and the original object is returned.
203

204
        Args:
205
            device (int): The destination device id. Defaults to the current device.
206
            non_blocking (bool): If ``True`` and the source is in pinned memory,
207
            the copy will be asynchronous with respect to the host. Otherwise,
208
            the argument has no effect.
209
        """
210
        # There should be a judgment related to storage device and a judgment related to storage type,
211
        # but it depends on the extended function, so this part is temporarily omitted in the automatic generation.
212
        device_idx = _normalization_device(custom_backend_name, device)
213

214
        if getattr(self, f'is_{custom_backend_name}'):
215
            # storage has already on expected device.
216
            if self.get_device() == device_idx:
217
                return self
218
        # For sparse storage, custom need to extend the implementation by themselves.
219
        if self.is_sparse:
220
            raise RuntimeError(f"Can not support a sparse storage move to {custom_backend_name} backend")
221
        # create untyped_storage and copy data
222
        untyped_storage = torch.UntypedStorage(
223
            self.size(), device=torch.device(f'{custom_backend_name}:{device_idx}')
224
        )
225
        untyped_storage.copy_(self, non_blocking)
226
        return untyped_storage
227

228
    _check_register_once(torch.storage._StorageBase, custom_backend_name)
229
    setattr(torch.storage._StorageBase, custom_backend_name, wrap_storage_to)
230

231
    # Register the corresponding attribute for the TypedStorage class.
232
    # When the TypedStorage class is removed, the registration is also removed.
233

234
    @property  # type: ignore[misc]
235
    def wrap_typed_storage_backend(self: torch.storage.TypedStorage) -> bool:
236
        torch.storage._warn_typed_storage_removal()
237
        return self._untyped_storage.device.type == custom_backend_name
238

239
    _check_register_once(torch.TypedStorage, f'is_{custom_backend_name}')
240
    setattr(torch.storage.TypedStorage, f'is_{custom_backend_name}', wrap_typed_storage_backend)
241

242
    def wrap_typed_storage_to(self: torch.storage.TypedStorage,
243
                              device=None, non_blocking=False, **kwargs) -> torch.storage.TypedStorage:
244
        torch.storage._warn_typed_storage_removal()
245
        if unsupported_dtype and self.dtype in unsupported_dtype:
246
            raise RuntimeError(f"Cannot create {custom_backend_name} storage "
247
                               f"as {self.dtype} dtype is not supported by this backend")
248
        custom_backend_storage: torch.UntypedStorage = getattr(
249
            self._untyped_storage, custom_backend_name)(device, non_blocking, **kwargs)
250
        return self._new_wrapped_storage(custom_backend_storage)
251

252
    _check_register_once(torch.TypedStorage, custom_backend_name)
253
    setattr(torch.TypedStorage, custom_backend_name, wrap_typed_storage_to)
254

255

256
def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module: bool = True,
257
                                             for_storage: bool = False,
258
                                             unsupported_dtype: Optional[List[torch.dtype]] = None) -> None:
259
    r"""
260
    Automatically generate attributes and methods for the custom backend after rename privateuse1 backend.
261

262
    In the default scenario, storage-related methods will not be generated automatically.
263

264
    When you implement kernels for various torch operations, and register them to the PrivateUse1 dispatch key.
265
    And call the function torch.rename_privateuse1_backend("foo") to rename your backend name.
266
    At this point, you can easily register specific methods and attributes by calling this function.
267
    Just like torch.Tensor.foo(), torch.Tensor.is_foo, torch.Storage.foo(), torch.Storage.is_foo.
268

269
    Note: We recommend you use generic functions (check devices are equal or to(device=)).
270
    We provide these methods for convenience only and they will be "monkey patched" onto the objects
271
    and so will not be properly typed. For Storage methods generate, if you need to support sparse data storage,
272
    you need to extend the implementation yourself.
273

274
    Args:
275
        for_tensor (bool): whether register related methods for torch.Tensor class.
276
        for_module (bool): whether register related methods for torch.nn.Module class.
277
        for_storage (bool): whether register related methods for torch.Storage class.
278
        unsupported_dtype (List[torch.dtype]): takes effect only when the storage method needs to be generated,
279
            indicating that the storage does not support the torch.dtype type.
280

281
    Example::
282

283
        >>> # xdoctest: +SKIP("failing")
284
        >>> torch.utils.rename_privateuse1_backend("foo")
285
        >>> torch.utils.generate_methods_for_privateuse1_backend()
286
        # Then automatically generate backend-related attributes and methods.
287
        >>> a = torch.tensor(2).foo()
288
        >>> a.is_foo
289
        >>> hasattr(torch.nn.Module, 'foo')
290
    """
291
    custom_backend_name = _get_privateuse1_backend_name()
292

293
    if for_tensor:
294
        _generate_tensor_methods_for_privateuse1_backend(custom_backend_name)
295

296
    if for_module:
297
        _generate_module_methods_for_privateuse1_backend(custom_backend_name)
298

299
    if for_storage:
300
        _generate_storage_methods_for_privateuse1_backend(custom_backend_name, unsupported_dtype)
301

302
def _get_custom_mod_func(func_name: str):
303
    r"""
304
    Return the func named `func_name` defined in custom device module. If not defined,
305
    return `None`. And the func is registered with `torch.utils.rename_privateuse1_backend('foo')`
306
    and `torch._register_device_module('foo', BackendModule)`.
307
    If the custom device module or the func is not defined, it will give warning or error message.
308
    Args:
309
        func_name (str): return the callable func named func_name defined in custom device module.
310
    Example::
311
        class DummyfooModule:
312
            @staticmethod
313
            def is_available():
314
                return True
315
            @staticmethod
316
            def func_name(*args, **kwargs):
317
                ....
318
        torch.utils.rename_privateuse1_backend("foo")
319
        torch._register_device_module("foo", DummyfooModule)
320
        foo_is_available_func = torch.utils.backend_registration._get_custom_mod_func("is_available")
321
        if foo_is_available_func:
322
            foo_is_available = foo_is_available_func()
323
        func_ = torch.utils.backend_registration._get_custom_mod_func("func_name")
324
        if func_:
325
            result = func_(*args, **kwargs)
326
    Attention: This function is not meant to be used directly by users, which is why
327
    it is marked as private. It is a convenience function for backend implementers to
328
    more easily call the hooks into their backend extensions.
329
    """
330
    assert isinstance(func_name, str), f"func_name must be `str`, but got `{type(func_name)}`."
331
    backend_name = _get_privateuse1_backend_name()
332
    custom_device_mod = getattr(torch, backend_name, None)  # type: ignore[arg-type]
333
    function = getattr(custom_device_mod, func_name, None)  # type: ignore[arg-type]
334
    if custom_device_mod is None or function is None:
335
        message = f'Try to call torch.{backend_name}.{func_name}. The backend must register a custom backend '
336
        message += f"module with `torch._register_device_module('{backend_name}', BackendModule)`. And "
337
        message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n"
338
        raise RuntimeError(message)
339
    return function
340

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.