pytorch

Форк
0
/
_tensor_str.py 
697 строк · 26.2 Кб
1
import contextlib
2
import dataclasses
3
import math
4
import textwrap
5
from typing import Any, Dict, Optional
6

7
import torch
8
from torch import inf
9

10

11
@dataclasses.dataclass
12
class __PrinterOptions:
13
    precision: int = 4
14
    threshold: float = 1000
15
    edgeitems: int = 3
16
    linewidth: int = 80
17
    sci_mode: Optional[bool] = None
18

19

20
PRINT_OPTS = __PrinterOptions()
21

22

23
# We could use **kwargs, but this will give better docs
24
def set_printoptions(
25
    precision=None,
26
    threshold=None,
27
    edgeitems=None,
28
    linewidth=None,
29
    profile=None,
30
    sci_mode=None,
31
):
32
    r"""Set options for printing. Items shamelessly taken from NumPy
33

34
    Args:
35
        precision: Number of digits of precision for floating point output
36
            (default = 4).
37
        threshold: Total number of array elements which trigger summarization
38
            rather than full `repr` (default = 1000).
39
        edgeitems: Number of array items in summary at beginning and end of
40
            each dimension (default = 3).
41
        linewidth: The number of characters per line for the purpose of
42
            inserting line breaks (default = 80). Thresholded matrices will
43
            ignore this parameter.
44
        profile: Sane defaults for pretty printing. Can override with any of
45
            the above options. (any one of `default`, `short`, `full`)
46
        sci_mode: Enable (True) or disable (False) scientific notation. If
47
            None (default) is specified, the value is defined by
48
            `torch._tensor_str._Formatter`. This value is automatically chosen
49
            by the framework.
50

51
    Example::
52

53
        >>> # Limit the precision of elements
54
        >>> torch.set_printoptions(precision=2)
55
        >>> torch.tensor([1.12345])
56
        tensor([1.12])
57
        >>> # Limit the number of elements shown
58
        >>> torch.set_printoptions(threshold=5)
59
        >>> torch.arange(10)
60
        tensor([0, 1, 2, ..., 7, 8, 9])
61
        >>> # Restore defaults
62
        >>> torch.set_printoptions(profile='default')
63
        >>> torch.tensor([1.12345])
64
        tensor([1.1235])
65
        >>> torch.arange(10)
66
        tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
67

68
    """
69
    if profile is not None:
70
        if profile == "default":
71
            PRINT_OPTS.precision = 4
72
            PRINT_OPTS.threshold = 1000
73
            PRINT_OPTS.edgeitems = 3
74
            PRINT_OPTS.linewidth = 80
75
        elif profile == "short":
76
            PRINT_OPTS.precision = 2
77
            PRINT_OPTS.threshold = 1000
78
            PRINT_OPTS.edgeitems = 2
79
            PRINT_OPTS.linewidth = 80
80
        elif profile == "full":
81
            PRINT_OPTS.precision = 4
82
            PRINT_OPTS.threshold = inf
83
            PRINT_OPTS.edgeitems = 3
84
            PRINT_OPTS.linewidth = 80
85

86
    if precision is not None:
87
        PRINT_OPTS.precision = precision
88
    if threshold is not None:
89
        PRINT_OPTS.threshold = threshold
90
    if edgeitems is not None:
91
        PRINT_OPTS.edgeitems = edgeitems
92
    if linewidth is not None:
93
        PRINT_OPTS.linewidth = linewidth
94
    PRINT_OPTS.sci_mode = sci_mode
95

96

97
def get_printoptions() -> Dict[str, Any]:
98
    r"""Gets the current options for printing, as a dictionary that
99
    can be passed as ``**kwargs`` to set_printoptions().
100
    """
101
    return dataclasses.asdict(PRINT_OPTS)
102

103

104
@contextlib.contextmanager
105
def printoptions(**kwargs):
106
    r"""Context manager that temporarily changes the print options.  Accepted
107
    arguments are same as :func:`set_printoptions`."""
108
    old_kwargs = get_printoptions()
109
    set_printoptions(**kwargs)
110
    try:
111
        yield
112
    finally:
113
        set_printoptions(**old_kwargs)
114

115

116
def tensor_totype(t):
117
    dtype = torch.float if t.is_mps else torch.double
118
    return t.to(dtype=dtype)
119

120

121
class _Formatter:
122
    def __init__(self, tensor):
123
        self.floating_dtype = tensor.dtype.is_floating_point
124
        self.int_mode = True
125
        self.sci_mode = False
126
        self.max_width = 1
127

128
        with torch.no_grad():
129
            tensor_view = tensor.reshape(-1)
130

131
        if not self.floating_dtype:
132
            for value in tensor_view:
133
                value_str = f"{value}"
134
                self.max_width = max(self.max_width, len(value_str))
135

136
        else:
137
            nonzero_finite_vals = torch.masked_select(
138
                tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
139
            )
140

141
            if nonzero_finite_vals.numel() == 0:
142
                # no valid number, do nothing
143
                return
144

145
            # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
146
            nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs())
147
            nonzero_finite_min = tensor_totype(nonzero_finite_abs.min())
148
            nonzero_finite_max = tensor_totype(nonzero_finite_abs.max())
149

150
            for value in nonzero_finite_vals:
151
                if value != torch.ceil(value):
152
                    self.int_mode = False
153
                    break
154

155
            if self.int_mode:
156
                # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
157
                # to indicate that the tensor is of floating type. add 1 to the len to account for this.
158
                if (
159
                    nonzero_finite_max / nonzero_finite_min > 1000.0
160
                    or nonzero_finite_max > 1.0e8
161
                ):
162
                    self.sci_mode = True
163
                    for value in nonzero_finite_vals:
164
                        value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
165
                        self.max_width = max(self.max_width, len(value_str))
166
                else:
167
                    for value in nonzero_finite_vals:
168
                        value_str = f"{value:.0f}"
169
                        self.max_width = max(self.max_width, len(value_str) + 1)
170
            else:
171
                # Check if scientific representation should be used.
172
                if (
173
                    nonzero_finite_max / nonzero_finite_min > 1000.0
174
                    or nonzero_finite_max > 1.0e8
175
                    or nonzero_finite_min < 1.0e-4
176
                ):
177
                    self.sci_mode = True
178
                    for value in nonzero_finite_vals:
179
                        value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
180
                        self.max_width = max(self.max_width, len(value_str))
181
                else:
182
                    for value in nonzero_finite_vals:
183
                        value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
184
                        self.max_width = max(self.max_width, len(value_str))
185

186
        if PRINT_OPTS.sci_mode is not None:
187
            self.sci_mode = PRINT_OPTS.sci_mode
188

189
    def width(self):
190
        return self.max_width
191

192
    def format(self, value):
193
        if self.floating_dtype:
194
            if self.sci_mode:
195
                ret = f"{{:{self.max_width}.{PRINT_OPTS.precision}e}}".format(value)
196
            elif self.int_mode:
197
                ret = f"{value:.0f}"
198
                if not (math.isinf(value) or math.isnan(value)):
199
                    ret += "."
200
            else:
201
                ret = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
202
        else:
203
            ret = f"{value}"
204
        return (self.max_width - len(ret)) * " " + ret
205

206

207
def _scalar_str(self, formatter1, formatter2=None):
208
    if formatter2 is not None:
209
        real_str = _scalar_str(self.real, formatter1)
210
        imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip()
211
        # handles negative numbers, +0.0, -0.0
212
        if imag_str[0] == "+" or imag_str[0] == "-":
213
            return real_str + imag_str
214
        else:
215
            return real_str + "+" + imag_str
216
    else:
217
        return formatter1.format(self.item())
218

219

220
def _vector_str(self, indent, summarize, formatter1, formatter2=None):
221
    # length includes spaces and comma between elements
222
    element_length = formatter1.width() + 2
223
    if formatter2 is not None:
224
        # width for imag_formatter + an extra j for complex
225
        element_length += formatter2.width() + 1
226

227
    elements_per_line = max(
228
        1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length)))
229
    )
230

231
    def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
232
        if formatter2 is not None:
233
            real_str = formatter1.format(val.real)
234
            imag_str = (formatter2.format(val.imag) + "j").lstrip()
235
            # handles negative numbers, +0.0, -0.0
236
            if imag_str[0] == "+" or imag_str[0] == "-":
237
                return real_str + imag_str
238
            else:
239
                return real_str + "+" + imag_str
240
        else:
241
            return formatter1.format(val)
242

243
    if summarize and not PRINT_OPTS.edgeitems:
244
        # Deal with edge case that negative zero is zero
245
        data = ["..."]
246
    elif summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
247
        data = (
248
            [_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()]
249
            + [" ..."]
250
            + [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()]
251
        )
252
    else:
253
        data = [_val_formatter(val) for val in self.tolist()]
254

255
    data_lines = [
256
        data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line)
257
    ]
258
    lines = [", ".join(line) for line in data_lines]
259
    return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]"
260

261

262
# formatter2 is only used for printing complex tensors.
263
# For complex tensors, formatter1 and formatter2 are the formatters for tensor.real
264
# and tensor.imag respesectively
265
def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None):
266
    dim = self.dim()
267

268
    if dim == 0:
269
        return _scalar_str(self, formatter1, formatter2)
270

271
    if dim == 1:
272
        return _vector_str(self, indent, summarize, formatter1, formatter2)
273

274
    if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
275
        slices = (
276
            [
277
                _tensor_str_with_formatter(
278
                    self[i], indent + 1, summarize, formatter1, formatter2
279
                )
280
                for i in range(0, PRINT_OPTS.edgeitems)
281
            ]
282
            + ["..."]
283
            + [
284
                _tensor_str_with_formatter(
285
                    self[i], indent + 1, summarize, formatter1, formatter2
286
                )
287
                for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))
288
            ]
289
        )
290
    else:
291
        slices = [
292
            _tensor_str_with_formatter(
293
                self[i], indent + 1, summarize, formatter1, formatter2
294
            )
295
            for i in range(0, self.size(0))
296
        ]
297

298
    tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices)
299
    return "[" + tensor_str + "]"
300

301

302
def _tensor_str(self, indent):
303
    if self.numel() == 0:
304
        return "[]"
305

306
    if self.has_names():
307
        # There are two main codepaths (possibly more) that tensor printing goes through:
308
        # - tensor data can fit comfortably on screen
309
        # - tensor data needs to be summarized
310
        # Some of the codepaths don't fully support named tensors, so we send in
311
        # an unnamed tensor to the formatting code as a workaround.
312
        self = self.rename(None)
313

314
    summarize = self.numel() > PRINT_OPTS.threshold
315

316
    if self._is_zerotensor():
317
        self = self.clone()
318

319
    # handle the negative bit
320
    if self.is_neg():
321
        self = self.resolve_neg()
322

323
    if self.dtype in [
324
        torch.float16,
325
        torch.bfloat16,
326
        torch.float8_e5m2,
327
        torch.float8_e5m2fnuz,
328
        torch.float8_e4m3fn,
329
        torch.float8_e4m3fnuz,
330
    ]:
331
        self = self.float()
332

333
    if self.dtype is torch.complex32:
334
        self = self.cfloat()
335

336
    if self.dtype.is_complex:
337
        # handle the conjugate bit
338
        self = self.resolve_conj()
339
        real_formatter = _Formatter(
340
            get_summarized_data(self.real) if summarize else self.real
341
        )
342
        imag_formatter = _Formatter(
343
            get_summarized_data(self.imag) if summarize else self.imag
344
        )
345
        return _tensor_str_with_formatter(
346
            self, indent, summarize, real_formatter, imag_formatter
347
        )
348
    else:
349
        formatter = _Formatter(get_summarized_data(self) if summarize else self)
350
        return _tensor_str_with_formatter(self, indent, summarize, formatter)
351

352

353
def _add_suffixes(tensor_str, suffixes, indent, force_newline):
354
    tensor_strs = [tensor_str]
355
    last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1
356
    for suffix in suffixes:
357
        suffix_len = len(suffix)
358
        if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth:
359
            tensor_strs.append(",\n" + " " * indent + suffix)
360
            last_line_len = indent + suffix_len
361
            force_newline = False
362
        else:
363
            tensor_strs.append(", " + suffix)
364
            last_line_len += suffix_len + 2
365
    tensor_strs.append(")")
366
    return "".join(tensor_strs)
367

368

369
def get_summarized_data(self):
370
    dim = self.dim()
371
    if dim == 0:
372
        return self
373
    if dim == 1:
374
        if self.size(0) > 2 * PRINT_OPTS.edgeitems:
375
            return torch.cat(
376
                (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :])
377
            )
378
        else:
379
            return self
380
    if not PRINT_OPTS.edgeitems:
381
        return self.new_empty([0] * self.dim())
382
    elif self.size(0) > 2 * PRINT_OPTS.edgeitems:
383
        start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)]
384
        end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))]
385
        return torch.stack([get_summarized_data(x) for x in (start + end)])
386
    else:
387
        return torch.stack([get_summarized_data(x) for x in self])
388

389

390
def _str_intern(inp, *, tensor_contents=None):
391
    if torch._C._functorch.is_functorch_wrapped_tensor(inp):
392
        return _functorch_wrapper_str_intern(inp, tensor_contents=tensor_contents)
393
    is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter
394
    if inp.is_nested:
395
        prefix = "nested_tensor("
396
    elif is_plain_tensor:
397
        prefix = "tensor("
398
    else:
399
        prefix = f"{type(inp).__name__}("
400
    indent = len(prefix)
401
    suffixes = []
402
    custom_contents_provided = tensor_contents is not None
403
    if custom_contents_provided:
404
        tensor_str = tensor_contents
405

406
    # This is used to extract the primal value and thus disable the forward AD
407
    # within this function.
408
    # TODO(albanD) This needs to be updated when more than one level is supported
409
    self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
410

411
    # Note [Print tensor device]:
412
    # A general logic here is we only print device when it doesn't match
413
    # the device specified in default tensor type.
414
    # Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus
415
    # torch._C._get_default_device() only returns either cpu or cuda.
416
    # In other cases, we don't have a way to set them as default yet,
417
    # and we should always print out device for them.
418
    if (
419
        self.device.type != torch._C._get_default_device()
420
        or (
421
            self.device.type == "cuda"
422
            and torch.cuda.current_device() != self.device.index
423
        )
424
        or (self.device.type == "mps")
425
    ):
426
        suffixes.append("device='" + str(self.device) + "'")
427

428
    # Tensor printing performs tensor operations like slice, indexing, etc to make it in a
429
    # representable format. These operations on ipu/xla/lazy/mtia tensor results in compilations. Hence,
430
    # to avoid compilations, copying the tensor to cpu before printing.
431
    if self.device.type in ["xla", "lazy", "ipu", "mtia"]:
432
        self = self.to("cpu")
433

434
    # TODO: add an API to map real -> complex dtypes
435
    _default_complex_dtype = (
436
        torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
437
    )
438
    has_default_dtype = self.dtype in (
439
        torch.get_default_dtype(),
440
        _default_complex_dtype,
441
        torch.int64,
442
        torch.bool,
443
    )
444
    if self.is_sparse:
445
        suffixes.append("size=" + str(tuple(self.shape)))
446
        from torch._subclasses.fake_tensor import FakeTensor
447

448
        is_meta = self.is_meta or isinstance(self, FakeTensor)
449
        if not is_meta:
450
            suffixes.append("nnz=" + str(self._nnz()))
451
        if not has_default_dtype:
452
            suffixes.append("dtype=" + str(self.dtype))
453
        if not custom_contents_provided:
454
            indices_prefix = "indices=tensor("
455
            indices = self._indices().detach()
456
            if is_meta:
457
                indices_str = "..."
458
            else:
459
                indices_str = _tensor_str(indices, indent + len(indices_prefix))
460
            if indices.numel() == 0 or is_meta:
461
                indices_str += ", size=" + str(tuple(indices.shape))
462
            values_prefix = "values=tensor("
463
            values = self._values().detach()
464
            if is_meta:
465
                values_str = "..."
466
            else:
467
                values_str = _tensor_str(values, indent + len(values_prefix))
468
            if values.numel() == 0 or is_meta:
469
                values_str += ", size=" + str(tuple(values.shape))
470
            tensor_str = (
471
                indices_prefix
472
                + indices_str
473
                + "),\n"
474
                + " " * indent
475
                + values_prefix
476
                + values_str
477
                + ")"
478
            )
479
    elif self.layout in {
480
        torch.sparse_csr,
481
        torch.sparse_csc,
482
        torch.sparse_bsr,
483
        torch.sparse_bsc,
484
    }:
485
        from torch._subclasses.fake_tensor import FakeTensor
486

487
        suffixes.append("size=" + str(tuple(self.shape)))
488
        is_meta = self.is_meta or isinstance(self, FakeTensor)
489
        if not is_meta:
490
            suffixes.append("nnz=" + str(self._nnz()))
491
        if not has_default_dtype:
492
            suffixes.append("dtype=" + str(self.dtype))
493
        if not custom_contents_provided:
494
            compressed_indices_method, plain_indices_method = {
495
                torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
496
                torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
497
                torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
498
                torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
499
            }[self.layout]
500
            if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
501
                cdimname, pdimname = "row", "column"
502
            else:
503
                cdimname, pdimname = "column", "row"
504
            compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor("
505
            compressed_indices = compressed_indices_method(self).detach()
506
            if is_meta:
507
                compressed_indices_str = "..."
508
            else:
509
                compressed_indices_str = _tensor_str(
510
                    compressed_indices, indent + len(compressed_indices_prefix)
511
                )
512
            if compressed_indices.numel() == 0 or is_meta:
513
                compressed_indices_str += ", size=" + str(
514
                    tuple(compressed_indices.shape)
515
                )
516
            plain_indices_prefix = f"{pdimname[:3]}_indices=tensor("
517
            plain_indices = plain_indices_method(self).detach()
518
            if is_meta:
519
                plain_indices_str = "..."
520
            else:
521
                plain_indices_str = _tensor_str(
522
                    plain_indices, indent + len(plain_indices_prefix)
523
                )
524
            if plain_indices.numel() == 0 or is_meta:
525
                plain_indices_str += ", size=" + str(tuple(plain_indices.shape))
526
            values_prefix = "values=tensor("
527
            values = self.values().detach()
528
            if is_meta:
529
                values_str = "..."
530
            else:
531
                values_str = _tensor_str(values, indent + len(values_prefix))
532
            if values.numel() == 0 or is_meta:
533
                values_str += ", size=" + str(tuple(values.shape))
534
            tensor_str = (
535
                compressed_indices_prefix
536
                + compressed_indices_str
537
                + "),\n"
538
                + " " * indent
539
                + plain_indices_prefix
540
                + plain_indices_str
541
                + "),\n"
542
                + " " * indent
543
                + values_prefix
544
                + values_str
545
                + ")"
546
            )
547
    elif self.is_quantized:
548
        suffixes.append("size=" + str(tuple(self.shape)))
549
        if not has_default_dtype:
550
            suffixes.append("dtype=" + str(self.dtype))
551
        suffixes.append("quantization_scheme=" + str(self.qscheme()))
552
        if (
553
            self.qscheme() == torch.per_tensor_affine
554
            or self.qscheme() == torch.per_tensor_symmetric
555
        ):
556
            suffixes.append("scale=" + str(self.q_scale()))
557
            suffixes.append("zero_point=" + str(self.q_zero_point()))
558
        elif (
559
            self.qscheme() == torch.per_channel_affine
560
            or self.qscheme() == torch.per_channel_symmetric
561
            or self.qscheme() == torch.per_channel_affine_float_qparams
562
        ):
563
            suffixes.append("scale=" + str(self.q_per_channel_scales()))
564
            suffixes.append("zero_point=" + str(self.q_per_channel_zero_points()))
565
            suffixes.append("axis=" + str(self.q_per_channel_axis()))
566
        if not custom_contents_provided:
567
            tensor_str = _tensor_str(self.dequantize(), indent)
568
    elif self.is_nested:
569
        if not custom_contents_provided:
570

571
            def indented_str(s, indent):
572
                return "\n".join(f"  {line}" for line in s.split("\n"))
573

574
            strs = ",\n".join(
575
                indented_str(str(t), indent + 1)
576
                for t in torch.ops.aten.unbind.int(self, 0)
577
            )
578
            tensor_str = f"[\n{strs}\n]"
579
    elif torch._is_functional_tensor(self):
580
        prefix = "_to_functional_tensor("
581
        tensor_str = repr(torch._from_functional_tensor(self))
582
    else:
583
        # Circular import problem, so we import it here
584
        from torch._subclasses.fake_tensor import FakeTensor
585

586
        if self.is_meta or isinstance(self, FakeTensor):
587
            suffixes.append("size=" + str(tuple(self.shape)))
588
            if self.dtype != torch.get_default_dtype():
589
                suffixes.append("dtype=" + str(self.dtype))
590
            # TODO: This implies that ellipses is valid syntax for allocating
591
            # a meta tensor or FakeTensor, which it could be, but it isn't right now
592
            if not custom_contents_provided:
593
                tensor_str = "..."
594
        else:
595
            if self.numel() == 0 and not self.is_sparse:
596
                # Explicitly print the shape if it is not (0,), to match NumPy behavior
597
                if self.dim() != 1:
598
                    suffixes.append("size=" + str(tuple(self.shape)))
599

600
                # In an empty tensor, there are no elements to infer if the dtype
601
                # should be int64, so it must be shown explicitly.
602
                if self.dtype != torch.get_default_dtype():
603
                    suffixes.append("dtype=" + str(self.dtype))
604
                if not custom_contents_provided:
605
                    tensor_str = "[]"
606
            else:
607
                if not PRINT_OPTS.edgeitems:
608
                    suffixes.append("size=" + str(tuple(self.shape)))
609

610
                if not has_default_dtype:
611
                    suffixes.append("dtype=" + str(self.dtype))
612

613
                if not custom_contents_provided:
614
                    if self.layout != torch.strided:
615
                        tensor_str = _tensor_str(self.to_dense(), indent)
616
                    else:
617
                        tensor_str = _tensor_str(self, indent)
618

619
    if self.layout != torch.strided:
620
        suffixes.append("layout=" + str(self.layout))
621

622
    # Use inp here to get the original grad_fn and not the one generated by the forward grad
623
    # unpacking.
624
    grad_fn_name = None
625
    try:
626
        grad_fn = inp.grad_fn
627
    except RuntimeError:
628
        # Accessing the grad_fn calls rebasing logic which would cause an error
629
        # if that tensor is a view created in no-grad mode modified in-place in
630
        # no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968
631
        grad_fn_name = "Invalid"
632

633
    if grad_fn_name is None and grad_fn is not None:  # type: ignore[possibly-undefined]
634
        grad_fn_name = type(grad_fn).__name__
635
        if grad_fn_name == "CppFunction":
636
            grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
637

638
    if grad_fn_name is not None:
639
        suffixes.append(f"grad_fn=<{grad_fn_name}>")
640
    elif inp.requires_grad:
641
        suffixes.append("requires_grad=True")
642

643
    if self.has_names():
644
        suffixes.append(f"names={self.names}")
645

646
    if tangent is not None:
647
        suffixes.append(f"tangent={tangent}")
648

649
    string_repr = _add_suffixes(
650
        prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse  # type: ignore[possibly-undefined]
651
    )
652

653
    # Check if this instance is flagged as a parameter and change the repr accordingly.
654
    # Unfortunately, this function has to be aware of this detail.
655
    # NB: This is currently skipped for plain tensor parameters to maintain BC. In the future,
656
    # this should be done for those as well to produce a valid repr.
657
    if isinstance(self, torch.nn.Parameter) and not is_plain_tensor:
658
        string_repr = f"Parameter({string_repr})"
659

660
    return string_repr
661

662

663
def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None):
664
    level = torch._C._functorch.maybe_get_level(tensor)
665
    assert level != -1
666

667
    if torch._C._functorch.is_functionaltensor(tensor):
668
        # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
669
        # that it's up to date first
670
        torch._sync(tensor)
671

672
    value = torch._C._functorch.get_unwrapped(tensor)
673
    value_repr = repr(value)
674

675
    indented_value_repr = textwrap.indent(value_repr, " " * 4)
676
    if torch._C._functorch.is_batchedtensor(tensor):
677
        bdim = torch._C._functorch.maybe_get_bdim(tensor)
678
        assert bdim != -1
679
        return (
680
            f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n"
681
            f"{indented_value_repr}\n"
682
            f")"
683
        )
684
    if torch._C._functorch.is_gradtrackingtensor(tensor):
685
        return (
686
            f"GradTrackingTensor(lvl={level}, value=\n" f"{indented_value_repr}\n" f")"
687
        )
688
    if torch._C._functorch.is_functionaltensor(tensor):
689
        return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})"
690

691
    raise ValueError("We don't know how to print this, please file us an issue")
692

693

694
def _str(self, *, tensor_contents=None):
695
    with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes():
696
        guard = torch._C._DisableFuncTorch()
697
        return _str_intern(self, tensor_contents=tensor_contents)
698

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.