pytorch-image-models
19 строк · 743.0 Байт
1""" Linear layer (alternate definition)
2"""
3import torch
4import torch.nn.functional as F
5from torch import nn as nn
6
7
8class Linear(nn.Linear):
9r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
10
11Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
12weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
13"""
14def forward(self, input: torch.Tensor) -> torch.Tensor:
15if torch.jit.is_scripting():
16bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
17return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
18else:
19return F.linear(input, self.weight, self.bias)
20