pytorch
1import torch.library
2from torch import Tensor
3from torch.autograd import Function
4
5_test_lib_def = torch.library.Library("_inductor_test", "DEF")
6_test_lib_def.define("realize(Tensor self) -> Tensor", tags=torch.Tag.pt2_compliant_tag)
7
8_test_lib_impl = torch.library.Library("_inductor_test", "IMPL")
9for dispatch_key in ("CPU", "CUDA", "Meta"):
10_test_lib_impl.impl("realize", lambda x: x.clone(), dispatch_key)
11
12
13class Realize(Function):
14@staticmethod
15def forward(ctx, x):
16return torch.ops._inductor_test.realize(x)
17
18@staticmethod
19def backward(ctx, grad_output):
20return grad_output
21
22
23def realize(x: Tensor) -> Tensor:
24return Realize.apply(x)
25