pytorch

Форк
0
/
_utils.py 
1019 строк · 36.1 Кб
1
# mypy: allow-untyped-defs
2
import copyreg
3
import functools
4
import logging
5
import sys
6
import traceback
7
import warnings
8
from collections import defaultdict
9
from typing import Any, Callable, DefaultDict, Generic, List, Optional
10
from typing_extensions import ParamSpec
11

12
import torch
13

14

15
def _type(self, dtype=None, non_blocking=False, **kwargs):
16
    """Returns the type if `dtype` is not provided, else casts this object to
17
    the specified type.
18

19
    If this is already of the correct type, no copy is performed and the
20
    original object is returned.
21

22
    Args:
23
        dtype (type or string): The desired type
24
        non_blocking (bool): If ``True``, and the source is in pinned memory
25
            and destination is on the GPU or vice versa, the copy is performed
26
            asynchronously with respect to the host. Otherwise, the argument
27
            has no effect.
28
        **kwargs: For compatibility, may contain the key ``async`` in place of
29
            the ``non_blocking`` argument. The ``async`` arg is deprecated.
30
    """
31
    non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs)
32
    if dtype is None:
33
        return self.__module__ + "." + self.__class__.__name__
34

35
    if isinstance(dtype, str):
36
        dtype = _import_dotted_name(dtype)
37
    if dtype == type(self):
38
        return self
39
    if self.is_sparse:
40
        if not dtype.is_sparse:
41
            raise RuntimeError("Cannot cast sparse tensor to dense tensor")
42
        new_module_name = dtype.__module__.replace(".sparse", "")
43
        new_values_type_name = new_module_name + "." + dtype.__name__
44
        new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking)
45
        new_indices_type_name = new_module_name + ".LongTensor"
46
        new_indices = torch.Tensor._indices(self).type(
47
            new_indices_type_name, non_blocking
48
        )
49
        return dtype(new_indices, new_values, self.size())
50
    if dtype.is_sparse:
51
        raise RuntimeError("Cannot cast dense tensor to sparse tensor")
52
    return dtype(self.size()).copy_(self, non_blocking)
53

54

55
def _to(self, device, non_blocking=False):
56
    """Returns a copy of this object in device memory.
57

58
    If this object is already on the correct device, then no copy is performed
59
    and the original object is returned.
60

61
    Args:
62
        device (int): The destination device.
63
        non_blocking (bool): If ``True`` and the source is in pinned memory,
64
            the copy will be asynchronous with respect to the host. Otherwise,
65
            the argument has no effect.
66
    """
67
    if self.device == device:
68
        return self
69

70
    device_module = getattr(torch, device.type, None)
71
    assert (
72
        device_module is not None
73
    ), f"{device.type.upper()} device module is not loaded"
74
    with device_module.device(device):
75
        if self.is_sparse and hasattr(device_module, "sparse"):
76
            new_type = getattr(device_module.sparse, self.__class__.__name__)
77
            indices = getattr(torch.Tensor._indices(self), device.type)(
78
                device, non_blocking
79
            )
80
            values = getattr(torch.Tensor._values(self), device.type)(
81
                device, non_blocking
82
            )
83
            return new_type(indices, values, self.size())
84
        else:
85
            assert (
86
                not self.is_sparse
87
            ), f"sparse storage is not supported for {device.type.upper()} tensors"
88
            untyped_storage = torch.UntypedStorage(self.size(), device=device)
89
            untyped_storage.copy_(self, non_blocking)
90
            return untyped_storage
91

92

93
def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
94
    """Return the non-blocking flag given the function name and kwargs.
95

96
    Args:
97
        function_name (str): the name of the function being used.
98
        non_blocking (bool): the default value.
99
        **kwargs (dict): the kwargs passed to the function.
100
    """
101
    if not kwargs:
102
        return non_blocking
103
    if len(kwargs) != 1 or "async" not in kwargs:
104
        message = "{}() got an unexpected keyword argument '{}'"
105
        argument = list(kwargs.keys()).pop()
106
        raise TypeError(message.format(function_name, argument))
107
    warnings.warn("'async' is deprecated; use 'non_blocking'")
108
    return kwargs["async"]
109

110

111
def _get_restore_location(device):
112
    """Return the map_location location.
113

114
    Used for rebuild functions where the tensor device is distinct from the storage
115
    """
116

117
    map_location = torch.serialization._serialization_tls.map_location
118
    if map_location is None:
119
        return device
120
    else:
121
        if isinstance(map_location, dict):
122
            return map_location.get(device, device)
123
        elif isinstance(map_location, (str, torch.device)):
124
            return map_location
125
        else:
126
            assert callable(map_location)
127
            raise RuntimeError(
128
                "Callable map_location not supported with _rebuild_wrapper_subclass "
129
                "or _rebuild_device_tensor_from_numpy"
130
            )
131

132

133
# Note [Don't serialize hooks]
134
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
135
# Since time immemorial, we have serialized the backward hooks associated with
136
# variables.  This kind of half-worked--Python can pickle global functions
137
# (but not closures!)--but there were problems.
138
#
139
#   - It's fragile.  If you serialize a backward hook into a saved
140
#     model, and then you rename the function associated with the hook,
141
#     now your saved model is broken and you can't load it anymore.
142
#
143
#   - It's not actually used.  The standard recommendation is to
144
#     serialize the *state_dict* of a model, not the model itself
145
#     (since this is more stable to code changes affecting the model
146
#     serialization), and the state dict saves "data" only, thus
147
#     stripping the backward hooks.  In some cases, hooks are
148
#     essential to the well-functioning of a model (e.g., DDP),
149
#     but DDP already manages readding the hooks!
150
#
151
#   - We didn't serialize them in many cases.  Prior to #10220, we
152
#     were dropping backward hooks in ForkingPickler.  We "fixed" this
153
#     to be convenient with other serialization sites, but lack of
154
#     serializing backward hooks wasn't actually the root cause of
155
#     the bug.
156
#
157
# With these cases in mind, we have decided that a better strategy
158
# is to just NOT serialize hooks at all.
159
#
160
# Since this is a BC-breaking change, we should warn when we previously
161
# serialized a hook, but no longer do so. This will be done by adding a special
162
# sentinel property to hooks will be used to suppress this warning. If a hook
163
# has the property _torch_serialize_ignore, we will not emit a warning if we
164
# attempt to serialize a Tensor with this hook attached to it.
165
#
166
# By the way, when _backward_hooks is skipped, we must give an EMPTY
167
# OrderedDict(), if you pass a None you'll run afoul #12219.
168

169

170
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
171
# be a TypedStorage
172
def _rebuild_tensor(storage, storage_offset, size, stride):
173
    # first construct a tensor with the correct dtype/device
174
    t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device)
175
    return t.set_(storage._untyped_storage, storage_offset, size, stride)
176

177

178
def get_tensor_metadata(tensor):
179
    # Tensor's Metadata for serializing.
180
    # Currently, this only returns a dict[string, bool] specifing whether
181
    # `conj` or `neg` bit is set.
182
    assert isinstance(tensor, torch.Tensor)
183
    return torch._C._get_tensor_metadata(tensor)  # type: ignore[attr-defined]
184

185

186
def set_tensor_metadata(tensor, metadata):
187
    # See `get_tensor_metadata` above
188
    assert isinstance(metadata, dict)
189
    assert isinstance(tensor, torch.Tensor)
190
    torch._C._set_tensor_metadata(tensor, metadata)  # type: ignore[attr-defined]
191

192

193
def _rebuild_tensor_v2(
194
    storage,
195
    storage_offset,
196
    size,
197
    stride,
198
    requires_grad,
199
    backward_hooks,
200
    metadata=None,
201
):
202
    tensor = _rebuild_tensor(storage, storage_offset, size, stride)
203
    tensor.requires_grad = requires_grad
204
    if metadata:
205
        set_tensor_metadata(tensor, metadata)
206

207
    # NB: This line exists only for backwards compatibility; the
208
    # general expectation is that backward_hooks is an empty
209
    # OrderedDict.  See Note [Don't serialize hooks]
210
    tensor._backward_hooks = backward_hooks
211
    return tensor
212

213

214
def _rebuild_tensor_v3(
215
    storage,
216
    storage_offset,
217
    size,
218
    stride,
219
    requires_grad,
220
    backward_hooks,
221
    dtype,
222
    metadata=None,
223
):
224
    t = torch.empty(
225
        (0,),
226
        dtype=dtype,
227
        device=storage._untyped_storage.device,
228
        requires_grad=requires_grad,
229
    )
230
    t.set_(storage._untyped_storage, storage_offset, size, stride)
231
    if metadata:
232
        set_tensor_metadata(t, metadata)
233
    t._backward_hooks = backward_hooks
234
    return t
235

236

237
_sparse_tensors_to_validate: List["torch.Tensor"] = []
238

239

240
# In _legacy_load() in serialization.py we unpickle storages after the sparse
241
# tensors have been already unpickled. Those storages contain data necessary for
242
# validating sparse tensors: indices and values. That's why sparse tensors are
243
# first unpickled without any validation, and then this function is called just
244
# before _legacy_load() returns, so that all the sparse tensors can be validated
245
# in bulk.
246
#
247
# The same procedure must be followed by _load() in serialization.py because due
248
# to Pickler semantics, we have to use the same (non-validating) function for
249
# unpickling sparse tensors, regardless of the caller.
250
def _validate_loaded_sparse_tensors():
251
    try:
252
        for t in _sparse_tensors_to_validate:
253
            if t.layout is torch.sparse_coo:
254
                torch._validate_sparse_coo_tensor_args(
255
                    t._indices(), t._values(), t.size(), t.is_coalesced()
256
                )
257
            elif t.layout in {
258
                torch.sparse_csr,
259
                torch.sparse_csc,
260
                torch.sparse_bsr,
261
                torch.sparse_bsc,
262
            }:
263
                # TODO: Validation currently involves an expensive traversal
264
                # on CPU, which may include a device transfer.
265
                if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
266
                    compressed_indices, plain_indices = (
267
                        t.crow_indices(),
268
                        t.col_indices(),
269
                    )
270
                else:
271
                    compressed_indices, plain_indices = (
272
                        t.ccol_indices(),
273
                        t.row_indices(),
274
                    )
275
                torch._validate_sparse_compressed_tensor_args(
276
                    compressed_indices, plain_indices, t.values(), t.size(), t.layout
277
                )
278
            else:
279
                raise NotImplementedError(
280
                    f"_validate_loaded_sparse_tensors for layout `{t.layout}`"
281
                )
282

283
    finally:
284
        _sparse_tensors_to_validate.clear()
285

286

287
def _rebuild_sparse_tensor(layout, data):
288
    """
289
    Rebuilds a sparse tensor from its sparse storage representation.
290

291
    Args:
292
        layout (str): The sparse storage layout of the tensor.
293
        data (tuple): The tensor's sparse storage representation.
294
    """
295
    if layout == torch.sparse_coo:
296
        if len(data) == 3:
297
            # For BC:
298
            indices, values, size = data
299
            is_coalesced = None
300
        else:
301
            indices, values, size, is_coalesced = data
302
        result = torch.sparse_coo_tensor(
303
            indices, values, size, check_invariants=False, is_coalesced=is_coalesced
304
        )
305
        _sparse_tensors_to_validate.append(result)
306
        return result
307

308
    elif layout in {
309
        torch.sparse_csr,
310
        torch.sparse_csc,
311
        torch.sparse_bsr,
312
        torch.sparse_bsc,
313
    }:
314
        compressed_indices, plain_indices, values, size = data
315
        result = torch.sparse_compressed_tensor(
316
            compressed_indices,
317
            plain_indices,
318
            values,
319
            size,
320
            layout=layout,
321
            check_invariants=False,
322
        )
323
        _sparse_tensors_to_validate.append(result)
324
        return result
325

326
    raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}")
327

328

329
def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
330
    return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets)
331

332

333
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
334
    device = _get_restore_location(device)
335
    tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
336
    tensor.requires_grad = requires_grad
337
    return tensor
338

339

340
# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch
341
_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
342

343

344
def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
345
    return torch.empty_strided(
346
        size, stride, dtype=dtype, device="meta", requires_grad=requires_grad
347
    )
348

349

350
def _rebuild_wrapper_subclass(
351
    cls,
352
    dtype,
353
    size,
354
    stride,
355
    storage_offset,
356
    layout,
357
    device,
358
    requires_grad,
359
):
360
    device = _get_restore_location(device)
361
    return torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
362
        cls,
363
        size,
364
        strides=stride,
365
        dtype=dtype,
366
        storage_offset=storage_offset,
367
        layout=layout,
368
        device=device,
369
        requires_grad=requires_grad,
370
    )
371

372

373
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
374
# be a TypedStorage
375
def _rebuild_qtensor(
376
    storage,
377
    storage_offset,
378
    size,
379
    stride,
380
    quantizer_params,
381
    requires_grad,
382
    backward_hooks,
383
):
384
    qscheme = quantizer_params[0]
385
    if qscheme == torch.per_tensor_affine:
386
        _, scale, zero_point = quantizer_params
387
        tensor = torch._empty_affine_quantized(
388
            size,
389
            scale=scale,
390
            zero_point=zero_point,
391
            dtype=storage.dtype,
392
            device=storage.device,
393
        )
394
    elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
395
        _, scales, zero_points, axis = quantizer_params
396
        if type(scales) is list and type(zero_points) is list:
397
            if qscheme == torch.per_channel_affine:
398
                scales = torch.tensor(scales, dtype=torch.double, device=storage.device)
399
                zero_points = torch.tensor(
400
                    zero_points, dtype=torch.long, device=storage.device
401
                )
402
            else:
403
                scales = torch.tensor(scales, dtype=torch.float, device=storage.device)
404
                zero_points = torch.tensor(
405
                    zero_points, dtype=torch.float, device=storage.device
406
                )
407
        tensor = torch._empty_per_channel_affine_quantized(
408
            size,
409
            scales=scales,
410
            zero_points=zero_points,
411
            axis=axis,
412
            dtype=storage.dtype,
413
            device=storage.device,
414
        )
415
    else:
416
        raise RuntimeError(f"Can't deserialize quantized tensor with qscheme {qscheme}")
417
    tensor.set_(storage, storage_offset, size, stride)
418
    tensor.requires_grad = requires_grad
419
    # NB: This line exists only for backwards compatibility; the
420
    # general expectation is that backward_hooks is an empty
421
    # OrderedDict.  See Note [Don't serialize hooks]
422
    tensor._backward_hooks = backward_hooks
423
    return tensor
424

425

426
def _rebuild_parameter(data, requires_grad, backward_hooks):
427
    param = torch.nn.Parameter(data, requires_grad)
428
    # NB: This line exists only for backwards compatibility; the
429
    # general expectation is that backward_hooks is an empty
430
    # OrderedDict.  See Note [Don't serialize hooks]
431
    param._backward_hooks = backward_hooks
432

433
    return param
434

435

436
def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state):
437
    param = torch.nn.Parameter(data, requires_grad)
438
    # NB: This line exists only for backwards compatibility; the
439
    # general expectation is that backward_hooks is an empty
440
    # OrderedDict.  See Note [Don't serialize hooks]
441
    param._backward_hooks = backward_hooks
442

443
    # Restore state on Parameter like python attr.
444
    param = _set_obj_state(param, state)
445
    return param
446

447

448
def _get_obj_state(obj):
449
    # Get the state of the python subclass
450
    # This loosely mimicks the function on the object class but since Tensor do not inherit
451
    # from it, we cannot call that function directly
452
    # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891
453
    # Note that starting with Python 3.11, this `__getstate__` is always defined and thus
454
    # the else branch will never be taken.
455
    getstate_fn = getattr(obj, "__getstate__", None)
456
    if getstate_fn:
457
        state = getstate_fn()
458
    else:
459
        slots_to_save = copyreg._slotnames(obj.__class__)  # type: ignore[attr-defined]
460
        if slots_to_save:
461
            state = (
462
                obj.__dict__,
463
                {
464
                    name: getattr(obj, name)
465
                    for name in slots_to_save
466
                    if hasattr(obj, name)
467
                },
468
            )
469
        else:
470
            state = obj.__dict__
471

472
    return state
473

474

475
def _set_obj_state(obj, state):
476
    if isinstance(state, tuple):
477
        if not len(state) == 2:
478
            raise RuntimeError(f"Invalid serialized state: {state}")
479
        dict_state = state[0]
480
        slots_state = state[1]
481
    else:
482
        dict_state = state
483
        slots_state = None
484

485
    # Starting with Python 3.11, the __dict__ attribute is lazily created
486
    # and is serialized as None when not needed.
487
    if dict_state:
488
        for k, v in dict_state.items():
489
            setattr(obj, k, v)
490

491
    if slots_state:
492
        for k, v in slots_state.items():
493
            setattr(obj, k, v)
494
    return obj
495

496

497
def _import_dotted_name(name):
498
    components = name.split(".")
499
    obj = __import__(components[0])
500
    for component in components[1:]:
501
        obj = getattr(obj, component)
502
    return obj
503

504

505
def _flatten_dense_tensors(tensors):
506
    """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
507
    same dense type.
508

509
    Since inputs are dense, the resulting tensor will be a concatenated 1D
510
    buffer. Element-wise operation on this buffer will be equivalent to
511
    operating individually.
512

513
    Args:
514
        tensors (Iterable[Tensor]): dense tensors to flatten.
515

516
    Returns:
517
        A contiguous 1D buffer containing input tensors.
518
    """
519
    return torch._C._nn.flatten_dense_tensors(tensors)
520

521

522
def _flatten_sparse_tensors(tensors):
523
    """Flatten sparse tensors into two contiguous 1D buffers, one of indices and
524
    one of values. Assume tensors are of same sparse type.
525

526
    Args:
527
        tensors (Iterable[Tensor]): sparse tensors to flatten.
528

529
    Returns:
530
        A tuple of two contiguous 1D buffers, one containing input tensors'
531
        indices and the other containing the values.
532
    """
533
    flat_indices = torch._C._nn.flatten_dense_tensors(
534
        [torch.Tensor._indices(t) for t in tensors]
535
    )
536
    flat_values = torch._C._nn.flatten_dense_tensors(
537
        [torch.Tensor._values(t) for t in tensors]
538
    )
539
    return flat_indices, flat_values
540

541

542
def _unflatten_dense_tensors(flat, tensors):
543
    """View a flat buffer using the sizes of tensors. Assume that tensors are of
544
    same dense type, and that flat is given by _flatten_dense_tensors.
545

546
    Args:
547
        flat (Tensor): flattened dense tensors to unflatten.
548
        tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
549
          unflatten flat.
550

551
    Returns:
552
        Unflattened dense tensors with sizes same as tensors and values from
553
        flat.
554
    """
555
    return torch._C._nn.unflatten_dense_tensors(flat, tensors)
556

557

558
def _unflatten_sparse_tensors(flat, tensors):
559
    """View flat buffer (containing indices and values) using the sizes of
560
    tensors. Assume that tensors are of same sparse type, and that flat is given
561
    by _flatten_sparse_tensors.
562

563
    Args:
564
        flat (tuple(Tensor, Tensor)): flattened indices and values of sparse
565
          tensors to unflatten.
566
        tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to
567
          unflatten flat.
568

569
    Returns:
570
        Unflattened sparse tensors with sizes same as tensors and values from
571
        flat.
572
    """
573
    flat_indices, flat_values = flat
574
    indices = torch._C._nn.unflatten_dense_tensors(
575
        flat_indices, [torch.Tensor._indices(t) for t in tensors]
576
    )
577
    values = torch._C._nn.unflatten_dense_tensors(
578
        flat_values, [torch.Tensor._values(t) for t in tensors]
579
    )
580
    outputs = []
581
    for t, i, v in zip(tensors, indices, values):
582
        outputs.append(t.new(i, v, t.size()))
583
    return tuple(outputs)
584

585

586
def _reorder_tensors_as(tensors, ordered_tensors):
587
    """Assume that tensors are of same order as ordered_tensors within their
588
    types, e.g., from _take_tensors. Reorder them to be of same order as
589
    ordered_tensors.
590

591
    Args:
592
        tensors (Iterable[Tensor]): tensors to be reordered. They should be of
593
          the same order as ordered_tensors within their own types.
594
        ordered_tensors (Iterable[Tensor]): tensors whose order will be the
595
          reference.
596

597
    Returns:
598
        Ordered tuple of tensors with contents from tensors and order of
599
        ordered_tensors.
600
    """
601
    type_dict = defaultdict(list)
602
    for tensor in tensors:
603
        type_dict[tensor.type()].append(tensor)
604
    type_dict_ = {t: iter(coll) for t, coll in type_dict.items()}
605
    return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors)
606

607

608
def _take_tensors(tensors, size_limit):
609
    """Group tensors into chunks. This generator yields a chunk at each time,
610
    each containing tensors of same type up to certain byte limit in total size.
611

612
    Args:
613
        tensors (Sequence): A sequence of tensors to be separated into chunks.
614
        size_limit (int): The limit of each chunk in bytes.
615

616
    Yields:
617
        Blocks of tensors of same type and within size_limit. The yielded
618
        tensors are only ordered as the original sequence within its types.
619
    """
620
    buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0])
621
    for tensor in tensors:
622
        t = tensor.type()
623
        if tensor.is_sparse:
624
            indices = torch.Tensor._indices(tensor)
625
            values = torch.Tensor._values(tensor)
626
            size = (
627
                indices.numel() * indices.element_size()
628
                + values.numel() * values.element_size()
629
            )
630
        else:
631
            size = tensor.numel() * tensor.element_size()
632
        buf_and_size = buf_dict[t]
633
        if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
634
            yield buf_and_size[0]
635
            buf_and_size = buf_dict[t] = [[], 0]
636
        buf_and_size[0].append(tensor)
637
        buf_and_size[1] += size
638
    for buf, _ in buf_dict.values():
639
        if len(buf) > 0:
640
            yield buf
641

642

643
# annotation decorator to get annotations in a way that is compatible
644
# with both Python 2 and 3
645
def annotate(ret, **kwargs):
646
    def dec(fun):
647
        fun.__annotations__ = dict(kwargs)
648
        fun.__annotations__["return"] = ret
649
        return fun
650

651
    return dec
652

653

654
def render_call(fn, args, kwargs):
655
    str_fn = torch.overrides.resolve_name(fn)
656
    if str_fn is None:
657
        str_fn = str(fn)
658

659
    str_args: List[str] = []
660
    with torch._tensor_str.printoptions(threshold=0, edgeitems=0):
661
        str_args.extend(repr(a) for a in args)
662
        str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items())
663
        r = f"{str_fn}({', '.join(str_args)})"
664
    return r
665

666

667
# NOTE [ Python Traceback Reference Cycle Problem ]
668
#
669
# When using sys.exc_info(), it is important to **not** store the exc_info[2],
670
# which is the traceback, because otherwise you will run into the traceback
671
# reference cycle problem, i.e., the traceback holding reference to the frame,
672
# and the frame (which holds reference to all the object in its temporary scope)
673
# holding reference the traceback.
674

675

676
class KeyErrorMessage(str):
677
    r"""str subclass that returns itself in repr"""
678

679
    def __repr__(self):
680
        return self
681

682

683
class ExceptionWrapper:
684
    r"""Wraps an exception plus traceback to communicate across threads"""
685

686
    def __init__(self, exc_info=None, where="in background"):
687
        # It is important that we don't store exc_info, see
688
        # NOTE [ Python Traceback Reference Cycle Problem ]
689
        if exc_info is None:
690
            exc_info = sys.exc_info()
691
        self.exc_type = exc_info[0]
692
        self.exc_msg = "".join(traceback.format_exception(*exc_info))
693
        self.where = where
694

695
    def reraise(self):
696
        r"""Reraises the wrapped exception in the current thread"""
697
        # Format a message such as: "Caught ValueError in DataLoader worker
698
        # process 2. Original Traceback:", followed by the traceback.
699
        msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}"
700
        if self.exc_type == KeyError:
701
            # KeyError calls repr() on its argument (usually a dict key). This
702
            # makes stack traces unreadable. It will not be changed in Python
703
            # (https://bugs.python.org/issue2651), so we work around it.
704
            msg = KeyErrorMessage(msg)
705
        elif getattr(self.exc_type, "message", None):
706
            # Some exceptions have first argument as non-str but explicitly
707
            # have message field
708
            raise self.exc_type(message=msg)
709
        try:
710
            exception = self.exc_type(msg)
711
        except TypeError:
712
            # If the exception takes multiple arguments, don't try to
713
            # instantiate since we don't know how to
714
            raise RuntimeError(msg) from None
715
        raise exception
716

717

718
def _get_available_device_type():
719
    if torch.cuda.is_available():
720
        return "cuda"
721
    if hasattr(torch, "xpu") and torch.xpu.is_available():  # type: ignore[attr-defined]
722
        return "xpu"
723
    if hasattr(torch, "mtia") and torch.mtia.is_available():
724
        return "mtia"
725
    custom_backend_name = torch._C._get_privateuse1_backend_name()
726
    custom_device_mod = getattr(torch, custom_backend_name, None)
727
    if custom_device_mod and custom_device_mod.is_available():
728
        return custom_backend_name
729
    # add more available device types here
730
    return None
731

732

733
def _get_device_attr(get_member):
734
    device_type = _get_available_device_type()
735
    if device_type and device_type.lower() == "cuda":
736
        return get_member(torch.cuda)
737
    if device_type and device_type.lower() == "xpu":
738
        return get_member(torch.xpu)  # type: ignore[attr-defined]
739
    if device_type and device_type.lower() == "mtia":
740
        return get_member(torch.mtia)
741
    if device_type == torch._C._get_privateuse1_backend_name():
742
        return get_member(getattr(torch, device_type))
743
    # add more available device types here
744
    return None
745

746

747
def _get_current_device_index():
748
    # current device index
749
    return _get_device_attr(lambda m: m.current_device())
750

751

752
def _get_all_device_indices():
753
    # all device index
754
    return _get_device_attr(lambda m: list(range(m.device_count())))
755

756

757
def _get_devices_properties(device_ids):
758
    # all device properties
759
    return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
760

761

762
def get_current_device_index() -> int:
763
    r"""Checks if there are CUDA devices available and
764
    returns the device index of the current default CUDA device.
765
    Returns -1 in case there are no CUDA devices available.
766
    Arguments: ``None``
767
    """
768
    if torch.cuda.device_count() > 0:
769
        return torch.cuda.current_device()
770
    return -1
771

772

773
def _get_device_index(
774
    device: Any,
775
    optional: bool = False,
776
    allow_cpu: bool = False,
777
) -> int:
778
    r"""Gets the device index from :attr:`device`, which can be a torch.device
779
    object, a Python integer, or ``None``.
780

781
    If :attr:`device` is a torch.device object, returns the device index if it
782
    has index. Note that for a device without a specified index,
783
    i.e., ``torch.device('xxx')``, this will return the current default
784
    device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
785
    CPU devices will be accepted and ``-1`` will be returned in this case.
786

787
    If :attr:`device` is a Python integer, it is returned as is.
788

789
    If :attr:`device` is ``None``, this will return the current default
790
    device of the supported runtime platform if :attr:`optional` is ``True``.
791
    i.e., the current default CUDA device will be returned if CUDA runtime is supported.
792
    """
793
    if isinstance(device, str):
794
        device = torch.device(device)
795
    device_idx: Optional[int] = None
796
    if isinstance(device, torch.device):
797
        if not allow_cpu and device.type == "cpu":
798
            raise ValueError(f"Expected a non cpu device, but got: {device}")
799
        device_idx = -1 if device.type == "cpu" else device.index
800
    if isinstance(device, int):
801
        device_idx = device
802
    if device_idx is None:
803
        if optional:
804
            # The eager API _get_current_device_index uses `lambda` functions which are
805
            # not supported in JIT and hence not scriptable. The JIT equivalent API to get
806
            # the current device index is `get_current_device_index()` which can
807
            # be scripted. We use is_scripting to check the mode we are in and call the
808
            # appropriate API.
809
            if torch.jit.is_scripting():
810
                device_idx = get_current_device_index()
811
            else:
812
                device_idx = _get_current_device_index()
813
        else:
814
            raise ValueError(
815
                f"Expected a torch.device with a specified index or an integer, but got:{device}"
816
            )
817
    return device_idx
818

819

820
def _handle_complex(tensor):
821
    """
822
    Returns a real view of a tensor if complex dtype else just the tensor
823
    need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule
824
    """
825
    return (
826
        torch.view_as_real(tensor)
827
        if not isinstance(tensor, torch.nn.UninitializedParameter)
828
        and tensor.is_complex()
829
        else tensor
830
    )
831

832

833
def _element_size(dtype):
834
    """
835
    Returns the element size for a dtype, in bytes
836
    """
837
    if not isinstance(dtype, torch.dtype):
838
        raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}")
839

840
    if dtype.is_complex:
841
        return torch.finfo(dtype).bits >> 2
842
    elif dtype.is_floating_point:
843
        return torch.finfo(dtype).bits >> 3
844
    elif dtype == torch.bool:
845
        # NOTE: torch.bool is not supported in torch.iinfo()
846
        return 1
847
    else:
848
        return torch.iinfo(dtype).bits >> 3
849

850

851
class _ClassPropertyDescriptor:
852
    def __init__(self, fget, fset=None):
853
        self.fget = fget
854

855
    def __get__(self, instance, owner=None):
856
        if owner is None:
857
            owner = type(instance)
858
        return self.fget.__get__(instance, owner)()
859

860

861
def classproperty(func):
862
    if not isinstance(func, (classmethod, staticmethod)):
863
        func = classmethod(func)
864
    return _ClassPropertyDescriptor(func)
865

866

867
def is_compiling() -> bool:
868
    """
869
    Indicates whether we are tracing/compiling with torch.compile() or torch.export().
870

871
    TODO(khabinov): we should deprecate this function and use torch.compiler.is_compiling().
872
    """
873
    return torch.compiler.is_compiling()
874

875

876
def _functionalize_sync(t):
877
    # This code lives in python instead of C++ since conditioning on a certain python subclass
878
    # is much more of a pain in C++.
879
    from torch._subclasses.functional_tensor import FunctionalTensor
880

881
    if isinstance(t, FunctionalTensor):
882
        # If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called
883
        # when we sync our inner tensor.
884
        # Why?
885
        # (1) If there are input mutations in the graph, then they will be re-applied during
886
        #     AOTAutograd when we call _sync() from inside of our functionalization kernels.
887
        # (2) _sync() causes us to regenerate our updated the tensor from the updated base,
888
        #     which dispatches to a bunch of view ops
889
        # (3) The input to these view ops is our inner FunctionalTensorWrapper
890
        #     (since the sync was called from C++), not the python FunctionalTensor
891
        # (4) if a python FunctionalTensorMode is active, it will complain when it intercepts
892
        #     the view op, since it will see an input that is a C++ FunctionalTensorWrapper
893
        #     (aka a normal torch.Tensor) instead of a python `FunctionalTensor).
894
        maybe_functional_mode = torch._C._unset_dispatch_mode(
895
            torch._C._TorchDispatchModeKey.FUNCTIONAL
896
        )
897
        try:
898
            torch._functionalize_sync(t.elem)  # type: ignore[attr-defined]
899
        finally:
900
            if maybe_functional_mode is not None:
901
                torch._C._set_dispatch_mode(maybe_functional_mode)
902
    else:
903
        torch._functionalize_sync(t)  # type: ignore[attr-defined]
904

905

906
@functools.lru_cache(2)
907
def _get_device_module(device_type: str):
908
    device_module = getattr(torch, device_type, None)
909
    if device_module is None:
910
        raise RuntimeError(
911
            f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
912
        )
913
    return device_module
914

915

916
def _dummy_type(name: str) -> type:
917
    def get_err_fn(is_init: bool):
918
        def err_fn(obj, *args, **kwargs):
919
            if is_init:
920
                class_name = obj.__class__.__name__
921
            else:
922
                class_name = obj.__name__
923
            raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
924

925
        return err_fn
926

927
    return type(
928
        name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
929
    )
930

931

932
class _LazySeedTracker:
933
    # Since seeding is memory-less, only track the latest seed.
934
    # Note: `manual_seed_all` followed by `manual_seed` overwrites
935
    # the seed on current device. We track the order of **latest**
936
    # calls between these two API.
937
    def __init__(self):
938
        self.manual_seed_all_cb = None
939
        self.manual_seed_cb = None
940
        self.call_order = []
941

942
    def queue_seed_all(self, cb, traceback):
943
        self.manual_seed_all_cb = (cb, traceback)
944
        # update seed_all to be latest
945
        self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
946

947
    def queue_seed(self, cb, traceback):
948
        self.manual_seed_cb = (cb, traceback)
949
        # update seed to be latest
950
        self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
951

952
    def get_calls(self) -> List:
953
        return self.call_order
954

955

956
logger = logging.getLogger(__name__)
957
P = ParamSpec("P")
958

959

960
class CallbackRegistry(Generic[P]):
961
    def __init__(self, name: str):
962
        self.name = name
963
        self.callback_list: List[Callable[P, None]] = []
964

965
    def add_callback(self, cb: Callable[P, None]) -> None:
966
        self.callback_list.append(cb)
967

968
    def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
969
        for cb in self.callback_list:
970
            try:
971
                cb(*args, **kwargs)
972
            except Exception as e:
973
                logger.exception(
974
                    "Exception in callback for %s registered with gpu trace", self.name
975
                )
976

977

978
# IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py
979
# for use in the weights_only Unpickler.
980

981
IMPORT_MAPPING = {
982
    "__builtin__": "builtins",
983
    "copy_reg": "copyreg",
984
    "Queue": "queue",
985
    "repr": "reprlib",
986
    "_abcoll": "collections.abc",
987
    # Non-mutual mappings.
988
    "UserDict": "collections",
989
    "UserList": "collections",
990
    "UserString": "collections",
991
    "whichdb": "dbm",
992
    "StringIO": "io",
993
    "cStringIO": "io",
994
}
995

996

997
# This contains rename rules that are easy to handle.  We ignore the more
998
# complex stuff (e.g. mapping the names in the urllib and types modules).
999
# These rules should be run before import names are fixed.
1000
NAME_MAPPING = {
1001
    ("__builtin__", "xrange"): ("builtins", "range"),
1002
    ("__builtin__", "reduce"): ("functools", "reduce"),
1003
    ("__builtin__", "intern"): ("sys", "intern"),
1004
    ("__builtin__", "unichr"): ("builtins", "chr"),
1005
    ("__builtin__", "unicode"): ("builtins", "str"),
1006
    ("__builtin__", "long"): ("builtins", "int"),
1007
    ("itertools", "izip"): ("builtins", "zip"),
1008
    ("itertools", "imap"): ("builtins", "map"),
1009
    ("itertools", "ifilter"): ("builtins", "filter"),
1010
    ("itertools", "ifilterfalse"): ("itertools", "filterfalse"),
1011
    ("itertools", "izip_longest"): ("itertools", "zip_longest"),
1012
    ("UserDict", "IterableUserDict"): ("collections", "UserDict"),
1013
    ("UserList", "UserList"): ("collections", "UserList"),
1014
    ("UserString", "UserString"): ("collections", "UserString"),
1015
    # Non-mutual mappings.
1016
    ("__builtin__", "basestring"): ("builtins", "str"),
1017
    ("exceptions", "StandardError"): ("builtins", "Exception"),
1018
    ("UserDict", "UserDict"): ("collections", "UserDict"),
1019
}
1020

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.