5
from typing import Any, Dict, Optional
12
class __PrinterOptions:
14
threshold: float = 1000
17
sci_mode: Optional[bool] = None
20
PRINT_OPTS = __PrinterOptions()
32
r"""Set options for printing. Items shamelessly taken from NumPy
35
precision: Number of digits of precision for floating point output
37
threshold: Total number of array elements which trigger summarization
38
rather than full `repr` (default = 1000).
39
edgeitems: Number of array items in summary at beginning and end of
40
each dimension (default = 3).
41
linewidth: The number of characters per line for the purpose of
42
inserting line breaks (default = 80). Thresholded matrices will
43
ignore this parameter.
44
profile: Sane defaults for pretty printing. Can override with any of
45
the above options. (any one of `default`, `short`, `full`)
46
sci_mode: Enable (True) or disable (False) scientific notation. If
47
None (default) is specified, the value is defined by
48
`torch._tensor_str._Formatter`. This value is automatically chosen
53
>>> # Limit the precision of elements
54
>>> torch.set_printoptions(precision=2)
55
>>> torch.tensor([1.12345])
57
>>> # Limit the number of elements shown
58
>>> torch.set_printoptions(threshold=5)
60
tensor([0, 1, 2, ..., 7, 8, 9])
61
>>> # Restore defaults
62
>>> torch.set_printoptions(profile='default')
63
>>> torch.tensor([1.12345])
66
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
69
if profile is not None:
70
if profile == "default":
71
PRINT_OPTS.precision = 4
72
PRINT_OPTS.threshold = 1000
73
PRINT_OPTS.edgeitems = 3
74
PRINT_OPTS.linewidth = 80
75
elif profile == "short":
76
PRINT_OPTS.precision = 2
77
PRINT_OPTS.threshold = 1000
78
PRINT_OPTS.edgeitems = 2
79
PRINT_OPTS.linewidth = 80
80
elif profile == "full":
81
PRINT_OPTS.precision = 4
82
PRINT_OPTS.threshold = inf
83
PRINT_OPTS.edgeitems = 3
84
PRINT_OPTS.linewidth = 80
86
if precision is not None:
87
PRINT_OPTS.precision = precision
88
if threshold is not None:
89
PRINT_OPTS.threshold = threshold
90
if edgeitems is not None:
91
PRINT_OPTS.edgeitems = edgeitems
92
if linewidth is not None:
93
PRINT_OPTS.linewidth = linewidth
94
PRINT_OPTS.sci_mode = sci_mode
97
def get_printoptions() -> Dict[str, Any]:
98
r"""Gets the current options for printing, as a dictionary that
99
can be passed as ``**kwargs`` to set_printoptions().
101
return dataclasses.asdict(PRINT_OPTS)
104
@contextlib.contextmanager
105
def printoptions(**kwargs):
106
r"""Context manager that temporarily changes the print options. Accepted
107
arguments are same as :func:`set_printoptions`."""
108
old_kwargs = get_printoptions()
109
set_printoptions(**kwargs)
113
set_printoptions(**old_kwargs)
117
dtype = torch.float if t.is_mps else torch.double
118
return t.to(dtype=dtype)
122
def __init__(self, tensor):
123
self.floating_dtype = tensor.dtype.is_floating_point
125
self.sci_mode = False
128
with torch.no_grad():
129
tensor_view = tensor.reshape(-1)
131
if not self.floating_dtype:
132
for value in tensor_view:
133
value_str = f"{value}"
134
self.max_width = max(self.max_width, len(value_str))
137
nonzero_finite_vals = torch.masked_select(
138
tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
141
if nonzero_finite_vals.numel() == 0:
146
nonzero_finite_abs = tensor_totype(nonzero_finite_vals.abs())
147
nonzero_finite_min = tensor_totype(nonzero_finite_abs.min())
148
nonzero_finite_max = tensor_totype(nonzero_finite_abs.max())
150
for value in nonzero_finite_vals:
151
if value != torch.ceil(value):
152
self.int_mode = False
159
nonzero_finite_max / nonzero_finite_min > 1000.0
160
or nonzero_finite_max > 1.0e8
163
for value in nonzero_finite_vals:
164
value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
165
self.max_width = max(self.max_width, len(value_str))
167
for value in nonzero_finite_vals:
168
value_str = f"{value:.0f}"
169
self.max_width = max(self.max_width, len(value_str) + 1)
173
nonzero_finite_max / nonzero_finite_min > 1000.0
174
or nonzero_finite_max > 1.0e8
175
or nonzero_finite_min < 1.0e-4
178
for value in nonzero_finite_vals:
179
value_str = f"{{:.{PRINT_OPTS.precision}e}}".format(value)
180
self.max_width = max(self.max_width, len(value_str))
182
for value in nonzero_finite_vals:
183
value_str = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
184
self.max_width = max(self.max_width, len(value_str))
186
if PRINT_OPTS.sci_mode is not None:
187
self.sci_mode = PRINT_OPTS.sci_mode
190
return self.max_width
192
def format(self, value):
193
if self.floating_dtype:
195
ret = f"{{:{self.max_width}.{PRINT_OPTS.precision}e}}".format(value)
198
if not (math.isinf(value) or math.isnan(value)):
201
ret = f"{{:.{PRINT_OPTS.precision}f}}".format(value)
204
return (self.max_width - len(ret)) * " " + ret
207
def _scalar_str(self, formatter1, formatter2=None):
208
if formatter2 is not None:
209
real_str = _scalar_str(self.real, formatter1)
210
imag_str = (_scalar_str(self.imag, formatter2) + "j").lstrip()
212
if imag_str[0] == "+" or imag_str[0] == "-":
213
return real_str + imag_str
215
return real_str + "+" + imag_str
217
return formatter1.format(self.item())
220
def _vector_str(self, indent, summarize, formatter1, formatter2=None):
222
element_length = formatter1.width() + 2
223
if formatter2 is not None:
225
element_length += formatter2.width() + 1
227
elements_per_line = max(
228
1, int(math.floor((PRINT_OPTS.linewidth - indent) / (element_length)))
231
def _val_formatter(val, formatter1=formatter1, formatter2=formatter2):
232
if formatter2 is not None:
233
real_str = formatter1.format(val.real)
234
imag_str = (formatter2.format(val.imag) + "j").lstrip()
236
if imag_str[0] == "+" or imag_str[0] == "-":
237
return real_str + imag_str
239
return real_str + "+" + imag_str
241
return formatter1.format(val)
243
if summarize and not PRINT_OPTS.edgeitems:
246
elif summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
248
[_val_formatter(val) for val in self[: PRINT_OPTS.edgeitems].tolist()]
250
+ [_val_formatter(val) for val in self[-PRINT_OPTS.edgeitems :].tolist()]
253
data = [_val_formatter(val) for val in self.tolist()]
256
data[i : i + elements_per_line] for i in range(0, len(data), elements_per_line)
258
lines = [", ".join(line) for line in data_lines]
259
return "[" + ("," + "\n" + " " * (indent + 1)).join(lines) + "]"
265
def _tensor_str_with_formatter(self, indent, summarize, formatter1, formatter2=None):
269
return _scalar_str(self, formatter1, formatter2)
272
return _vector_str(self, indent, summarize, formatter1, formatter2)
274
if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems:
277
_tensor_str_with_formatter(
278
self[i], indent + 1, summarize, formatter1, formatter2
280
for i in range(0, PRINT_OPTS.edgeitems)
284
_tensor_str_with_formatter(
285
self[i], indent + 1, summarize, formatter1, formatter2
287
for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))
292
_tensor_str_with_formatter(
293
self[i], indent + 1, summarize, formatter1, formatter2
295
for i in range(0, self.size(0))
298
tensor_str = ("," + "\n" * (dim - 1) + " " * (indent + 1)).join(slices)
299
return "[" + tensor_str + "]"
302
def _tensor_str(self, indent):
303
if self.numel() == 0:
312
self = self.rename(None)
314
summarize = self.numel() > PRINT_OPTS.threshold
316
if self._is_zerotensor():
321
self = self.resolve_neg()
327
torch.float8_e5m2fnuz,
329
torch.float8_e4m3fnuz,
333
if self.dtype is torch.complex32:
336
if self.dtype.is_complex:
338
self = self.resolve_conj()
339
real_formatter = _Formatter(
340
get_summarized_data(self.real) if summarize else self.real
342
imag_formatter = _Formatter(
343
get_summarized_data(self.imag) if summarize else self.imag
345
return _tensor_str_with_formatter(
346
self, indent, summarize, real_formatter, imag_formatter
349
formatter = _Formatter(get_summarized_data(self) if summarize else self)
350
return _tensor_str_with_formatter(self, indent, summarize, formatter)
353
def _add_suffixes(tensor_str, suffixes, indent, force_newline):
354
tensor_strs = [tensor_str]
355
last_line_len = len(tensor_str) - tensor_str.rfind("\n") + 1
356
for suffix in suffixes:
357
suffix_len = len(suffix)
358
if force_newline or last_line_len + suffix_len + 2 > PRINT_OPTS.linewidth:
359
tensor_strs.append(",\n" + " " * indent + suffix)
360
last_line_len = indent + suffix_len
361
force_newline = False
363
tensor_strs.append(", " + suffix)
364
last_line_len += suffix_len + 2
365
tensor_strs.append(")")
366
return "".join(tensor_strs)
369
def get_summarized_data(self):
374
if self.size(0) > 2 * PRINT_OPTS.edgeitems:
376
(self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :])
380
if not PRINT_OPTS.edgeitems:
381
return self.new_empty([0] * self.dim())
382
elif self.size(0) > 2 * PRINT_OPTS.edgeitems:
383
start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)]
384
end = [self[i] for i in range(len(self) - PRINT_OPTS.edgeitems, len(self))]
385
return torch.stack([get_summarized_data(x) for x in (start + end)])
387
return torch.stack([get_summarized_data(x) for x in self])
390
def _str_intern(inp, *, tensor_contents=None):
391
if torch._C._functorch.is_functorch_wrapped_tensor(inp):
392
return _functorch_wrapper_str_intern(inp, tensor_contents=tensor_contents)
393
is_plain_tensor = type(inp) is torch.Tensor or type(inp) is torch.nn.Parameter
395
prefix = "nested_tensor("
396
elif is_plain_tensor:
399
prefix = f"{type(inp).__name__}("
402
custom_contents_provided = tensor_contents is not None
403
if custom_contents_provided:
404
tensor_str = tensor_contents
409
self, tangent = torch.autograd.forward_ad.unpack_dual(inp)
419
self.device.type != torch._C._get_default_device()
421
self.device.type == "cuda"
422
and torch.cuda.current_device() != self.device.index
424
or (self.device.type == "mps")
426
suffixes.append("device='" + str(self.device) + "'")
431
if self.device.type in ["xla", "lazy", "ipu", "mtia"]:
432
self = self.to("cpu")
435
_default_complex_dtype = (
436
torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
438
has_default_dtype = self.dtype in (
439
torch.get_default_dtype(),
440
_default_complex_dtype,
445
suffixes.append("size=" + str(tuple(self.shape)))
446
from torch._subclasses.fake_tensor import FakeTensor
448
is_meta = self.is_meta or isinstance(self, FakeTensor)
450
suffixes.append("nnz=" + str(self._nnz()))
451
if not has_default_dtype:
452
suffixes.append("dtype=" + str(self.dtype))
453
if not custom_contents_provided:
454
indices_prefix = "indices=tensor("
455
indices = self._indices().detach()
459
indices_str = _tensor_str(indices, indent + len(indices_prefix))
460
if indices.numel() == 0 or is_meta:
461
indices_str += ", size=" + str(tuple(indices.shape))
462
values_prefix = "values=tensor("
463
values = self._values().detach()
467
values_str = _tensor_str(values, indent + len(values_prefix))
468
if values.numel() == 0 or is_meta:
469
values_str += ", size=" + str(tuple(values.shape))
479
elif self.layout in {
485
from torch._subclasses.fake_tensor import FakeTensor
487
suffixes.append("size=" + str(tuple(self.shape)))
488
is_meta = self.is_meta or isinstance(self, FakeTensor)
490
suffixes.append("nnz=" + str(self._nnz()))
491
if not has_default_dtype:
492
suffixes.append("dtype=" + str(self.dtype))
493
if not custom_contents_provided:
494
compressed_indices_method, plain_indices_method = {
495
torch.sparse_csr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
496
torch.sparse_csc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
497
torch.sparse_bsr: (torch.Tensor.crow_indices, torch.Tensor.col_indices),
498
torch.sparse_bsc: (torch.Tensor.ccol_indices, torch.Tensor.row_indices),
500
if self.layout in {torch.sparse_csr, torch.sparse_bsr}:
501
cdimname, pdimname = "row", "column"
503
cdimname, pdimname = "column", "row"
504
compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor("
505
compressed_indices = compressed_indices_method(self).detach()
507
compressed_indices_str = "..."
509
compressed_indices_str = _tensor_str(
510
compressed_indices, indent + len(compressed_indices_prefix)
512
if compressed_indices.numel() == 0 or is_meta:
513
compressed_indices_str += ", size=" + str(
514
tuple(compressed_indices.shape)
516
plain_indices_prefix = f"{pdimname[:3]}_indices=tensor("
517
plain_indices = plain_indices_method(self).detach()
519
plain_indices_str = "..."
521
plain_indices_str = _tensor_str(
522
plain_indices, indent + len(plain_indices_prefix)
524
if plain_indices.numel() == 0 or is_meta:
525
plain_indices_str += ", size=" + str(tuple(plain_indices.shape))
526
values_prefix = "values=tensor("
527
values = self.values().detach()
531
values_str = _tensor_str(values, indent + len(values_prefix))
532
if values.numel() == 0 or is_meta:
533
values_str += ", size=" + str(tuple(values.shape))
535
compressed_indices_prefix
536
+ compressed_indices_str
539
+ plain_indices_prefix
547
elif self.is_quantized:
548
suffixes.append("size=" + str(tuple(self.shape)))
549
if not has_default_dtype:
550
suffixes.append("dtype=" + str(self.dtype))
551
suffixes.append("quantization_scheme=" + str(self.qscheme()))
553
self.qscheme() == torch.per_tensor_affine
554
or self.qscheme() == torch.per_tensor_symmetric
556
suffixes.append("scale=" + str(self.q_scale()))
557
suffixes.append("zero_point=" + str(self.q_zero_point()))
559
self.qscheme() == torch.per_channel_affine
560
or self.qscheme() == torch.per_channel_symmetric
561
or self.qscheme() == torch.per_channel_affine_float_qparams
563
suffixes.append("scale=" + str(self.q_per_channel_scales()))
564
suffixes.append("zero_point=" + str(self.q_per_channel_zero_points()))
565
suffixes.append("axis=" + str(self.q_per_channel_axis()))
566
if not custom_contents_provided:
567
tensor_str = _tensor_str(self.dequantize(), indent)
569
if not custom_contents_provided:
571
def indented_str(s, indent):
572
return "\n".join(f" {line}" for line in s.split("\n"))
575
indented_str(str(t), indent + 1)
576
for t in torch.ops.aten.unbind.int(self, 0)
578
tensor_str = f"[\n{strs}\n]"
579
elif torch._is_functional_tensor(self):
580
prefix = "_to_functional_tensor("
581
tensor_str = repr(torch._from_functional_tensor(self))
584
from torch._subclasses.fake_tensor import FakeTensor
586
if self.is_meta or isinstance(self, FakeTensor):
587
suffixes.append("size=" + str(tuple(self.shape)))
588
if self.dtype != torch.get_default_dtype():
589
suffixes.append("dtype=" + str(self.dtype))
592
if not custom_contents_provided:
595
if self.numel() == 0 and not self.is_sparse:
598
suffixes.append("size=" + str(tuple(self.shape)))
602
if self.dtype != torch.get_default_dtype():
603
suffixes.append("dtype=" + str(self.dtype))
604
if not custom_contents_provided:
607
if not PRINT_OPTS.edgeitems:
608
suffixes.append("size=" + str(tuple(self.shape)))
610
if not has_default_dtype:
611
suffixes.append("dtype=" + str(self.dtype))
613
if not custom_contents_provided:
614
if self.layout != torch.strided:
615
tensor_str = _tensor_str(self.to_dense(), indent)
617
tensor_str = _tensor_str(self, indent)
619
if self.layout != torch.strided:
620
suffixes.append("layout=" + str(self.layout))
626
grad_fn = inp.grad_fn
631
grad_fn_name = "Invalid"
633
if grad_fn_name is None and grad_fn is not None:
634
grad_fn_name = type(grad_fn).__name__
635
if grad_fn_name == "CppFunction":
636
grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
638
if grad_fn_name is not None:
639
suffixes.append(f"grad_fn=<{grad_fn_name}>")
640
elif inp.requires_grad:
641
suffixes.append("requires_grad=True")
644
suffixes.append(f"names={self.names}")
646
if tangent is not None:
647
suffixes.append(f"tangent={tangent}")
649
string_repr = _add_suffixes(
650
prefix + tensor_str, suffixes, indent, force_newline=self.is_sparse
657
if isinstance(self, torch.nn.Parameter) and not is_plain_tensor:
658
string_repr = f"Parameter({string_repr})"
663
def _functorch_wrapper_str_intern(tensor, *, tensor_contents=None):
664
level = torch._C._functorch.maybe_get_level(tensor)
667
if torch._C._functorch.is_functionaltensor(tensor):
672
value = torch._C._functorch.get_unwrapped(tensor)
673
value_repr = repr(value)
675
indented_value_repr = textwrap.indent(value_repr, " " * 4)
676
if torch._C._functorch.is_batchedtensor(tensor):
677
bdim = torch._C._functorch.maybe_get_bdim(tensor)
680
f"BatchedTensor(lvl={level}, bdim={bdim}, value=\n"
681
f"{indented_value_repr}\n"
684
if torch._C._functorch.is_gradtrackingtensor(tensor):
686
f"GradTrackingTensor(lvl={level}, value=\n" f"{indented_value_repr}\n" f")"
688
if torch._C._functorch.is_functionaltensor(tensor):
689
return f"FunctionalTensor(lvl={level}, value=\\\n{value_repr})"
691
raise ValueError("We don't know how to print this, please file us an issue")
694
def _str(self, *, tensor_contents=None):
695
with torch.no_grad(), torch.utils._python_dispatch._disable_current_modes():
696
guard = torch._C._DisableFuncTorch()
697
return _str_intern(self, tensor_contents=tensor_contents)