pytorch
33 строки · 804.0 Байт
1import torch2
3NUM_REPEATS = 10004NUM_REPEAT_OF_REPEATS = 10005
6
7class SubTensor(torch.Tensor):8pass9
10
11class WithTorchFunction:12def __init__(self, data, requires_grad=False):13if isinstance(data, torch.Tensor):14self._tensor = data15return16
17self._tensor = torch.tensor(data, requires_grad=requires_grad)18
19@classmethod20def __torch_function__(cls, func, types, args=(), kwargs=None):21if kwargs is None:22kwargs = {}23
24return WithTorchFunction(args[0]._tensor + args[1]._tensor)25
26
27class SubWithTorchFunction(torch.Tensor):28@classmethod29def __torch_function__(cls, func, types, args=(), kwargs=None):30if kwargs is None:31kwargs = {}32
33return super().__torch_function__(func, types, args, kwargs)34