pytorch

Форк
0
/
_conversions.py 
119 строк · 3.5 Кб
1
# mypy: allow-untyped-defs
2
import torch
3
import torch._prims_common as utils
4

5
# Utilities should come BEFORE this import
6
from torch._decomp import register_decomposition
7
from torch._prims_common import TensorLikeType
8
from torch._prims_common.wrappers import out_wrapper
9
from torch._refs import _broadcast_shapes
10

11

12
# Data conversion references.
13
#
14
# Note: this module breaks the usual _refs to torch naming scheme where
15
# _refs.foo.bar is a ref for torch.foo.bar.  The following definitions are not
16
# part of _refs/__init__.py to avoid name clashes with Python builtin types
17
# (like int).
18

19
__all__ = [
20
    # dtypes
21
    "bfloat16",
22
    "bool",
23
    "byte",
24
    "cdouble",
25
    "cfloat",
26
    "chalf",
27
    "char",
28
    "double",
29
    "float",
30
    "half",
31
    "int",
32
    "long",
33
    "short",
34
    # misc
35
    "complex",
36
    "polar",
37
]
38

39

40
def _make_conversion_method(name: str, dtype: torch.dtype):
41
    def fn(
42
        self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format
43
    ) -> TensorLikeType:
44
        return self.to(dtype, memory_format=memory_format)  # type: ignore[call-overload]
45

46
    fn.__name__ = name
47
    return fn
48

49

50
bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16)
51

52
bool = _make_conversion_method("bool", torch.bool)
53

54
byte = _make_conversion_method("byte", torch.uint8)
55

56
cdouble = _make_conversion_method("cdouble", torch.cdouble)
57

58
cfloat = _make_conversion_method("cfloat", torch.cfloat)
59

60
chalf = _make_conversion_method("chalf", torch.complex32)
61

62
char = _make_conversion_method("char", torch.int8)
63

64
double = _make_conversion_method("double", torch.double)
65

66
float = _make_conversion_method("float", torch.float)
67

68
half = _make_conversion_method("half", torch.half)
69

70
int = _make_conversion_method("int", torch.int)
71

72
long = _make_conversion_method("long", torch.long)
73

74
short = _make_conversion_method("short", torch.short)
75

76

77
@register_decomposition(torch._ops.ops.aten.complex)
78
# Note: complex has type promotion tests disabled due to different semantics.
79
# exact_dtype is for compat with complex_check_dtype from core.
80
@out_wrapper(exact_dtype=True)
81
def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType:
82
    allowed_dtypes = (torch.float32, torch.float64, torch.float16)
83
    torch._check(
84
        real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes,
85
        lambda: (
86
            f"Expected both inputs to be Half, Float or Double tensors but got "
87
            f"{real.dtype} and {imag.dtype}"
88
        ),
89
    )
90
    torch._check(
91
        real.dtype == imag.dtype,
92
        lambda: (
93
            f"Expected object of scalar type {real.dtype} but got "
94
            f"scalar type {imag.dtype} for second argument"
95
        ),
96
    )
97
    result_dtype = utils.corresponding_complex_dtype(real.dtype)  # type: ignore[arg-type]
98
    common_shape = _broadcast_shapes(real.shape, imag.shape)
99
    result = real.new_empty(
100
        common_shape,
101
        dtype=result_dtype,
102
        layout=real.layout,
103
        device=real.device,
104
        # pin_memory=real.is_pinned(),  # NYI
105
    )
106
    result.real = real
107
    result.imag = imag
108
    return result
109

110

111
@register_decomposition(torch._ops.ops.aten.polar)
112
# Note: polar has type promotion tests disabled due to different semantics.
113
# exact_dtype is for compat with complex_check_dtype from core.
114
@out_wrapper(exact_dtype=True)
115
def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType:
116
    result = torch.complex(abs, angle)
117
    result.real = abs * torch.cos(angle)
118
    result.imag = abs * torch.sin(angle)
119
    return result
120

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.