¿Qué es la activación de GELU?


18

Estaba revisando el documento BERT que usa GELU (Unidad lineal de error gaussiano) que establece la ecuación como que a su vez se aproxima a

GELU(x)=xP(Xx)=xΦ(x).
0.5x(1+tanh[2/π(x+0.044715x3)])

¿Podría simplificar la ecuación y explicar cómo se ha aproximado?

Respuestas:


19

Función GELU

Podemos expandir la distribución acumulativa de N(0,1) , es decir, Φ(x) , de la siguiente manera:

GELU(x):=xP(Xx)=xΦ(x)=0.5x(1+erf(x2))

Tenga en cuenta que esta es una definición , no una ecuación (o una relación). Los autores han proporcionado algunas justificaciones para esta propuesta, por ejemplo, una analogía estocástica , aunque matemáticamente, esta es solo una definición.

Aquí está la trama de GELU:

Aproximación Tanh

Para este tipo de aproximaciones numéricas, la idea clave es encontrar una función similar (principalmente basada en la experiencia), parametrizarla y luego ajustarla a un conjunto de puntos de la función original.

Sabiendo que está muy cerca deerf(x)tanh(x)

y la primera derivada de coincide con la de en , que es , procedemos a ajustar (o con más términos) a un conjunto de puntos .erf(x2)tanh(2πx)x=02π

tanh(2π(x+ax2+bx3+cx4+dx5))
(xi,erf(xi2))

He ajustado esta función a 20 muestras entre ( usando este sitio ), y aquí están los coeficientes:(1.5,1.5)

Al establecer , se estimó en . Con más muestras de un rango más amplio (ese sitio solo permitió 20), el coeficiente estará más cerca del del papel . Finalmente llegamosa=c=d=0b0.04495641b0.044715

GELU(x)=xΦ(x)=0.5x(1+erf(x2))0.5x(1+tanh(2π(x+0.044715x3)))

con error cuadrático medio para .108x[10,10]

Tenga en cuenta que si no utilizamos la relación entre las primeras derivadas, el término se habría incluido en los parámetros de la siguiente manera que es menos bella (menos analítica, más numérica).2π

0.5x(1+tanh(0.797885x+0.035677x3))

Utilizando la paridad

Como lo sugiere @BookYourLuck , podemos utilizar la paridad de funciones para restringir el espacio de los polinomios en los que buscamos. Es decir, dado que es una función extraña, es decir, , y también es una función extraña, función polinómica dentro también debe ser impar (solo debe tener poderes impares de ) para tener erff(x)=f(x)tanhpol(x)tanhx

erf(x)tanh(pol(x))=tanh(pol(x))=tanh(pol(x))erf(x)

Anteriormente, tuvimos la suerte de terminar con (casi) cero los coeficientes de potencias pares y , sin embargo, en general, esto podría dar lugar a aproximaciones de baja calidad que, por ejemplo, tienen un término como que se cancela mediante términos adicionales (pares o impares) en lugar de simplemente optar por .x2x40.23x20x2

Aproximación sigmoidea

Una relación similar se mantiene entre y (sigmoid), que se propone en el documento como otra aproximación, con error cuadrático medio para .erf(x)2(σ(x)12)104x[10,10]

Aquí hay un código de Python para generar puntos de datos, ajustar las funciones y calcular los errores cuadrados medios:

import math
import numpy as np
import scipy.optimize as optimize


def tahn(xs, a):
    return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]


def sigmoid(xs, a):
    return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]


print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
#       .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])

# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))

sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))

# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()

# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()

print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)

if print_points == 1:
    print(len(xs))
    for x, erf in zip(xs, erfs):
        print(x, erf)

Salida:

Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05

2
¿Por qué se necesita la aproximación? ¿No podrían simplemente usar la función erf?
SebiSebi

8

Φ(x)=12erfc(x2)=12(1+erf(x2))
erf
erf(x2)tanh(2π(x+ax3))
a0.044715

x[1,1]x

tanh(x)=xx33+o(x3)
erf(x)=2π(xx33)+o(x3).
tanh(2π(x+ax3))=2π(x+(a23π)x3)+o(x3)
erf(x2)=2π(xx36)+o(x3).
x3
a0.04553992412
0.044715

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.