Me he hecho esta pregunta durante meses. Todas las respuestas en CrossValidated y Quora enumeran buenas propiedades de la función sigmoidea logística, pero parece que adivinamos hábilmente esta función. Lo que me perdí fue la justificación para elegirlo. Finalmente encontré uno en la sección 6.2.2.2 del libro "Deep Learning" de Bengio (2016) . En mis propias palabras:
En resumen, queremos que el logaritmo de la salida del modelo sea adecuado para la optimización basada en gradiente de la probabilidad logarítmica de los datos de entrenamiento.
Motivación
- Queremos un modelo lineal, pero no podemos usar z=wTx+b directamente como z∈(−∞,+∞) .
- Para la clasificación, tiene sentido asumir la distribución de Bernoulli y modelar su parámetro θ en P(Y=1)=θ .
- Entonces, necesitamos mapear z desde (−∞,+∞) hasta [0,1] para hacer la clasificación.
¿Por qué la función logística sigmoidea?
Cortar z con P(Y=1|z)=max{0,min{1,z}} produce un gradiente cero para z fuera de [0,1] . Necesitamos un gradiente fuerte siempre que la predicción del modelo sea incorrecta, porque resolvemos la regresión logística con el descenso del gradiente. Para la regresión logística, no existe una solución de forma cerrada.
La función logística tiene la buena propiedad de asíntotar un gradiente constante cuando la predicción del modelo es incorrecta, dado que usamos la Estimación de máxima verosimilitud para ajustar el modelo. Esto se muestra a continuación:
Para obtener beneficios numéricos, la Estimación de máxima verosimilitud se puede hacer minimizando la probabilidad logarítmica negativa de los datos de entrenamiento. Entonces, nuestra función de costo es:
J(w,b)=1m∑i=1m−logP(Y=yi|xi;w,b)=1m∑i=1m−(yilogP(Y=1|z)+(yi−1)logP(Y=0|z))
Como P(Y=0|z)=1−P(Y=1|z) , podemos centrarnos en el caso Y=1 . Entonces, la pregunta es cómo modelar P(Y=1|z) dado que tenemos z=wTx+b .
Los requisitos obvios para la función f mapeo z a P(Y=1|z) son:
- ∀z∈R:f(z)∈[0,1]
- f(0)=0.5
- f debe ser wrt simétrico rotacionalmente(0,0.5) , es decir,f(−x)=1−f(x) , de modo que voltear los signos de las clases no tiene efecto en la función de costo.
- f debe ser no decreciente, continuo y diferenciable.
Todos estos requisitos se cumplen reescalando las funciones sigmoideas . Ambos f(z)=11+e−z yf(z)=0.5+0.5z1+|z|cumplirlos Sin embargo, las funciones sigmoideas difieren con respecto a su comportamiento durante la optimización basada en gradiente de la probabilidad logarítmica. Podemos ver la diferencia conectando la función logísticaf(z)=11+e−z en nuestra función de costos.
Saturación para Y=1
For P(Y=1|z)=11+e−z and Y=1, the cost of a single misclassified sample (i.e. m=1) is:
J(z)=−log(P(Y=1|z))=−log(11+e−z)=−log(ez1+ez)=−z+log(1+ez)
We can see that there is a linear component −z. Now, we can look at two cases:
- When z is large, the model's prediction was correct, since Y=1. In the cost function, the log(1+ez) term asymptotes to z for large z. Thus, it roughly cancels the −z out leading to a roughly zero cost for this sample and a weak gradient. That makes sense, as the model is already predicting the correct class.
- When z is small (but |z| is large), the model's prediction was not correct, since Y=1. In the cost function, the log(1+ez) term asymptotes to 0 for small z. Thus, the overall cost for this sample is roughly −z, meaning the gradient w.r.t. z is roughly −1. This makes it easy for the model to correct its wrong prediction based on the constant gradient it receives. Even for very small z, there is no saturation going on, which would cause vanishing gradients.
Saturation for Y=0
Above, we focussed on the Y=1 case. For Y=0, the cost function behaves analogously, providing strong gradients only when the model's prediction is wrong.
This is the cost function J(z) for Y=1:
It is the horizontally flipped softplus function. For Y=0, it is the softplus function.
Alternatives
You mentioned the alternatives to the logistic sigmoid function, for example z1+|z|. Normalized to [0,1], this would mean that we model P(Y=1|z)=0.5+0.5z1+|z|.
During MLE, the cost function for Y=1 would then be
J(z)=−log(0.5+0.5z1+|z|),
which looks like this:
You can see, that the gradient of the cost function gets weaker and weaker for z→−∞.