CSS-LM
56 строк · 1.5 Кб
1import logging2import math3
4import torch5import torch.nn.functional as F6
7
8logger = logging.getLogger(__name__)9
10
11def swish(x):12return x * torch.sigmoid(x)13
14
15def _gelu_python(x):16""" Original Implementation of the gelu activation function in Google Bert repo when initially created.17For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
180.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
19This is now written in C in torch.nn.functional
20Also see https://arxiv.org/abs/1606.08415
21"""
22return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))23
24
25def gelu_new(x):26""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).27Also see https://arxiv.org/abs/1606.08415
28"""
29return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))30
31
32if torch.__version__ < "1.4.0":33gelu = _gelu_python34else:35gelu = F.gelu36
37
38def gelu_fast(x):39return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))40
41
42ACT2FN = {43"relu": F.relu,44"swish": swish,45"gelu": gelu,46"tanh": torch.tanh,47"gelu_new": gelu_new,48"gelu_fast": gelu_fast,49}
50
51
52def get_activation(activation_string):53if activation_string in ACT2FN:54return ACT2FN[activation_string]55else:56raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))57