6
from typing import Any, Dict, Optional
13
class __PrinterOptions:
15
threshold: float = 1000
18
sci_mode: Optional[bool] = None
21
PRINT_OPTS = __PrinterOptions()
33
r"""Set options for printing. Items shamelessly taken from NumPy
36
precision: Number of digits of precision for floating point output
38
threshold: Total number of array elements which trigger summarization
39
rather than full `repr` (default = 1000).
40
edgeitems: Number of array items in summary at beginning and end of
41
each dimension (default = 3).
42
linewidth: The number of characters per line for the purpose of
43
inserting line breaks (default = 80). Thresholded matrices will
44
ignore this parameter.
45
profile: Sane defaults for pretty printing. Can override with any of
46
the above options. (any one of `default`, `short`, `full`)
47
sci_mode: Enable (True) or disable (False) scientific notation. If
48
None (default) is specified, the value is defined by
49
`torch._tensor_str._Formatter`. This value is automatically chosen
54
>>> # Limit the precision of elements
55
>>> torch.set_printoptions(precision=2)
56
>>> torch.tensor([1.12345])
58
>>> # Limit the number of elements shown
59
>>> torch.set_printoptions(threshold=5)
61
tensor([0, 1, 2, ..., 7, 8, 9])
62
>>> # Restore defaults
63
>>> torch.set_printoptions(profile='default')
64
>>> torch.tensor([1.12345])
67
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
70
if profile is not None:
71
if profile == "default":
72
PRINT_OPTS.precision = 4
73
PRINT_OPTS.threshold = 1000
74
PRINT_OPTS.edgeitems = 3
75
PRINT_OPTS.linewidth = 80
76
elif profile == "short":
77
PRINT_OPTS.precision = 2
78
PRINT_OPTS.threshold = 1000
79
PRINT_OPTS.edgeitems = 2
80
PRINT_OPTS.linewidth = 80
81
elif profile == "full":
82
PRINT_OPTS.precision = 4
83
PRINT_OPTS.threshold = inf
84
PRINT_OPTS.edgeitems = 3
85
PRINT_OPTS.linewidth = 80
87
if precision is not None:
88
PRINT_OPTS.precision = precision
89
if threshold is not None:
90
PRINT_OPTS.threshold = threshold
91
if edgeitems is not None:
92
PRINT_OPTS.edgeitems = edgeitems
93
if linewidth is not None:
94
PRINT_OPTS.linewidth = linewidth
95
PRINT_OPTS.sci_mode = sci_mode
98
def get_printoptions() -> Dict[str, Any]:
99
r"""Gets the current options for printing, as a dictionary that
100
can be passed as ``**kwargs`` to set_printoptions().
102
return dataclasses.asdict(PRINT_OPTS)
105
@contextlib.contextmanager
106
def printoptions(**kwargs):
107
r"""Context manager that temporarily changes the print options. Accepted
108
arguments are same as :func:`set_printoptions`."""
109
old_kwargs = get_printoptions()
110
set_printoptions(**kwargs)
114
set_printoptions(**old_kwargs)
122
or (t.is_xpu and not torch.xpu.get_device_properties(t.device).has_fp64)
126
return t.to(dtype=dtype)
130
def __init__(self, tensor):
131
self.floating_dtype = tensor.dtype.is_floating_point
133
self.sci_mode = False
136
with torch.no_grad():
137
tensor_view = tensor.reshape(-1)
139
if not self.floating_dtype:
140
for value in tensor_view:
141
value_str = f"{value}"
142
self.max_width = max(self.max_width, len(value_str))
145
nonzero_finite_vals = torch.masked_select(
146
tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
149
if nonzero_finite_vals.numel() == 0:
154
nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs())
155
nonzero_finite_min = tensor_totype(nonzero_finite_abs.min())
156
nonzero_finite_max = tensor_totype(nonzero_finite_abs.max())
158
for value in nonzero_finite_vals:
159
if value != torch.ceil(value):
160
self.int_mode = False
167
nonzero_finite_max / nonzero_finite_min > 1000.0
168
or nonzero_finite_max > 1.0e8
171
for value in nonzero_finite_vals:
172
value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
173
self.max_width = max(self.max_width, len(value_str))
175
for value in nonzero_finite_vals:
176
value_str = f"{value:.0f}"
177
self.max_width = max(self.max_width, len(value_str) + 1)
181
nonzero_finite_max / nonzero_finite_min > 1000.0
182
or nonzero_finite_max > 1.0e8
183
or nonzero_finite_min < 1.0e-4
186
for value in nonzero_finite_vals:
187
value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
188
self.max_width = max(self.max_width, len(value_str))
190
for value in nonzero_finite_vals:
191
value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
192
self.max_width = max(self.max_width, len(value_str))
194
if PRINT_OPTS.sci_mode is not None:
195
self.sci_mode = PRINT_OPTS.sci_mode
198
return self.max_width
200
def format(self, value):
201
if self.floating_dtype:
203
ret = f"{{:{self.max_width}.{PRINT_OPTS.precision}e}}".format(value)
206
if not (math.isinf(value) or math.isnan(value)):
209
ret = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
212
return (self.max_width - len(ret)) * " " + ret
215
def _scalar_str(self, formatter1, formatter2=None):
216
if formatter2 is not None:
217
real_str = _scalar_str(self.real, formatter1)
218
imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip()
220
if imag_str[0] == "+" or imag_str[0] == "-":
221
return real_str + imag_str
223
return real_str + "+" + imag_str
225
return formatter1.format(self.item())
228
def _vector_str(self, indent, summarize, formatter1, formatter2=None):
230
element_length = formatter1.width() + 2
231
if formatter2 is not None:
233
element_length += formatter2.width() + 1
235
elements_per_line = max(
236
1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length)))
239
def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
240
if formatter2 is not None:
241
real_str = formatter1.format(val.real)
242
imag_str = (formatter2.format(val.imag) + "j").lstrip()
244
if imag_str[0] == "+" or imag_str[0] == "-":
245
return real_str + imag_str
247
return real_str + "+" + imag_str
249
return formatter1.format(val)
251
if summarize and not PRINT_OPTS.edgeitems:
254
elif summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
256
[_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()]
258
+ [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()]
261
data = [_val_formatter(val) for val in self.tolist()]
264
data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line)
266
lines = [", ".join(line) for line in data_lines]
267
return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]"
273
def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None):
277
return _scalar_str(self, formatter1, formatter2)
280
return _vector_str(self, indent, summarize, formatter1, formatter2)
282
if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
285
_tensor_str_with_formatter(
286
self[i], indent + 1, summarize, formatter1, formatter2
288
for i in range(0, PRINT_OPTS.edgeitems)
292
_tensor_str_with_formatter(
293
self[i], indent + 1, summarize, formatter1, formatter2
295
for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))
300
_tensor_str_with_formatter(
301
self[i], indent + 1, summarize, formatter1, formatter2
303
for i in range(0, self.size(0))
306
tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices)
307
return "[" + tensor_str + "]"
310
def _tensor_str(self, indent):
311
if self.numel() == 0:
320
self = self.rename(None)
322
summarize = self.numel() > PRINT_OPTS.threshold
324
if self._is_zerotensor():
329
self = self.resolve_neg()
335
torch.float8_e5m2fnuz,
337
torch.float8_e4m3fnuz,
341
if self.dtype is torch.complex32:
344
if self.dtype.is_complex:
346
self = self.resolve_conj()
347
real_formatter = _Formatter(
348
get_summarized_data(self.real) if summarize else self.real
350
imag_formatter = _Formatter(
351
get_summarized_data(self.imag) if summarize else self.imag
353
return _tensor_str_with_formatter(
354
self, indent, summarize, real_formatter, imag_formatter
357
formatter = _Formatter(get_summarized_data(self) if summarize else self)
358
return _tensor_str_with_formatter(self, indent, summarize, formatter)
361
def _add_suffixes(tensor_str, suffixes, indent, force_newline):
362
tensor_strs = [tensor_str]
363
last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1
364
for suffix in suffixes:
365
suffix_len = len(suffix)
366
if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth:
367
tensor_strs.append(",\n" + " " * indent + suffix)
368
last_line_len = indent + suffix_len
369
force_newline = False
371
tensor_strs.append(", " + suffix)
372
last_line_len += suffix_len + 2
373
tensor_strs.append(")")
374
return "".join(tensor_strs)
377
def get_summarized_data(self):
382
if self.size(0) > 2 * PRINT_OPTS.edgeitems:
384
(self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :])
388
if not PRINT_OPTS.edgeitems:
389
return self.new_empty([0] * self.dim())
390
elif self.size(0) > 2 * PRINT_OPTS.edgeitems:
391
start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)]
392
end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))]
393
return torch.stack([get_summarized_data(x) for x in (start + end)])
395
return torch.stack([get_summarized_data(x) for x in self])
398
def _str_intern(inp, *, tensor_contents=None):
399
if torch._C._functorch.is_functorch_wrapped_tensor(inp):
400
return _functorch_wrapper_str_intern(inp, tensor_contents=tensor_contents)
401
is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter
403
prefix = "nested_tensor("
404
elif is_plain_tensor:
407
prefix = f"{type(inp).__name__}("
410
custom_contents_provided = tensor_contents is not None
411
if custom_contents_provided:
412
tensor_str = tensor_contents
417
self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
427
self.device.type != torch._C._get_default_device()
429
self.device.type == "cuda"
430
and torch.cuda.current_device() != self.device.index
432
or (self.device.type == "mps")
434
suffixes.append("device='" + str(self.device) + "'")
439
if self.device.type in ["xla", "lazy", "ipu", "mtia"]:
440
self = self.to("cpu")
443
_default_complex_dtype = (
444
torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
446
has_default_dtype = self.dtype in (
447
torch.get_default_dtype(),
448
_default_complex_dtype,
453
suffixes.append("size=" + str(tuple(self.shape)))
454
from torch._subclasses.fake_tensor import FakeTensor
456
is_meta = self.is_meta or isinstance(self, FakeTensor)
458
suffixes.append("nnz=" + str(self._nnz()))
459
if not has_default_dtype:
460
suffixes.append("dtype=" + str(self.dtype))
461
if not custom_contents_provided:
462
indices_prefix = "indices=tensor("
463
indices = self._indices().detach()
467
indices_str = _tensor_str(indices, indent + len(indices_prefix))
468
if is_meta or indices.numel() == 0:
469
indices_str += ", size=" + str(tuple(indices.shape))
470
values_prefix = "values=tensor("
471
values = self._values().detach()
475
values_str = _tensor_str(values, indent + len(values_prefix))
476
if is_meta or values.numel() == 0:
477
values_str += ", size=" + str(tuple(values.shape))
487
elif self.layout in {
493
from torch._subclasses.fake_tensor import FakeTensor
495
suffixes.append("size=" + str(tuple(self.shape)))
496
is_meta = self.is_meta or isinstance(self, FakeTensor)
498
suffixes.append("nnz=" + str(self._nnz()))
499
if not has_default_dtype:
500
suffixes.append("dtype=" + str(self.dtype))
501
if not custom_contents_provided:
502
compressed_indices_method, plain_indices_method = {
503
torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
504
torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
505
torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
506
torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
508
if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
509
cdimname, pdimname = "row", "column"
511
cdimname, pdimname = "column", "row"
512
compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor("
513
compressed_indices = compressed_indices_method(self).detach()
515
compressed_indices_str = "..."
517
compressed_indices_str = _tensor_str(
518
compressed_indices, indent + len(compressed_indices_prefix)
520
if compressed_indices.numel() == 0 or is_meta:
521
compressed_indices_str += ", size=" + str(
522
tuple(compressed_indices.shape)
524
plain_indices_prefix = f"{pdimname[:3]}_indices=tensor("
525
plain_indices = plain_indices_method(self).detach()
527
plain_indices_str = "..."
529
plain_indices_str = _tensor_str(
530
plain_indices, indent + len(plain_indices_prefix)
532
if plain_indices.numel() == 0 or is_meta:
533
plain_indices_str += ", size=" + str(tuple(plain_indices.shape))
534
values_prefix = "values=tensor("
535
values = self.values().detach()
539
values_str = _tensor_str(values, indent + len(values_prefix))
540
if values.numel() == 0 or is_meta:
541
values_str += ", size=" + str(tuple(values.shape))
543
compressed_indices_prefix
544
+ compressed_indices_str
547
+ plain_indices_prefix
555
elif self.is_quantized:
556
suffixes.append("size=" + str(tuple(self.shape)))
557
if not has_default_dtype:
558
suffixes.append("dtype=" + str(self.dtype))
559
suffixes.append("quantization_scheme=" + str(self.qscheme()))
561
self.qscheme() == torch.per_tensor_affine
562
or self.qscheme() == torch.per_tensor_symmetric
564
suffixes.append("scale=" + str(self.q_scale()))
565
suffixes.append("zero_point=" + str(self.q_zero_point()))
567
self.qscheme() == torch.per_channel_affine
568
or self.qscheme() == torch.per_channel_symmetric
569
or self.qscheme() == torch.per_channel_affine_float_qparams
571
suffixes.append("scale=" + str(self.q_per_channel_scales()))
572
suffixes.append("zero_point=" + str(self.q_per_channel_zero_points()))
573
suffixes.append("axis=" + str(self.q_per_channel_axis()))
574
if not custom_contents_provided:
575
tensor_str = _tensor_str(self.dequantize(), indent)
577
if not custom_contents_provided:
579
def indented_str(s, indent):
580
return "\n".join(f" {line}" for line in s.split("\n"))
583
indented_str(str(t), indent + 1)
584
for t in torch.ops.aten.unbind.int(self, 0)
586
tensor_str = f"[\n{strs}\n]"
587
elif torch._is_functional_tensor(self):
588
prefix = "_to_functional_tensor("
589
tensor_str = repr(torch._from_functional_tensor(self))
592
from torch._subclasses.fake_tensor import FakeTensor
594
if self.is_meta or isinstance(self, FakeTensor):
595
suffixes.append("size=" + str(tuple(self.shape)))
596
if self.dtype != torch.get_default_dtype():
597
suffixes.append("dtype=" + str(self.dtype))
600
if not custom_contents_provided:
603
if self.numel() == 0 and not self.is_sparse:
606
suffixes.append("size=" + str(tuple(self.shape)))
610
if self.dtype != torch.get_default_dtype():
611
suffixes.append("dtype=" + str(self.dtype))
612
if not custom_contents_provided:
615
if not PRINT_OPTS.edgeitems:
616
suffixes.append("size=" + str(tuple(self.shape)))
618
if not has_default_dtype:
619
suffixes.append("dtype=" + str(self.dtype))
621
if not custom_contents_provided:
622
if self.layout != torch.strided:
623
tensor_str = _tensor_str(self.to_dense(), indent)
625
tensor_str = _tensor_str(self, indent)
627
if self.layout != torch.strided:
628
suffixes.append("layout=" + str(self.layout))
634
grad_fn = inp.grad_fn
639
grad_fn_name = "Invalid"
641
if grad_fn_name is None and grad_fn is not None:
642
grad_fn_name = type(grad_fn).__name__
643
if grad_fn_name == "CppFunction":
644
grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
646
if grad_fn_name is not None:
647
suffixes.append(f"grad_fn=<{grad_fn_name}>")
648
elif inp.requires_grad:
649
suffixes.append("requires_grad=True")
652
suffixes.append(f"names={self.names}")
654
if tangent is not None:
655
suffixes.append(f"tangent={tangent}")
657
string_repr = _add_suffixes(
661
force_newline=self.is_sparse,
668
if isinstance(self, torch.nn.Parameter) and not is_plain_tensor:
669
string_repr = f"Parameter({string_repr})"
674
def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None):
675
level = torch._C._functorch.maybe_get_level(tensor)
678
if torch._C._functorch.is_functionaltensor(tensor):
683
value = torch._C._functorch.get_unwrapped(tensor)
684
value_repr = repr(value)
686
indented_value_repr = textwrap.indent(value_repr, " " * 4)
687
if torch._C._functorch.is_batchedtensor(tensor):
688
bdim = torch._C._functorch.maybe_get_bdim(tensor)
691
f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n"
692
f"{indented_value_repr}\n"
695
if torch._C._functorch.is_gradtrackingtensor(tensor):
697
f"GradTrackingTensor(lvl={level}, value=\n" f"{indented_value_repr}\n" f")"
699
if torch._C._functorch.is_functionaltensor(tensor):
700
return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})"
702
raise ValueError("We don't know how to print this, please file us an issue")
705
def _str(self, *, tensor_contents=None):
706
with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes():
707
guard = torch._C._DisableFuncTorch()
708
return _str_intern(self, tensor_contents=tensor_contents)