pytorch
1# Owner(s): ["module: onnx"]
2
3import torch
4
5
6# Autograd funtion that is a replica of the autograd funtion in
7# test_utility_funs.py (test_autograd_module_name)
8class CustomFunction(torch.autograd.Function):
9@staticmethod
10def forward(ctx, input):
11ctx.save_for_backward(input)
12return input.clamp(min=0)
13
14@staticmethod
15def backward(ctx, grad_output):
16(input,) = ctx.saved_tensors
17grad_input = grad_output.clone()
18grad_input[input < 0] = 0
19return grad_input
20