3
import torch._prims_common as utils
6
from torch._decomp import register_decomposition
7
from torch._prims_common import TensorLikeType
8
from torch._prims_common.wrappers import out_wrapper
9
from torch._refs import _broadcast_shapes
40
def _make_conversion_method(name: str, dtype: torch.dtype):
42
self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format
44
return self.to(dtype, memory_format=memory_format)
50
bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16)
52
bool = _make_conversion_method("bool", torch.bool)
54
byte = _make_conversion_method("byte", torch.uint8)
56
cdouble = _make_conversion_method("cdouble", torch.cdouble)
58
cfloat = _make_conversion_method("cfloat", torch.cfloat)
60
chalf = _make_conversion_method("chalf", torch.complex32)
62
char = _make_conversion_method("char", torch.int8)
64
double = _make_conversion_method("double", torch.double)
66
float = _make_conversion_method("float", torch.float)
68
half = _make_conversion_method("half", torch.half)
70
int = _make_conversion_method("int", torch.int)
72
long = _make_conversion_method("long", torch.long)
74
short = _make_conversion_method("short", torch.short)
77
@register_decomposition(torch._ops.ops.aten.complex)
80
@out_wrapper(exact_dtype=True)
81
def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType:
82
allowed_dtypes = (torch.float32, torch.float64, torch.float16)
84
real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes,
86
f"Expected both inputs to be Half, Float or Double tensors but got "
87
f"{real.dtype} and {imag.dtype}"
91
real.dtype == imag.dtype,
93
f"Expected object of scalar type {real.dtype} but got "
94
f"scalar type {imag.dtype} for second argument"
97
result_dtype = utils.corresponding_complex_dtype(real.dtype)
98
common_shape = _broadcast_shapes(real.shape, imag.shape)
99
result = real.new_empty(
111
@register_decomposition(torch._ops.ops.aten.polar)
114
@out_wrapper(exact_dtype=True)
115
def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType:
116
result = torch.complex(abs, angle)
117
result.real = abs * torch.cos(angle)
118
result.imag = abs * torch.sin(angle)