pytorch

Форк
0
33 строки · 804.0 Байт
1
import torch
2

3
NUM_REPEATS = 1000
4
NUM_REPEAT_OF_REPEATS = 1000
5

6

7
class SubTensor(torch.Tensor):
8
    pass
9

10

11
class WithTorchFunction:
12
    def __init__(self, data, requires_grad=False):
13
        if isinstance(data, torch.Tensor):
14
            self._tensor = data
15
            return
16

17
        self._tensor = torch.tensor(data, requires_grad=requires_grad)
18

19
    @classmethod
20
    def __torch_function__(cls, func, types, args=(), kwargs=None):
21
        if kwargs is None:
22
            kwargs = {}
23

24
        return WithTorchFunction(args[0]._tensor + args[1]._tensor)
25

26

27
class SubWithTorchFunction(torch.Tensor):
28
    @classmethod
29
    def __torch_function__(cls, func, types, args=(), kwargs=None):
30
        if kwargs is None:
31
            kwargs = {}
32

33
        return super().__torch_function__(func, types, args, kwargs)
34

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.