pytorch

Форк
0
/
test_operators.py 
24 строки · 683.0 Байт
1
import torch.library
2
from torch import Tensor
3
from torch.autograd import Function
4

5
_test_lib_def = torch.library.Library("_inductor_test", "DEF")
6
_test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag)
7

8
_test_lib_impl = torch.library.Library("_inductor_test", "IMPL")
9
for dispatch_key in ("CPU", "CUDA", "Meta"):
10
    _test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key)
11

12

13
class Realize(Function):
14
    @staticmethod
15
    def forward(ctx, x):
16
        return torch.ops._inductor_test.realize(x)
17

18
    @staticmethod
19
    def backward(ctx, grad_output):
20
        return grad_output
21

22

23
def realize(x: Tensor) -> Tensor:
24
    return Realize.apply(x)
25

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.