pytorch
65 строк · 2.1 Кб
1# mypy: allow-untyped-defs
2import operator
3from functools import reduce
4from typing_extensions import deprecated
5
6import torch
7import torch._utils
8from torch.autograd.function import Function
9
10
11class Type(Function):
12@staticmethod
13@deprecated(
14"`torch.autograd._functions.Type` is deprecated as of PyTorch 2.1, "
15"please use `torch.tensor.to(dtype=dtype)` instead.",
16category=FutureWarning,
17)
18def forward(ctx, i, dest_type):
19ctx.input_type = type(i)
20ctx.input_device = -1 if not i.is_cuda else i.get_device()
21return i.type(dest_type)
22
23@staticmethod
24def backward(ctx, grad_output):
25if ctx.input_device == -1:
26return grad_output.type(ctx.input_type), None
27else:
28with torch.cuda.device(ctx.input_device):
29return grad_output.type(ctx.input_type), None
30
31
32# TODO: deprecate this
33class Resize(Function):
34@staticmethod
35def forward(ctx, tensor, sizes):
36ctx.sizes = sizes
37ctx.numel = reduce(operator.mul, sizes, 1)
38if tensor.numel() != ctx.numel:
39raise RuntimeError(
40(
41"requested resize to {} ({} elements in total), "
42"but the given tensor has a size of {} ({} elements). "
43"autograd's resize can only change the shape of a given "
44"tensor, while preserving the number of elements. "
45).format(
46"x".join(map(str, sizes)),
47ctx.numel,
48"x".join(map(str, tensor.size())),
49tensor.numel(),
50)
51)
52ctx.input_sizes = tensor.size()
53if tensor.is_quantized:
54tensor.copy_(tensor)
55return tensor.contiguous().view(*sizes)
56if tensor.is_contiguous():
57result = tensor.new(tensor).contiguous().view(*sizes)
58return result
59else:
60return tensor.contiguous().view(*sizes)
61
62@staticmethod
63def backward(ctx, grad_output):
64assert grad_output.numel() == ctx.numel
65return grad_output.contiguous().view(ctx.input_sizes), None
66