pytorch

Форк
0
/
_binary_ufuncs_impl.py 
85 строк · 1.8 Кб
1
# mypy: ignore-errors
2

3
"""Export torch work functions for binary ufuncs, rename/tweak to match numpy.
4
This listing is further exported to public symbols in the `torch._numpy/_ufuncs.py` module.
5
"""
6

7
import torch
8
from torch import (  # noqa: F401
9
    add,
10
    arctan2,
11
    bitwise_and,
12
    bitwise_left_shift as left_shift,
13
    bitwise_or,
14
    bitwise_right_shift as right_shift,
15
    bitwise_xor,
16
    copysign,
17
    divide,
18
    eq as equal,
19
    float_power,
20
    floor_divide,
21
    fmax,
22
    fmin,
23
    fmod,
24
    gcd,
25
    greater,
26
    greater_equal,
27
    heaviside,
28
    hypot,
29
    lcm,
30
    ldexp,
31
    less,
32
    less_equal,
33
    logaddexp,
34
    logaddexp2,
35
    logical_and,
36
    logical_or,
37
    logical_xor,
38
    maximum,
39
    minimum,
40
    multiply,
41
    nextafter,
42
    not_equal,
43
    pow as power,
44
    remainder,
45
    remainder as mod,
46
    subtract,
47
    true_divide,
48
)
49

50
from . import _dtypes_impl, _util
51

52

53
# work around torch limitations w.r.t. numpy
54
def matmul(x, y):
55
    # work around:
56
    #  - RuntimeError: expected scalar type Int but found Double
57
    #  - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool'
58
    #  - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
59
    dtype = _dtypes_impl.result_type_impl(x, y)
60
    is_bool = dtype == torch.bool
61
    is_half = (x.dtype == torch.float16 or y.dtype == torch.float16) and (
62
        x.is_cpu or y.is_cpu
63
    )
64

65
    work_dtype = dtype
66
    if is_bool:
67
        work_dtype = torch.uint8
68
    if is_half:
69
        work_dtype = torch.float32
70

71
    x = _util.cast_if_needed(x, work_dtype)
72
    y = _util.cast_if_needed(y, work_dtype)
73

74
    result = torch.matmul(x, y)
75

76
    if work_dtype != dtype:
77
        result = result.to(dtype)
78

79
    return result
80

81

82
# a stub implementation of divmod, should be improved after
83
# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch
84
def divmod(x, y):
85
    return x // y, x % y
86

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.