pytorch

Форк
0
/
autograd_helper.py 
19 строк · 518.0 Байт
1
# Owner(s): ["module: onnx"]
2

3
import torch
4

5

6
# Autograd funtion that is a replica of the autograd funtion in
7
# test_utility_funs.py (test_autograd_module_name)
8
class CustomFunction(torch.autograd.Function):
9
    @staticmethod
10
    def forward(ctx, input):
11
        ctx.save_for_backward(input)
12
        return input.clamp(min=0)
13

14
    @staticmethod
15
    def backward(ctx, grad_output):
16
        (input,) = ctx.saved_tensors
17
        grad_input = grad_output.clone()
18
        grad_input[input < 0] = 0
19
        return grad_input
20

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.