pytorch

Форк
0
65 строк · 2.1 Кб
1
# mypy: allow-untyped-defs
2
import operator
3
from functools import reduce
4
from typing_extensions import deprecated
5

6
import torch
7
import torch._utils
8
from torch.autograd.function import Function
9

10

11
class Type(Function):
12
    @staticmethod
13
    @deprecated(
14
        "`torch.autograd._functions.Type` is deprecated as of PyTorch 2.1, "
15
        "please use `torch.tensor.to(dtype=dtype)` instead.",
16
        category=FutureWarning,
17
    )
18
    def forward(ctx, i, dest_type):
19
        ctx.input_type = type(i)
20
        ctx.input_device = -1 if not i.is_cuda else i.get_device()
21
        return i.type(dest_type)
22

23
    @staticmethod
24
    def backward(ctx, grad_output):
25
        if ctx.input_device == -1:
26
            return grad_output.type(ctx.input_type), None
27
        else:
28
            with torch.cuda.device(ctx.input_device):
29
                return grad_output.type(ctx.input_type), None
30

31

32
# TODO: deprecate this
33
class Resize(Function):
34
    @staticmethod
35
    def forward(ctx, tensor, sizes):
36
        ctx.sizes = sizes
37
        ctx.numel = reduce(operator.mul, sizes, 1)
38
        if tensor.numel() != ctx.numel:
39
            raise RuntimeError(
40
                (
41
                    "requested resize to {} ({} elements in total), "
42
                    "but the given tensor has a size of {} ({} elements). "
43
                    "autograd's resize can only change the shape of a given "
44
                    "tensor, while preserving the number of elements. "
45
                ).format(
46
                    "x".join(map(str, sizes)),
47
                    ctx.numel,
48
                    "x".join(map(str, tensor.size())),
49
                    tensor.numel(),
50
                )
51
            )
52
        ctx.input_sizes = tensor.size()
53
        if tensor.is_quantized:
54
            tensor.copy_(tensor)
55
            return tensor.contiguous().view(*sizes)
56
        if tensor.is_contiguous():
57
            result = tensor.new(tensor).contiguous().view(*sizes)
58
            return result
59
        else:
60
            return tensor.contiguous().view(*sizes)
61

62
    @staticmethod
63
    def backward(ctx, grad_output):
64
        assert grad_output.numel() == ctx.numel
65
        return grad_output.contiguous().view(ctx.input_sizes), None
66

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.