pytorch

Форк
0
/
_utils.py 
937 строк · 34.0 Кб
1
import copyreg
2
import functools
3
import sys
4
import traceback
5
import warnings
6
from collections import defaultdict
7
from typing import Any, DefaultDict, List, Optional
8

9
import torch
10

11

12
def _type(self, dtype=None, non_blocking=False, **kwargs):
13
    """Returns the type if `dtype` is not provided, else casts this object to
14
    the specified type.
15

16
    If this is already of the correct type, no copy is performed and the
17
    original object is returned.
18

19
    Args:
20
        dtype (type or string): The desired type
21
        non_blocking (bool): If ``True``, and the source is in pinned memory
22
            and destination is on the GPU or vice versa, the copy is performed
23
            asynchronously with respect to the host. Otherwise, the argument
24
            has no effect.
25
        **kwargs: For compatibility, may contain the key ``async`` in place of
26
            the ``non_blocking`` argument. The ``async`` arg is deprecated.
27
    """
28
    non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs)
29
    if dtype is None:
30
        return self.__module__ + "." + self.__class__.__name__
31

32
    if isinstance(dtype, str):
33
        dtype = _import_dotted_name(dtype)
34
    if dtype == type(self):
35
        return self
36
    if self.is_sparse:
37
        if not dtype.is_sparse:
38
            raise RuntimeError("Cannot cast sparse tensor to dense tensor")
39
        new_module_name = dtype.__module__.replace(".sparse", "")
40
        new_values_type_name = new_module_name + "." + dtype.__name__
41
        new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking)
42
        new_indices_type_name = new_module_name + ".LongTensor"
43
        new_indices = torch.Tensor._indices(self).type(
44
            new_indices_type_name, non_blocking
45
        )
46
        return dtype(new_indices, new_values, self.size())
47
    if dtype.is_sparse:
48
        raise RuntimeError("Cannot cast dense tensor to sparse tensor")
49
    return dtype(self.size()).copy_(self, non_blocking)
50

51

52
def _hpu(self, device=None, non_blocking=False, **kwargs):
53
    """Returns a copy of this object in HPU memory.
54

55
    If this object is already in HPU memory and on the correct device, then
56
    no copy is performed and the original object is returned.
57

58
    Args:
59
        device (int): The destination HPU id. Defaults to the current device.
60
        non_blocking (bool): If ``True`` and the source is in pinned memory,
61
            the copy will be asynchronous with respect to the host. Otherwise,
62
            the argument has no effect.
63
        **kwargs: For compatibility, may contain the key ``async`` in place of
64
            the ``non_blocking`` argument.
65
    """
66
    non_blocking = _get_async_or_non_blocking("hpu", non_blocking, kwargs)
67
    hpu = getattr(torch, "hpu", None)
68
    assert hpu is not None, "HPU device module is not loaded"
69
    if self.is_hpu:
70
        if device is None:
71
            device = hpu.current_device()
72
        if self.get_device() == device:
73
            return self
74
    else:
75
        if device is None:
76
            device = -1
77
    with hpu.device(device):
78
        assert not self.is_sparse, "sparse storage is not supported for HPU tensors"
79
        untyped_storage = torch.UntypedStorage(self.size(), device=torch.device("hpu"))
80
        untyped_storage.copy_(self, non_blocking)
81
        return untyped_storage
82

83

84
def _cuda(self, device=None, non_blocking=False, **kwargs):
85
    """Returns a copy of this object in CUDA memory.
86

87
    If this object is already in CUDA memory and on the correct device, then
88
    no copy is performed and the original object is returned.
89

90
    Args:
91
        device (int): The destination GPU id. Defaults to the current device.
92
        non_blocking (bool): If ``True`` and the source is in pinned memory,
93
            the copy will be asynchronous with respect to the host. Otherwise,
94
            the argument has no effect.
95
        **kwargs: For compatibility, may contain the key ``async`` in place of
96
            the ``non_blocking`` argument.
97
    """
98
    non_blocking = _get_async_or_non_blocking("cuda", non_blocking, kwargs)
99
    if self.is_cuda:
100
        if device is None:
101
            device = torch.cuda.current_device()
102
        if self.get_device() == device:
103
            return self
104
    else:
105
        if device is None:
106
            device = -1
107
    with torch.cuda.device(device):
108
        if self.is_sparse:
109
            new_type = getattr(torch.cuda.sparse, self.__class__.__name__)
110
            indices = torch.Tensor._indices(self).cuda(device, non_blocking)
111
            values = torch.Tensor._values(self).cuda(device, non_blocking)
112
            return new_type(indices, values, self.size())
113
        else:
114
            untyped_storage = torch.UntypedStorage(
115
                self.size(), device=torch.device("cuda")
116
            )
117
            untyped_storage.copy_(self, non_blocking)
118
            return untyped_storage
119

120

121
def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
122
    """Return the non-blocking flag given the function name and kwargs.
123

124
    Args:
125
        function_name (str): the name of the function being used.
126
        non_blocking (bool): the default value.
127
        **kwargs (dict): the kwargs passed to the function.
128
    """
129
    if not kwargs:
130
        return non_blocking
131
    if len(kwargs) != 1 or "async" not in kwargs:
132
        message = "{}() got an unexpected keyword argument '{}'"
133
        argument = list(kwargs.keys()).pop()
134
        raise TypeError(message.format(function_name, argument))
135
    warnings.warn("'async' is deprecated; use 'non_blocking'")
136
    return kwargs["async"]
137

138

139
# Note [Don't serialize hooks]
140
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
141
# Since time immemorial, we have serialized the backward hooks associated with
142
# variables.  This kind of half-worked--Python can pickle global functions
143
# (but not closures!)--but there were problems.
144
#
145
#   - It's fragile.  If you serialize a backward hook into a saved
146
#     model, and then you rename the function associated with the hook,
147
#     now your saved model is broken and you can't load it anymore.
148
#
149
#   - It's not actually used.  The standard recommendation is to
150
#     serialize the *state_dict* of a model, not the model itself
151
#     (since this is more stable to code changes affecting the model
152
#     serialization), and the state dict saves "data" only, thus
153
#     stripping the backward hooks.  In some cases, hooks are
154
#     essential to the well-functioning of a model (e.g., DDP),
155
#     but DDP already manages readding the hooks!
156
#
157
#   - We didn't serialize them in many cases.  Prior to #10220, we
158
#     were dropping backward hooks in ForkingPickler.  We "fixed" this
159
#     to be convenient with other serialization sites, but lack of
160
#     serializing backward hooks wasn't actually the root cause of
161
#     the bug.
162
#
163
# With these cases in mind, we have decided that a better strategy
164
# is to just NOT serialize hooks at all.
165
#
166
# Since this is a BC-breaking change, we should warn when we previously
167
# serialized a hook, but no longer do so. This will be done by adding a special
168
# sentinel property to hooks will be used to suppress this warning. If a hook
169
# has the property _torch_serialize_ignore, we will not emit a warning if we
170
# attempt to serialize a Tensor with this hook attached to it.
171
#
172
# By the way, when _backward_hooks is skipped, we must give an EMPTY
173
# OrderedDict(), if you pass a None you'll run afoul #12219.
174

175

176
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
177
# be a TypedStorage
178
def _rebuild_tensor(storage, storage_offset, size, stride):
179
    # first construct a tensor with the correct dtype/device
180
    t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device)
181
    return t.set_(storage._untyped_storage, storage_offset, size, stride)
182

183

184
def get_tensor_metadata(tensor):
185
    # Tensor's Metadata for serializing.
186
    # Currently, this only returns a dict[string, bool] specifing whether
187
    # `conj` or `neg` bit is set.
188
    assert isinstance(tensor, torch.Tensor)
189
    return torch._C._get_tensor_metadata(tensor)  # type: ignore[attr-defined]
190

191

192
def set_tensor_metadata(tensor, metadata):
193
    # See `get_tensor_metadata` above
194
    assert isinstance(metadata, dict)
195
    assert isinstance(tensor, torch.Tensor)
196
    torch._C._set_tensor_metadata(tensor, metadata)  # type: ignore[attr-defined]
197

198

199
def _rebuild_tensor_v2(
200
    storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None
201
):
202
    tensor = _rebuild_tensor(storage, storage_offset, size, stride)
203
    tensor.requires_grad = requires_grad
204
    if metadata:
205
        set_tensor_metadata(tensor, metadata)
206

207
    # NB: This line exists only for backwards compatibility; the
208
    # general expectation is that backward_hooks is an empty
209
    # OrderedDict.  See Note [Don't serialize hooks]
210
    tensor._backward_hooks = backward_hooks
211
    return tensor
212

213

214
def _rebuild_tensor_v3(
215
    storage,
216
    storage_offset,
217
    size,
218
    stride,
219
    requires_grad,
220
    backward_hooks,
221
    dtype,
222
    metadata=None,
223
):
224
    t = torch.empty(
225
        (0,),
226
        dtype=dtype,
227
        device=storage._untyped_storage.device,
228
        requires_grad=requires_grad,
229
    )
230
    t.set_(storage._untyped_storage, storage_offset, size, stride)
231
    if metadata:
232
        set_tensor_metadata(t, metadata)
233
    t._backward_hooks = backward_hooks
234
    return t
235

236

237
_sparse_tensors_to_validate: List["torch.Tensor"] = []
238

239

240
# In _legacy_load() in serialization.py we unpickle storages after the sparse
241
# tensors have been already unpickled. Those storages contain data necessary for
242
# validating sparse tensors: indices and values. That's why sparse tensors are
243
# first unpickled without any validation, and then this function is called just
244
# before _legacy_load() returns, so that all the sparse tensors can be validated
245
# in bulk.
246
#
247
# The same procedure must be followed by _load() in serialization.py because due
248
# to Pickler semantics, we have to use the same (non-validating) function for
249
# unpickling sparse tensors, regardless of the caller.
250
def _validate_loaded_sparse_tensors():
251
    try:
252
        for t in _sparse_tensors_to_validate:
253
            if t.layout is torch.sparse_coo:
254
                torch._validate_sparse_coo_tensor_args(
255
                    t._indices(), t._values(), t.size(), t.is_coalesced()
256
                )
257
            elif t.layout in {
258
                torch.sparse_csr,
259
                torch.sparse_csc,
260
                torch.sparse_bsr,
261
                torch.sparse_bsc,
262
            }:
263
                # TODO: Validation currently involves an expensive traversal
264
                # on CPU, which may include a device transfer.
265
                if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
266
                    compressed_indices, plain_indices = (
267
                        t.crow_indices(),
268
                        t.col_indices(),
269
                    )
270
                else:
271
                    compressed_indices, plain_indices = (
272
                        t.ccol_indices(),
273
                        t.row_indices(),
274
                    )
275
                torch._validate_sparse_compressed_tensor_args(
276
                    compressed_indices, plain_indices, t.values(), t.size(), t.layout
277
                )
278
            else:
279
                raise NotImplementedError(
280
                    f"_validate_loaded_sparse_tensors for layout `{t.layout}`"
281
                )
282

283
    finally:
284
        _sparse_tensors_to_validate.clear()
285

286

287
def _rebuild_sparse_tensor(layout, data):
288
    """
289
    Rebuilds a sparse tensor from its sparse storage representation.
290

291
    Args:
292
        layout (str): The sparse storage layout of the tensor.
293
        data (tuple): The tensor's sparse storage representation.
294
    """
295
    if layout == torch.sparse_coo:
296
        if len(data) == 3:
297
            # For BC:
298
            indices, values, size = data
299
            is_coalesced = None
300
        else:
301
            indices, values, size, is_coalesced = data
302
        result = torch.sparse_coo_tensor(
303
            indices, values, size, check_invariants=False, is_coalesced=is_coalesced
304
        )
305
        _sparse_tensors_to_validate.append(result)
306
        return result
307

308
    elif layout in {
309
        torch.sparse_csr,
310
        torch.sparse_csc,
311
        torch.sparse_bsr,
312
        torch.sparse_bsc,
313
    }:
314
        compressed_indices, plain_indices, values, size = data
315
        result = torch.sparse_compressed_tensor(
316
            compressed_indices,
317
            plain_indices,
318
            values,
319
            size,
320
            layout=layout,
321
            check_invariants=False,
322
        )
323
        _sparse_tensors_to_validate.append(result)
324
        return result
325

326
    raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}")
327

328

329
def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
330
    return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets)
331

332

333
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
334
    tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
335
    tensor.requires_grad = requires_grad
336
    return tensor
337

338

339
# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch
340
_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
341

342

343
def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
344
    return torch.empty_strided(
345
        size, stride, dtype=dtype, device="meta", requires_grad=requires_grad
346
    )
347

348

349
def _rebuild_wrapper_subclass(
350
    cls, dtype, size, stride, storage_offset, layout, device, requires_grad
351
):
352
    return torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
353
        cls,
354
        size,
355
        strides=stride,
356
        storage_offset=storage_offset,
357
        layout=layout,
358
        device=device,
359
        requires_grad=requires_grad,
360
    )
361

362

363
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
364
# be a TypedStorage
365
def _rebuild_qtensor(
366
    storage,
367
    storage_offset,
368
    size,
369
    stride,
370
    quantizer_params,
371
    requires_grad,
372
    backward_hooks,
373
):
374
    qscheme = quantizer_params[0]
375
    if qscheme == torch.per_tensor_affine:
376
        _, scale, zero_point = quantizer_params
377
        tensor = torch._empty_affine_quantized(
378
            size,
379
            scale=scale,
380
            zero_point=zero_point,
381
            dtype=storage.dtype,
382
            device=storage.device,
383
        )
384
    elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
385
        _, scales, zero_points, axis = quantizer_params
386
        if type(scales) is list and type(zero_points) is list:
387
            if qscheme == torch.per_channel_affine:
388
                scales = torch.tensor(scales, dtype=torch.double, device=storage.device)
389
                zero_points = torch.tensor(
390
                    zero_points, dtype=torch.long, device=storage.device
391
                )
392
            else:
393
                scales = torch.tensor(scales, dtype=torch.float, device=storage.device)
394
                zero_points = torch.tensor(
395
                    zero_points, dtype=torch.float, device=storage.device
396
                )
397
        tensor = torch._empty_per_channel_affine_quantized(
398
            size,
399
            scales=scales,
400
            zero_points=zero_points,
401
            axis=axis,
402
            dtype=storage.dtype,
403
            device=storage.device,
404
        )
405
    else:
406
        raise RuntimeError(f"Can't deserialize quantized tensor with qscheme {qscheme}")
407
    tensor.set_(storage, storage_offset, size, stride)
408
    tensor.requires_grad = requires_grad
409
    # NB: This line exists only for backwards compatibility; the
410
    # general expectation is that backward_hooks is an empty
411
    # OrderedDict.  See Note [Don't serialize hooks]
412
    tensor._backward_hooks = backward_hooks
413
    return tensor
414

415

416
def _rebuild_parameter(data, requires_grad, backward_hooks):
417
    param = torch.nn.Parameter(data, requires_grad)
418
    # NB: This line exists only for backwards compatibility; the
419
    # general expectation is that backward_hooks is an empty
420
    # OrderedDict.  See Note [Don't serialize hooks]
421
    param._backward_hooks = backward_hooks
422

423
    return param
424

425

426
def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state):
427
    param = torch.nn.Parameter(data, requires_grad)
428
    # NB: This line exists only for backwards compatibility; the
429
    # general expectation is that backward_hooks is an empty
430
    # OrderedDict.  See Note [Don't serialize hooks]
431
    param._backward_hooks = backward_hooks
432

433
    # Restore state on Parameter like python attr.
434
    param = _set_obj_state(param, state)
435
    return param
436

437

438
def _get_obj_state(obj):
439
    # Get the state of the python subclass
440
    # This loosely mimicks the function on the object class but since Tensor do not inherit
441
    # from it, we cannot call that function directly
442
    # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891
443
    # Note that starting with Python 3.11, this `__getstate__` is always defined and thus
444
    # the else branch will never be taken.
445
    getstate_fn = getattr(obj, "__getstate__", None)
446
    if getstate_fn:
447
        state = getstate_fn()
448
    else:
449
        slots_to_save = copyreg._slotnames(obj.__class__)  # type: ignore[attr-defined]
450
        if slots_to_save:
451
            state = (
452
                obj.__dict__,
453
                {
454
                    name: getattr(obj, name)
455
                    for name in slots_to_save
456
                    if hasattr(obj, name)
457
                },
458
            )
459
        else:
460
            state = obj.__dict__
461

462
    return state
463

464

465
def _set_obj_state(obj, state):
466
    if isinstance(state, tuple):
467
        if not len(state) == 2:
468
            raise RuntimeError(f"Invalid serialized state: {state}")
469
        dict_state = state[0]
470
        slots_state = state[1]
471
    else:
472
        dict_state = state
473
        slots_state = None
474

475
    # Starting with Python 3.11, the __dict__ attribute is lazily created
476
    # and is serialized as None when not needed.
477
    if dict_state:
478
        for k, v in dict_state.items():
479
            setattr(obj, k, v)
480

481
    if slots_state:
482
        for k, v in slots_state.items():
483
            setattr(obj, k, v)
484
    return obj
485

486

487
def _import_dotted_name(name):
488
    components = name.split(".")
489
    obj = __import__(components[0])
490
    for component in components[1:]:
491
        obj = getattr(obj, component)
492
    return obj
493

494

495
def _flatten_dense_tensors(tensors):
496
    """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
497
    same dense type.
498

499
    Since inputs are dense, the resulting tensor will be a concatenated 1D
500
    buffer. Element-wise operation on this buffer will be equivalent to
501
    operating individually.
502

503
    Args:
504
        tensors (Iterable[Tensor]): dense tensors to flatten.
505

506
    Returns:
507
        A contiguous 1D buffer containing input tensors.
508
    """
509
    return torch._C._nn.flatten_dense_tensors(tensors)
510

511

512
def _flatten_sparse_tensors(tensors):
513
    """Flatten sparse tensors into two contiguous 1D buffers, one of indices and
514
    one of values. Assume tensors are of same sparse type.
515

516
    Args:
517
        tensors (Iterable[Tensor]): sparse tensors to flatten.
518

519
    Returns:
520
        A tuple of two contiguous 1D buffers, one containing input tensors'
521
        indices and the other containing the values.
522
    """
523
    flat_indices = torch._C._nn.flatten_dense_tensors(
524
        [torch.Tensor._indices(t) for t in tensors]
525
    )
526
    flat_values = torch._C._nn.flatten_dense_tensors(
527
        [torch.Tensor._values(t) for t in tensors]
528
    )
529
    return flat_indices, flat_values
530

531

532
def _unflatten_dense_tensors(flat, tensors):
533
    """View a flat buffer using the sizes of tensors. Assume that tensors are of
534
    same dense type, and that flat is given by _flatten_dense_tensors.
535

536
    Args:
537
        flat (Tensor): flattened dense tensors to unflatten.
538
        tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
539
          unflatten flat.
540

541
    Returns:
542
        Unflattened dense tensors with sizes same as tensors and values from
543
        flat.
544
    """
545
    return torch._C._nn.unflatten_dense_tensors(flat, tensors)
546

547

548
def _unflatten_sparse_tensors(flat, tensors):
549
    """View flat buffer (containing indices and values) using the sizes of
550
    tensors. Assume that tensors are of same sparse type, and that flat is given
551
    by _flatten_sparse_tensors.
552

553
    Args:
554
        flat (tuple(Tensor, Tensor)): flattened indices and values of sparse
555
          tensors to unflatten.
556
        tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to
557
          unflatten flat.
558

559
    Returns:
560
        Unflattened sparse tensors with sizes same as tensors and values from
561
        flat.
562
    """
563
    flat_indices, flat_values = flat
564
    indices = torch._C._nn.unflatten_dense_tensors(
565
        flat_indices, [torch.Tensor._indices(t) for t in tensors]
566
    )
567
    values = torch._C._nn.unflatten_dense_tensors(
568
        flat_values, [torch.Tensor._values(t) for t in tensors]
569
    )
570
    outputs = []
571
    for t, i, v in zip(tensors, indices, values):
572
        outputs.append(t.new(i, v, t.size()))
573
    return tuple(outputs)
574

575

576
def _reorder_tensors_as(tensors, ordered_tensors):
577
    """Assume that tensors are of same order as ordered_tensors within their
578
    types, e.g., from _take_tensors. Reorder them to be of same order as
579
    ordered_tensors.
580

581
    Args:
582
        tensors (Iterable[Tensor]): tensors to be reordered. They should be of
583
          the same order as ordered_tensors within their own types.
584
        ordered_tensors (Iterable[Tensor]): tensors whose order will be the
585
          reference.
586

587
    Returns:
588
        Ordered tuple of tensors with contents from tensors and order of
589
        ordered_tensors.
590
    """
591
    type_dict = defaultdict(list)
592
    for tensor in tensors:
593
        type_dict[tensor.type()].append(tensor)
594
    type_dict_ = {t: iter(coll) for t, coll in type_dict.items()}
595
    return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors)
596

597

598
def _take_tensors(tensors, size_limit):
599
    """Group tensors into chunks. This generator yields a chunk at each time,
600
    each containing tensors of same type up to certain byte limit in total size.
601

602
    Args:
603
        tensors (Sequence): A sequence of tensors to be separated into chunks.
604
        size_limit (int): The limit of each chunk in bytes.
605

606
    Yields:
607
        Blocks of tensors of same type and within size_limit. The yielded
608
        tensors are only ordered as the original sequence within its types.
609
    """
610
    buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0])
611
    for tensor in tensors:
612
        t = tensor.type()
613
        if tensor.is_sparse:
614
            indices = torch.Tensor._indices(tensor)
615
            values = torch.Tensor._values(tensor)
616
            size = (
617
                indices.numel() * indices.element_size()
618
                + values.numel() * values.element_size()
619
            )
620
        else:
621
            size = tensor.numel() * tensor.element_size()
622
        buf_and_size = buf_dict[t]
623
        if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
624
            yield buf_and_size[0]
625
            buf_and_size = buf_dict[t] = [[], 0]
626
        buf_and_size[0].append(tensor)
627
        buf_and_size[1] += size
628
    for buf, _ in buf_dict.values():
629
        if len(buf) > 0:
630
            yield buf
631

632

633
# annotation decorator to get annotations in a way that is compatible
634
# with both Python 2 and 3
635
def annotate(ret, **kwargs):
636
    def dec(fun):
637
        fun.__annotations__ = dict(kwargs)
638
        fun.__annotations__["return"] = ret
639
        return fun
640

641
    return dec
642

643

644
def render_call(fn, args, kwargs):
645
    str_fn = torch.overrides.resolve_name(fn)
646
    if str_fn is None:
647
        str_fn = str(fn)
648

649
    str_args: List[str] = []
650
    with torch._tensor_str.printoptions(threshold=0, edgeitems=0):
651
        str_args.extend(repr(a) for a in args)
652
        str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items())
653
        r = f"{str_fn}({', '.join(str_args)})"
654
    return r
655

656

657
# NOTE [ Python Traceback Reference Cycle Problem ]
658
#
659
# When using sys.exc_info(), it is important to **not** store the exc_info[2],
660
# which is the traceback, because otherwise you will run into the traceback
661
# reference cycle problem, i.e., the traceback holding reference to the frame,
662
# and the frame (which holds reference to all the object in its temporary scope)
663
# holding reference the traceback.
664

665

666
class KeyErrorMessage(str):
667
    r"""str subclass that returns itself in repr"""
668

669
    def __repr__(self):
670
        return self
671

672

673
class ExceptionWrapper:
674
    r"""Wraps an exception plus traceback to communicate across threads"""
675

676
    def __init__(self, exc_info=None, where="in background"):
677
        # It is important that we don't store exc_info, see
678
        # NOTE [ Python Traceback Reference Cycle Problem ]
679
        if exc_info is None:
680
            exc_info = sys.exc_info()
681
        self.exc_type = exc_info[0]
682
        self.exc_msg = "".join(traceback.format_exception(*exc_info))
683
        self.where = where
684

685
    def reraise(self):
686
        r"""Reraises the wrapped exception in the current thread"""
687
        # Format a message such as: "Caught ValueError in DataLoader worker
688
        # process 2. Original Traceback:", followed by the traceback.
689
        msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}"
690
        if self.exc_type == KeyError:
691
            # KeyError calls repr() on its argument (usually a dict key). This
692
            # makes stack traces unreadable. It will not be changed in Python
693
            # (https://bugs.python.org/issue2651), so we work around it.
694
            msg = KeyErrorMessage(msg)
695
        elif getattr(self.exc_type, "message", None):
696
            # Some exceptions have first argument as non-str but explicitly
697
            # have message field
698
            raise self.exc_type(message=msg)
699
        try:
700
            exception = self.exc_type(msg)
701
        except TypeError:
702
            # If the exception takes multiple arguments, don't try to
703
            # instantiate since we don't know how to
704
            raise RuntimeError(msg) from None
705
        raise exception
706

707

708
def _get_available_device_type():
709
    if torch.cuda.is_available():
710
        return "cuda"
711
    if hasattr(torch, "xpu") and torch.xpu.is_available():  # type: ignore[attr-defined]
712
        return "xpu"
713
    custom_backend_name = torch._C._get_privateuse1_backend_name()
714
    custom_device_mod = getattr(torch, custom_backend_name, None)
715
    if custom_device_mod and custom_device_mod.is_available():
716
        return custom_backend_name
717
    # add more available device types here
718
    return None
719

720

721
def _get_device_attr(get_member):
722
    device_type = _get_available_device_type()
723
    if device_type and device_type.lower() == "cuda":
724
        return get_member(torch.cuda)
725
    if device_type and device_type.lower() == "xpu":
726
        return get_member(torch.xpu)  # type: ignore[attr-defined]
727
    if device_type == torch._C._get_privateuse1_backend_name():
728
        return get_member(getattr(torch, device_type))
729
    # add more available device types here
730
    return None
731

732

733
def _get_current_device_index():
734
    # current device index
735
    return _get_device_attr(lambda m: m.current_device())
736

737

738
def _get_all_device_indices():
739
    # all device index
740
    return _get_device_attr(lambda m: list(range(m.device_count())))
741

742

743
def _get_devices_properties(device_ids):
744
    # all device properties
745
    return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
746

747

748
def get_current_device_index() -> int:
749
    r"""Checks if there are CUDA devices available and
750
    returns the device index of the current default CUDA device.
751
    Returns -1 in case there are no CUDA devices available.
752
    Arguments: ``None``
753
    """
754
    if torch.cuda.device_count() > 0:
755
        return torch.cuda.current_device()
756
    return -1
757

758

759
def _get_device_index(
760
    device: Any, optional: bool = False, allow_cpu: bool = False
761
) -> int:
762
    r"""Gets the device index from :attr:`device`, which can be a torch.device
763
    object, a Python integer, or ``None``.
764

765
    If :attr:`device` is a torch.device object, returns the device index if it
766
    has index. Note that for a device without a specified index,
767
    i.e., ``torch.device('xxx')``, this will return the current default
768
    device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
769
    CPU devices will be accepted and ``-1`` will be returned in this case.
770

771
    If :attr:`device` is a Python integer, it is returned as is.
772

773
    If :attr:`device` is ``None``, this will return the current default
774
    device of the supported runtime platform if :attr:`optional` is ``True``.
775
    i.e., the current default CUDA device will be returned if CUDA runtime is supported.
776
    """
777
    if isinstance(device, str):
778
        device = torch.device(device)
779
    device_idx: Optional[int] = None
780
    if isinstance(device, torch.device):
781
        if not allow_cpu and device.type == "cpu":
782
            raise ValueError(f"Expected a non cpu device, but got: {device}")
783
        device_idx = -1 if device.type == "cpu" else device.index
784
    if isinstance(device, int):
785
        device_idx = device
786
    if device_idx is None:
787
        if optional:
788
            # The eager API _get_current_device_index uses `lambda` functions which are
789
            # not supported in JIT and hence not scriptable. The JIT equivalent API to get
790
            # the current device index is `get_current_device_index()` which can
791
            # be scripted. We use is_scripting to check the mode we are in and call the
792
            # appropriate API.
793
            if torch.jit.is_scripting():
794
                device_idx = get_current_device_index()
795
            else:
796
                device_idx = _get_current_device_index()
797
        else:
798
            raise ValueError(
799
                f"Expected a torch.device with a specified index or an integer, but got:{device}"
800
            )
801
    return device_idx
802

803

804
def _handle_complex(tensor):
805
    """
806
    Returns a real view of a tensor if complex dtype else just the tensor
807
    need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule
808
    """
809
    return (
810
        torch.view_as_real(tensor)
811
        if not isinstance(tensor, torch.nn.UninitializedParameter)
812
        and tensor.is_complex()
813
        else tensor
814
    )
815

816

817
def _element_size(dtype):
818
    """
819
    Returns the element size for a dtype, in bytes
820
    """
821
    if not isinstance(dtype, torch.dtype):
822
        raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}")
823

824
    if dtype.is_complex:
825
        return torch.finfo(dtype).bits >> 2
826
    elif dtype.is_floating_point:
827
        return torch.finfo(dtype).bits >> 3
828
    elif dtype == torch.bool:
829
        # NOTE: torch.bool is not supported in torch.iinfo()
830
        return 1
831
    else:
832
        return torch.iinfo(dtype).bits >> 3
833

834

835
class _ClassPropertyDescriptor:
836
    def __init__(self, fget, fset=None):
837
        self.fget = fget
838

839
    def __get__(self, instance, owner=None):
840
        if owner is None:
841
            owner = type(instance)
842
        return self.fget.__get__(instance, owner)()
843

844

845
def classproperty(func):
846
    if not isinstance(func, (classmethod, staticmethod)):
847
        func = classmethod(func)
848
    return _ClassPropertyDescriptor(func)
849

850

851
def is_compiling() -> bool:
852
    """
853
    Indicates whether we are tracing/compiling with torch.compile() or torch.export().
854

855
    TODO(khabinov): we should deprecate this function and use torch.compiler.is_compiling().
856
    """
857
    return torch.compiler.is_compiling()
858

859

860
def _functionalize_sync(t):
861
    # This code lives in python instead of C++ since conditioning on a certain python subclass
862
    # is much more of a pain in C++.
863
    from torch._subclasses.functional_tensor import FunctionalTensor
864

865
    if isinstance(t, FunctionalTensor):
866
        # If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called
867
        # when we sync our inner tensor.
868
        # Why?
869
        # (1) If there are input mutations in the graph, then they will be re-applied during
870
        #     AOTAutograd when we call _sync() from inside of our functionalization kernels.
871
        # (2) _sync() causes us to regenerate our updated the tensor from the updated base,
872
        #     which dispatches to a bunch of view ops
873
        # (3) The input to these view ops is our inner FunctionalTensorWrapper
874
        #     (since the sync was called from C++), not the python FunctionalTensor
875
        # (4) if a python FunctionalTensorMode is active, it will complain when it intercepts
876
        #     the view op, since it will see an input that is a C++ FunctionalTensorWrapper
877
        #     (aka a normal torch.Tensor) instead of a python `FunctionalTensor).
878
        maybe_functional_mode = torch._C._unset_dispatch_mode(
879
            torch._C._TorchDispatchModeKey.FUNCTIONAL
880
        )
881
        try:
882
            torch._functionalize_sync(t.elem)  # type: ignore[attr-defined]
883
        finally:
884
            if maybe_functional_mode is not None:
885
                torch._C._set_dispatch_mode(maybe_functional_mode)
886
    else:
887
        torch._functionalize_sync(t)  # type: ignore[attr-defined]
888

889

890
@functools.lru_cache(2)
891
def _get_device_module(device_type: str):
892
    device_module = getattr(torch, device_type, None)
893
    if device_module is None:
894
        raise RuntimeError(
895
            f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
896
        )
897
    return device_module
898

899

900
def _dummy_type(name: str) -> type:
901
    def get_err_fn(is_init: bool):
902
        def err_fn(obj, *args, **kwargs):
903
            if is_init:
904
                class_name = obj.__class__.__name__
905
            else:
906
                class_name = obj.__name__
907
            raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
908

909
        return err_fn
910

911
    return type(
912
        name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
913
    )
914

915

916
class _LazySeedTracker:
917
    # Since seeding is memory-less, only track the latest seed.
918
    # Note: `manual_seed_all` followed by `manual_seed` overwrites
919
    # the seed on current device. We track the order of **latest**
920
    # calls between these two API.
921
    def __init__(self):
922
        self.manual_seed_all_cb = None
923
        self.manual_seed_cb = None
924
        self.call_order = []
925

926
    def queue_seed_all(self, cb, traceback):
927
        self.manual_seed_all_cb = (cb, traceback)
928
        # update seed_all to be latest
929
        self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
930

931
    def queue_seed(self, cb, traceback):
932
        self.manual_seed_cb = (cb, traceback)
933
        # update seed to be latest
934
        self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
935

936
    def get_calls(self) -> List:
937
        return self.call_order
938

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.