pytorch
63 строки · 2.1 Кб
1import operator
2import warnings
3from functools import reduce
4
5import torch
6import torch._utils
7from ..function import Function
8
9
10class Type(Function):
11@staticmethod
12def forward(ctx, i, dest_type):
13warnings.warn(
14"torch.autograd._functions.Type is deprecated as of PyTorch 2.1, please use "
15"torch.tensor.to(dtype=dtype) instead."
16)
17ctx.input_type = type(i)
18ctx.input_device = -1 if not i.is_cuda else i.get_device()
19return i.type(dest_type)
20
21@staticmethod
22def backward(ctx, grad_output):
23if ctx.input_device == -1:
24return grad_output.type(ctx.input_type), None
25else:
26with torch.cuda.device(ctx.input_device):
27return grad_output.type(ctx.input_type), None
28
29
30# TODO: deprecate this
31class Resize(Function):
32@staticmethod
33def forward(ctx, tensor, sizes):
34ctx.sizes = sizes
35ctx.numel = reduce(operator.mul, sizes, 1)
36if tensor.numel() != ctx.numel:
37raise RuntimeError(
38(
39"requested resize to {} ({} elements in total), "
40"but the given tensor has a size of {} ({} elements). "
41"autograd's resize can only change the shape of a given "
42"tensor, while preserving the number of elements. "
43).format(
44"x".join(map(str, sizes)),
45ctx.numel,
46"x".join(map(str, tensor.size())),
47tensor.numel(),
48)
49)
50ctx.input_sizes = tensor.size()
51if tensor.is_quantized:
52tensor.copy_(tensor)
53return tensor.contiguous().view(*sizes)
54if tensor.is_contiguous():
55result = tensor.new(tensor).contiguous().view(*sizes)
56return result
57else:
58return tensor.contiguous().view(*sizes)
59
60@staticmethod
61def backward(ctx, grad_output):
62assert grad_output.numel() == ctx.numel
63return grad_output.contiguous().view(ctx.input_sizes), None
64