pytorch

Форк
0
63 строки · 2.1 Кб
1
import operator
2
import warnings
3
from functools import reduce
4

5
import torch
6
import torch._utils
7
from ..function import Function
8

9

10
class Type(Function):
11
    @staticmethod
12
    def forward(ctx, i, dest_type):
13
        warnings.warn(
14
            "torch.autograd._functions.Type is deprecated as of PyTorch 2.1, please use "
15
            "torch.tensor.to(dtype=dtype) instead."
16
        )
17
        ctx.input_type = type(i)
18
        ctx.input_device = -1 if not i.is_cuda else i.get_device()
19
        return i.type(dest_type)
20

21
    @staticmethod
22
    def backward(ctx, grad_output):
23
        if ctx.input_device == -1:
24
            return grad_output.type(ctx.input_type), None
25
        else:
26
            with torch.cuda.device(ctx.input_device):
27
                return grad_output.type(ctx.input_type), None
28

29

30
# TODO: deprecate this
31
class Resize(Function):
32
    @staticmethod
33
    def forward(ctx, tensor, sizes):
34
        ctx.sizes = sizes
35
        ctx.numel = reduce(operator.mul, sizes, 1)
36
        if tensor.numel() != ctx.numel:
37
            raise RuntimeError(
38
                (
39
                    "requested resize to {} ({} elements in total), "
40
                    "but the given tensor has a size of {} ({} elements). "
41
                    "autograd's resize can only change the shape of a given "
42
                    "tensor, while preserving the number of elements. "
43
                ).format(
44
                    "x".join(map(str, sizes)),
45
                    ctx.numel,
46
                    "x".join(map(str, tensor.size())),
47
                    tensor.numel(),
48
                )
49
            )
50
        ctx.input_sizes = tensor.size()
51
        if tensor.is_quantized:
52
            tensor.copy_(tensor)
53
            return tensor.contiguous().view(*sizes)
54
        if tensor.is_contiguous():
55
            result = tensor.new(tensor).contiguous().view(*sizes)
56
            return result
57
        else:
58
            return tensor.contiguous().view(*sizes)
59

60
    @staticmethod
61
    def backward(ctx, grad_output):
62
        assert grad_output.numel() == ctx.numel
63
        return grad_output.contiguous().view(ctx.input_sizes), None
64

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.