pytorch-image-models

Форк
0
19 строк · 743.0 Байт
1
""" Linear layer (alternate definition)
2
"""
3
import torch
4
import torch.nn.functional as F
5
from torch import nn as nn
6

7

8
class Linear(nn.Linear):
9
    r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
10

11
    Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
12
    weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
13
    """
14
    def forward(self, input: torch.Tensor) -> torch.Tensor:
15
        if torch.jit.is_scripting():
16
            bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
17
            return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
18
        else:
19
            return F.linear(input, self.weight, self.bias)
20

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.