3
"""Export torch work functions for binary ufuncs, rename/tweak to match numpy.
4
This listing is further exported to public symbols in the `torch._numpy/_ufuncs.py` module.
8
from torch import ( # noqa: F401
12
bitwise_left_shift as left_shift,
14
bitwise_right_shift as right_shift,
50
from . import _dtypes_impl, _util
53
# work around torch limitations w.r.t. numpy
56
# - RuntimeError: expected scalar type Int but found Double
57
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool'
58
# - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
59
dtype = _dtypes_impl.result_type_impl(x, y)
60
is_bool = dtype == torch.bool
61
is_half = (x.dtype == torch.float16 or y.dtype == torch.float16) and (
67
work_dtype = torch.uint8
69
work_dtype = torch.float32
71
x = _util.cast_if_needed(x, work_dtype)
72
y = _util.cast_if_needed(y, work_dtype)
74
result = torch.matmul(x, y)
76
if work_dtype != dtype:
77
result = result.to(dtype)
82
# a stub implementation of divmod, should be improved after
83
# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch