pytorch

Форк
0
/
_tensor_str.py 
708 строк · 26.3 Кб
1
# mypy: allow-untyped-defs
2
import contextlib
3
import dataclasses
4
import math
5
import textwrap
6
from typing import Any, Dict, Optional
7

8
import torch
9
from torch import inf
10

11

12
@dataclasses.dataclass
13
class __PrinterOptions:
14
    precision: int = 4
15
    threshold: float = 1000
16
    edgeitems: int = 3
17
    linewidth: int = 80
18
    sci_mode: Optional[bool] = None
19

20

21
PRINT_OPTS = __PrinterOptions()
22

23

24
# We could use **kwargs, but this will give better docs
25
def set_printoptions(
26
    precision=None,
27
    threshold=None,
28
    edgeitems=None,
29
    linewidth=None,
30
    profile=None,
31
    sci_mode=None,
32
):
33
    r"""Set options for printing. Items shamelessly taken from NumPy
34

35
    Args:
36
        precision: Number of digits of precision for floating point output
37
            (default = 4).
38
        threshold: Total number of array elements which trigger summarization
39
            rather than full `repr` (default = 1000).
40
        edgeitems: Number of array items in summary at beginning and end of
41
            each dimension (default = 3).
42
        linewidth: The number of characters per line for the purpose of
43
            inserting line breaks (default = 80). Thresholded matrices will
44
            ignore this parameter.
45
        profile: Sane defaults for pretty printing. Can override with any of
46
            the above options. (any one of `default`, `short`, `full`)
47
        sci_mode: Enable (True) or disable (False) scientific notation. If
48
            None (default) is specified, the value is defined by
49
            `torch._tensor_str._Formatter`. This value is automatically chosen
50
            by the framework.
51

52
    Example::
53

54
        >>> # Limit the precision of elements
55
        >>> torch.set_printoptions(precision=2)
56
        >>> torch.tensor([1.12345])
57
        tensor([1.12])
58
        >>> # Limit the number of elements shown
59
        >>> torch.set_printoptions(threshold=5)
60
        >>> torch.arange(10)
61
        tensor([0, 1, 2, ..., 7, 8, 9])
62
        >>> # Restore defaults
63
        >>> torch.set_printoptions(profile='default')
64
        >>> torch.tensor([1.12345])
65
        tensor([1.1235])
66
        >>> torch.arange(10)
67
        tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
68

69
    """
70
    if profile is not None:
71
        if profile == "default":
72
            PRINT_OPTS.precision = 4
73
            PRINT_OPTS.threshold = 1000
74
            PRINT_OPTS.edgeitems = 3
75
            PRINT_OPTS.linewidth = 80
76
        elif profile == "short":
77
            PRINT_OPTS.precision = 2
78
            PRINT_OPTS.threshold = 1000
79
            PRINT_OPTS.edgeitems = 2
80
            PRINT_OPTS.linewidth = 80
81
        elif profile == "full":
82
            PRINT_OPTS.precision = 4
83
            PRINT_OPTS.threshold = inf
84
            PRINT_OPTS.edgeitems = 3
85
            PRINT_OPTS.linewidth = 80
86

87
    if precision is not None:
88
        PRINT_OPTS.precision = precision
89
    if threshold is not None:
90
        PRINT_OPTS.threshold = threshold
91
    if edgeitems is not None:
92
        PRINT_OPTS.edgeitems = edgeitems
93
    if linewidth is not None:
94
        PRINT_OPTS.linewidth = linewidth
95
    PRINT_OPTS.sci_mode = sci_mode
96

97

98
def get_printoptions() -> Dict[str, Any]:
99
    r"""Gets the current options for printing, as a dictionary that
100
    can be passed as ``**kwargs`` to set_printoptions().
101
    """
102
    return dataclasses.asdict(PRINT_OPTS)
103

104

105
@contextlib.contextmanager
106
def printoptions(**kwargs):
107
    r"""Context manager that temporarily changes the print options.  Accepted
108
    arguments are same as :func:`set_printoptions`."""
109
    old_kwargs = get_printoptions()
110
    set_printoptions(**kwargs)
111
    try:
112
        yield
113
    finally:
114
        set_printoptions(**old_kwargs)
115

116

117
def tensor_totype(t):
118
    dtype = (
119
        torch.float
120
        if (
121
            t.is_mps
122
            or (t.is_xpu and not torch.xpu.get_device_properties(t.device).has_fp64)
123
        )
124
        else torch.double
125
    )
126
    return t.to(dtype=dtype)
127

128

129
class _Formatter:
130
    def __init__(self, tensor):
131
        self.floating_dtype = tensor.dtype.is_floating_point
132
        self.int_mode = True
133
        self.sci_mode = False
134
        self.max_width = 1
135

136
        with torch.no_grad():
137
            tensor_view = tensor.reshape(-1)
138

139
        if not self.floating_dtype:
140
            for value in tensor_view:
141
                value_str = f"{value}"
142
                self.max_width = max(self.max_width, len(value_str))
143

144
        else:
145
            nonzero_finite_vals = torch.masked_select(
146
                tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
147
            )
148

149
            if nonzero_finite_vals.numel() == 0:
150
                # no valid number, do nothing
151
                return
152

153
            # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
154
            nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs())
155
            nonzero_finite_min = tensor_totype(nonzero_finite_abs.min())
156
            nonzero_finite_max = tensor_totype(nonzero_finite_abs.max())
157

158
            for value in nonzero_finite_vals:
159
                if value != torch.ceil(value):
160
                    self.int_mode = False
161
                    break
162

163
            if self.int_mode:
164
                # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
165
                # to indicate that the tensor is of floating type. add 1 to the len to account for this.
166
                if (
167
                    nonzero_finite_max / nonzero_finite_min > 1000.0
168
                    or nonzero_finite_max > 1.0e8
169
                ):
170
                    self.sci_mode = True
171
                    for value in nonzero_finite_vals:
172
                        value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
173
                        self.max_width = max(self.max_width, len(value_str))
174
                else:
175
                    for value in nonzero_finite_vals:
176
                        value_str = f"{value:.0f}"
177
                        self.max_width = max(self.max_width, len(value_str) + 1)
178
            else:
179
                # Check if scientific representation should be used.
180
                if (
181
                    nonzero_finite_max / nonzero_finite_min > 1000.0
182
                    or nonzero_finite_max > 1.0e8
183
                    or nonzero_finite_min < 1.0e-4
184
                ):
185
                    self.sci_mode = True
186
                    for value in nonzero_finite_vals:
187
                        value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
188
                        self.max_width = max(self.max_width, len(value_str))
189
                else:
190
                    for value in nonzero_finite_vals:
191
                        value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
192
                        self.max_width = max(self.max_width, len(value_str))
193

194
        if PRINT_OPTS.sci_mode is not None:
195
            self.sci_mode = PRINT_OPTS.sci_mode
196

197
    def width(self):
198
        return self.max_width
199

200
    def format(self, value):
201
        if self.floating_dtype:
202
            if self.sci_mode:
203
                ret = f"{{:{self.max_width}.{PRINT_OPTS.precision}e}}".format(value)
204
            elif self.int_mode:
205
                ret = f"{value:.0f}"
206
                if not (math.isinf(value) or math.isnan(value)):
207
                    ret += "."
208
            else:
209
                ret = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
210
        else:
211
            ret = f"{value}"
212
        return (self.max_width - len(ret)) * " " + ret
213

214

215
def _scalar_str(self, formatter1, formatter2=None):
216
    if formatter2 is not None:
217
        real_str = _scalar_str(self.real, formatter1)
218
        imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip()
219
        # handles negative numbers, +0.0, -0.0
220
        if imag_str[0] == "+" or imag_str[0] == "-":
221
            return real_str + imag_str
222
        else:
223
            return real_str + "+" + imag_str
224
    else:
225
        return formatter1.format(self.item())
226

227

228
def _vector_str(self, indent, summarize, formatter1, formatter2=None):
229
    # length includes spaces and comma between elements
230
    element_length = formatter1.width() + 2
231
    if formatter2 is not None:
232
        # width for imag_formatter + an extra j for complex
233
        element_length += formatter2.width() + 1
234

235
    elements_per_line = max(
236
        1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length)))
237
    )
238

239
    def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
240
        if formatter2 is not None:
241
            real_str = formatter1.format(val.real)
242
            imag_str = (formatter2.format(val.imag) + "j").lstrip()
243
            # handles negative numbers, +0.0, -0.0
244
            if imag_str[0] == "+" or imag_str[0] == "-":
245
                return real_str + imag_str
246
            else:
247
                return real_str + "+" + imag_str
248
        else:
249
            return formatter1.format(val)
250

251
    if summarize and not PRINT_OPTS.edgeitems:
252
        # Deal with edge case that negative zero is zero
253
        data = ["..."]
254
    elif summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
255
        data = (
256
            [_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()]
257
            + [" ..."]
258
            + [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()]
259
        )
260
    else:
261
        data = [_val_formatter(val) for val in self.tolist()]
262

263
    data_lines = [
264
        data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line)
265
    ]
266
    lines = [", ".join(line) for line in data_lines]
267
    return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]"
268

269

270
# formatter2 is only used for printing complex tensors.
271
# For complex tensors, formatter1 and formatter2 are the formatters for tensor.real
272
# and tensor.imag respesectively
273
def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None):
274
    dim = self.dim()
275

276
    if dim == 0:
277
        return _scalar_str(self, formatter1, formatter2)
278

279
    if dim == 1:
280
        return _vector_str(self, indent, summarize, formatter1, formatter2)
281

282
    if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
283
        slices = (
284
            [
285
                _tensor_str_with_formatter(
286
                    self[i], indent + 1, summarize, formatter1, formatter2
287
                )
288
                for i in range(0, PRINT_OPTS.edgeitems)
289
            ]
290
            + ["..."]
291
            + [
292
                _tensor_str_with_formatter(
293
                    self[i], indent + 1, summarize, formatter1, formatter2
294
                )
295
                for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))
296
            ]
297
        )
298
    else:
299
        slices = [
300
            _tensor_str_with_formatter(
301
                self[i], indent + 1, summarize, formatter1, formatter2
302
            )
303
            for i in range(0, self.size(0))
304
        ]
305

306
    tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices)
307
    return "[" + tensor_str + "]"
308

309

310
def _tensor_str(self, indent):
311
    if self.numel() == 0:
312
        return "[]"
313

314
    if self.has_names():
315
        # There are two main codepaths (possibly more) that tensor printing goes through:
316
        # - tensor data can fit comfortably on screen
317
        # - tensor data needs to be summarized
318
        # Some of the codepaths don't fully support named tensors, so we send in
319
        # an unnamed tensor to the formatting code as a workaround.
320
        self = self.rename(None)
321

322
    summarize = self.numel() > PRINT_OPTS.threshold
323

324
    if self._is_zerotensor():
325
        self = self.clone()
326

327
    # handle the negative bit
328
    if self.is_neg():
329
        self = self.resolve_neg()
330

331
    if self.dtype in [
332
        torch.float16,
333
        torch.bfloat16,
334
        torch.float8_e5m2,
335
        torch.float8_e5m2fnuz,
336
        torch.float8_e4m3fn,
337
        torch.float8_e4m3fnuz,
338
    ]:
339
        self = self.float()
340

341
    if self.dtype is torch.complex32:
342
        self = self.cfloat()
343

344
    if self.dtype.is_complex:
345
        # handle the conjugate bit
346
        self = self.resolve_conj()
347
        real_formatter = _Formatter(
348
            get_summarized_data(self.real) if summarize else self.real
349
        )
350
        imag_formatter = _Formatter(
351
            get_summarized_data(self.imag) if summarize else self.imag
352
        )
353
        return _tensor_str_with_formatter(
354
            self, indent, summarize, real_formatter, imag_formatter
355
        )
356
    else:
357
        formatter = _Formatter(get_summarized_data(self) if summarize else self)
358
        return _tensor_str_with_formatter(self, indent, summarize, formatter)
359

360

361
def _add_suffixes(tensor_str, suffixes, indent, force_newline):
362
    tensor_strs = [tensor_str]
363
    last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1
364
    for suffix in suffixes:
365
        suffix_len = len(suffix)
366
        if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth:
367
            tensor_strs.append(",\n" + " " * indent + suffix)
368
            last_line_len = indent + suffix_len
369
            force_newline = False
370
        else:
371
            tensor_strs.append(", " + suffix)
372
            last_line_len += suffix_len + 2
373
    tensor_strs.append(")")
374
    return "".join(tensor_strs)
375

376

377
def get_summarized_data(self):
378
    dim = self.dim()
379
    if dim == 0:
380
        return self
381
    if dim == 1:
382
        if self.size(0) > 2 * PRINT_OPTS.edgeitems:
383
            return torch.cat(
384
                (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :])
385
            )
386
        else:
387
            return self
388
    if not PRINT_OPTS.edgeitems:
389
        return self.new_empty([0] * self.dim())
390
    elif self.size(0) > 2 * PRINT_OPTS.edgeitems:
391
        start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)]
392
        end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))]
393
        return torch.stack([get_summarized_data(x) for x in (start + end)])
394
    else:
395
        return torch.stack([get_summarized_data(x) for x in self])
396

397

398
def _str_intern(inp, *, tensor_contents=None):
399
    if torch._C._functorch.is_functorch_wrapped_tensor(inp):
400
        return _functorch_wrapper_str_intern(inp, tensor_contents=tensor_contents)
401
    is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter
402
    if inp.is_nested:
403
        prefix = "nested_tensor("
404
    elif is_plain_tensor:
405
        prefix = "tensor("
406
    else:
407
        prefix = f"{type(inp).__name__}("
408
    indent = len(prefix)
409
    suffixes = []
410
    custom_contents_provided = tensor_contents is not None
411
    if custom_contents_provided:
412
        tensor_str = tensor_contents
413

414
    # This is used to extract the primal value and thus disable the forward AD
415
    # within this function.
416
    # TODO(albanD) This needs to be updated when more than one level is supported
417
    self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
418

419
    # Note [Print tensor device]:
420
    # A general logic here is we only print device when it doesn't match
421
    # the device specified in default tensor type.
422
    # Currently torch.set_default_tensor_type() only supports CPU/CUDA, thus
423
    # torch._C._get_default_device() only returns either cpu or cuda.
424
    # In other cases, we don't have a way to set them as default yet,
425
    # and we should always print out device for them.
426
    if (
427
        self.device.type != torch._C._get_default_device()
428
        or (
429
            self.device.type == "cuda"
430
            and torch.cuda.current_device() != self.device.index
431
        )
432
        or (self.device.type == "mps")
433
    ):
434
        suffixes.append("device='" + str(self.device) + "'")
435

436
    # Tensor printing performs tensor operations like slice, indexing, etc to make it in a
437
    # representable format. These operations on ipu/xla/lazy/mtia tensor results in compilations. Hence,
438
    # to avoid compilations, copying the tensor to cpu before printing.
439
    if self.device.type in ["xla", "lazy", "ipu", "mtia"]:
440
        self = self.to("cpu")
441

442
    # TODO: add an API to map real -> complex dtypes
443
    _default_complex_dtype = (
444
        torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
445
    )
446
    has_default_dtype = self.dtype in (
447
        torch.get_default_dtype(),
448
        _default_complex_dtype,
449
        torch.int64,
450
        torch.bool,
451
    )
452
    if self.is_sparse:
453
        suffixes.append("size=" + str(tuple(self.shape)))
454
        from torch._subclasses.fake_tensor import FakeTensor
455

456
        is_meta = self.is_meta or isinstance(self, FakeTensor)
457
        if not is_meta:
458
            suffixes.append("nnz=" + str(self._nnz()))
459
        if not has_default_dtype:
460
            suffixes.append("dtype=" + str(self.dtype))
461
        if not custom_contents_provided:
462
            indices_prefix = "indices=tensor("
463
            indices = self._indices().detach()
464
            if is_meta:
465
                indices_str = "..."
466
            else:
467
                indices_str = _tensor_str(indices, indent + len(indices_prefix))
468
            if is_meta or indices.numel() == 0:
469
                indices_str += ", size=" + str(tuple(indices.shape))
470
            values_prefix = "values=tensor("
471
            values = self._values().detach()
472
            if is_meta:
473
                values_str = "..."
474
            else:
475
                values_str = _tensor_str(values, indent + len(values_prefix))
476
            if is_meta or values.numel() == 0:
477
                values_str += ", size=" + str(tuple(values.shape))
478
            tensor_str = (
479
                indices_prefix
480
                + indices_str
481
                + "),\n"
482
                + " " * indent
483
                + values_prefix
484
                + values_str
485
                + ")"
486
            )
487
    elif self.layout in {
488
        torch.sparse_csr,
489
        torch.sparse_csc,
490
        torch.sparse_bsr,
491
        torch.sparse_bsc,
492
    }:
493
        from torch._subclasses.fake_tensor import FakeTensor
494

495
        suffixes.append("size=" + str(tuple(self.shape)))
496
        is_meta = self.is_meta or isinstance(self, FakeTensor)
497
        if not is_meta:
498
            suffixes.append("nnz=" + str(self._nnz()))
499
        if not has_default_dtype:
500
            suffixes.append("dtype=" + str(self.dtype))
501
        if not custom_contents_provided:
502
            compressed_indices_method, plain_indices_method = {
503
                torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
504
                torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
505
                torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
506
                torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
507
            }[self.layout]
508
            if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
509
                cdimname, pdimname = "row", "column"
510
            else:
511
                cdimname, pdimname = "column", "row"
512
            compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor("
513
            compressed_indices = compressed_indices_method(self).detach()
514
            if is_meta:
515
                compressed_indices_str = "..."
516
            else:
517
                compressed_indices_str = _tensor_str(
518
                    compressed_indices, indent + len(compressed_indices_prefix)
519
                )
520
            if compressed_indices.numel() == 0 or is_meta:
521
                compressed_indices_str += ", size=" + str(
522
                    tuple(compressed_indices.shape)
523
                )
524
            plain_indices_prefix = f"{pdimname[:3]}_indices=tensor("
525
            plain_indices = plain_indices_method(self).detach()
526
            if is_meta:
527
                plain_indices_str = "..."
528
            else:
529
                plain_indices_str = _tensor_str(
530
                    plain_indices, indent + len(plain_indices_prefix)
531
                )
532
            if plain_indices.numel() == 0 or is_meta:
533
                plain_indices_str += ", size=" + str(tuple(plain_indices.shape))
534
            values_prefix = "values=tensor("
535
            values = self.values().detach()
536
            if is_meta:
537
                values_str = "..."
538
            else:
539
                values_str = _tensor_str(values, indent + len(values_prefix))
540
            if values.numel() == 0 or is_meta:
541
                values_str += ", size=" + str(tuple(values.shape))
542
            tensor_str = (
543
                compressed_indices_prefix
544
                + compressed_indices_str
545
                + "),\n"
546
                + " " * indent
547
                + plain_indices_prefix
548
                + plain_indices_str
549
                + "),\n"
550
                + " " * indent
551
                + values_prefix
552
                + values_str
553
                + ")"
554
            )
555
    elif self.is_quantized:
556
        suffixes.append("size=" + str(tuple(self.shape)))
557
        if not has_default_dtype:
558
            suffixes.append("dtype=" + str(self.dtype))
559
        suffixes.append("quantization_scheme=" + str(self.qscheme()))
560
        if (
561
            self.qscheme() == torch.per_tensor_affine
562
            or self.qscheme() == torch.per_tensor_symmetric
563
        ):
564
            suffixes.append("scale=" + str(self.q_scale()))
565
            suffixes.append("zero_point=" + str(self.q_zero_point()))
566
        elif (
567
            self.qscheme() == torch.per_channel_affine
568
            or self.qscheme() == torch.per_channel_symmetric
569
            or self.qscheme() == torch.per_channel_affine_float_qparams
570
        ):
571
            suffixes.append("scale=" + str(self.q_per_channel_scales()))
572
            suffixes.append("zero_point=" + str(self.q_per_channel_zero_points()))
573
            suffixes.append("axis=" + str(self.q_per_channel_axis()))
574
        if not custom_contents_provided:
575
            tensor_str = _tensor_str(self.dequantize(), indent)
576
    elif self.is_nested:
577
        if not custom_contents_provided:
578

579
            def indented_str(s, indent):
580
                return "\n".join(f"  {line}" for line in s.split("\n"))
581

582
            strs = ",\n".join(
583
                indented_str(str(t), indent + 1)
584
                for t in torch.ops.aten.unbind.int(self, 0)
585
            )
586
            tensor_str = f"[\n{strs}\n]"
587
    elif torch._is_functional_tensor(self):
588
        prefix = "_to_functional_tensor("
589
        tensor_str = repr(torch._from_functional_tensor(self))
590
    else:
591
        # Circular import problem, so we import it here
592
        from torch._subclasses.fake_tensor import FakeTensor
593

594
        if self.is_meta or isinstance(self, FakeTensor):
595
            suffixes.append("size=" + str(tuple(self.shape)))
596
            if self.dtype != torch.get_default_dtype():
597
                suffixes.append("dtype=" + str(self.dtype))
598
            # TODO: This implies that ellipses is valid syntax for allocating
599
            # a meta tensor or FakeTensor, which it could be, but it isn't right now
600
            if not custom_contents_provided:
601
                tensor_str = "..."
602
        else:
603
            if self.numel() == 0 and not self.is_sparse:
604
                # Explicitly print the shape if it is not (0,), to match NumPy behavior
605
                if self.dim() != 1:
606
                    suffixes.append("size=" + str(tuple(self.shape)))
607

608
                # In an empty tensor, there are no elements to infer if the dtype
609
                # should be int64, so it must be shown explicitly.
610
                if self.dtype != torch.get_default_dtype():
611
                    suffixes.append("dtype=" + str(self.dtype))
612
                if not custom_contents_provided:
613
                    tensor_str = "[]"
614
            else:
615
                if not PRINT_OPTS.edgeitems:
616
                    suffixes.append("size=" + str(tuple(self.shape)))
617

618
                if not has_default_dtype:
619
                    suffixes.append("dtype=" + str(self.dtype))
620

621
                if not custom_contents_provided:
622
                    if self.layout != torch.strided:
623
                        tensor_str = _tensor_str(self.to_dense(), indent)
624
                    else:
625
                        tensor_str = _tensor_str(self, indent)
626

627
    if self.layout != torch.strided:
628
        suffixes.append("layout=" + str(self.layout))
629

630
    # Use inp here to get the original grad_fn and not the one generated by the forward grad
631
    # unpacking.
632
    grad_fn_name = None
633
    try:
634
        grad_fn = inp.grad_fn
635
    except RuntimeError:
636
        # Accessing the grad_fn calls rebasing logic which would cause an error
637
        # if that tensor is a view created in no-grad mode modified in-place in
638
        # no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968
639
        grad_fn_name = "Invalid"
640

641
    if grad_fn_name is None and grad_fn is not None:  # type: ignore[possibly-undefined]
642
        grad_fn_name = type(grad_fn).__name__
643
        if grad_fn_name == "CppFunction":
644
            grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
645

646
    if grad_fn_name is not None:
647
        suffixes.append(f"grad_fn=<{grad_fn_name}>")
648
    elif inp.requires_grad:
649
        suffixes.append("requires_grad=True")
650

651
    if self.has_names():
652
        suffixes.append(f"names={self.names}")
653

654
    if tangent is not None:
655
        suffixes.append(f"tangent={tangent}")
656

657
    string_repr = _add_suffixes(
658
        prefix + tensor_str,  # type: ignore[possibly-undefined]
659
        suffixes,
660
        indent,
661
        force_newline=self.is_sparse,
662
    )
663

664
    # Check if this instance is flagged as a parameter and change the repr accordingly.
665
    # Unfortunately, this function has to be aware of this detail.
666
    # NB: This is currently skipped for plain tensor parameters to maintain BC. In the future,
667
    # this should be done for those as well to produce a valid repr.
668
    if isinstance(self, torch.nn.Parameter) and not is_plain_tensor:
669
        string_repr = f"Parameter({string_repr})"
670

671
    return string_repr
672

673

674
def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None):
675
    level = torch._C._functorch.maybe_get_level(tensor)
676
    assert level != -1
677

678
    if torch._C._functorch.is_functionaltensor(tensor):
679
        # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
680
        # that it's up to date first
681
        torch._sync(tensor)
682

683
    value = torch._C._functorch.get_unwrapped(tensor)
684
    value_repr = repr(value)
685

686
    indented_value_repr = textwrap.indent(value_repr, " " * 4)
687
    if torch._C._functorch.is_batchedtensor(tensor):
688
        bdim = torch._C._functorch.maybe_get_bdim(tensor)
689
        assert bdim != -1
690
        return (
691
            f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n"
692
            f"{indented_value_repr}\n"
693
            f")"
694
        )
695
    if torch._C._functorch.is_gradtrackingtensor(tensor):
696
        return (
697
            f"GradTrackingTensor(lvl={level}, value=\n" f"{indented_value_repr}\n" f")"
698
        )
699
    if torch._C._functorch.is_functionaltensor(tensor):
700
        return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})"
701

702
    raise ValueError("We don't know how to print this, please file us an issue")
703

704

705
def _str(self, *, tensor_contents=None):
706
    with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes():
707
        guard = torch._C._DisableFuncTorch()
708
        return _str_intern(self, tensor_contents=tensor_contents)
709

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.