Aquí hay un ejemplo de Maximización de Expectativas (EM) utilizado para estimar la media y la desviación estándar. El código está en Python, pero debería ser fácil de seguir, incluso si no está familiarizado con el idioma.
La motivación para EM
Los puntos rojo y azul que se muestran a continuación se extraen de dos distribuciones normales diferentes, cada una con una media particular y una desviación estándar:
Para calcular aproximaciones razonables de la media "verdadera" y los parámetros de desviación estándar para la distribución roja, podríamos mirar fácilmente los puntos rojos y registrar la posición de cada uno, y luego usar las fórmulas familiares (y de manera similar para el grupo azul) .
Ahora considere el caso en el que sabemos que hay dos grupos de puntos, pero no podemos ver qué punto pertenece a qué grupo. En otras palabras, los colores están ocultos:
No es del todo obvio cómo dividir los puntos en dos grupos. Ahora no podemos mirar las posiciones y calcular estimaciones para los parámetros de la distribución roja o la distribución azul.
Aquí es donde se puede usar EM para resolver el problema.
Usando EM para estimar parámetros
Aquí está el código utilizado para generar los puntos que se muestran arriba. Puede ver las medias reales y las desviaciones estándar de las distribuciones normales de las que se extrajeron los puntos. Las variables red
y blue
mantienen las posiciones de cada punto en los grupos rojo y azul respectivamente:
import numpy as np
from scipy import stats
np.random.seed(110) # for reproducible random results
# set parameters
red_mean = 3
red_std = 0.8
blue_mean = 7
blue_std = 2
# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)
both_colours = np.sort(np.concatenate((red, blue)))
Si pudiéramos ver el color de cada punto, intentaríamos recuperar las medias y las desviaciones estándar utilizando las funciones de la biblioteca:
>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195
Pero como los colores están ocultos para nosotros, comenzaremos el proceso EM ...
Primero, simplemente adivinamos los valores para los parámetros de cada grupo ( paso 1 ). Estas conjeturas no tienen que ser buenas:
# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9
# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7
Suposiciones bastante malas: parece que los medios están muy lejos de cualquier "centro" de un grupo de puntos.
Para continuar con EM y mejorar estas conjeturas, calculamos la probabilidad de que cada punto de datos (independientemente de su color secreto) aparezca bajo estas conjeturas para la desviación media y estándar ( paso 2 ).
La variable both_colours
contiene cada punto de datos. La función stats.norm
calcula la probabilidad del punto bajo una distribución normal con los parámetros dados:
likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)
Esto nos dice, por ejemplo, que con nuestras conjeturas actuales, el punto de datos en 1.761 es mucho más probable que sea rojo (0.189) que azul (0.00003).
Podemos convertir estos dos valores de probabilidad en pesos ( paso 3 ) para que sumen 1 como sigue:
likelihood_total = likelihood_of_red + likelihood_of_blue
red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total
Con nuestras estimaciones actuales y nuestros pesos recién calculados, ahora podemos calcular nuevas estimaciones, probablemente mejores, para los parámetros ( paso 4 ). Necesitamos una función para la media y una función para la desviación estándar:
def estimate_mean(data, weight):
return np.sum(data * weight) / np.sum(weight)
def estimate_std(data, weight, mean):
variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
return np.sqrt(variance)
Estos se parecen mucho a las funciones habituales de la media y la desviación estándar de los datos. La diferencia es el uso de un weight
parámetro que asigna un peso a cada punto de datos.
Esta ponderación es la clave de EM. Cuanto mayor sea el peso de un color en un punto de datos, más influirá el punto de datos en las próximas estimaciones para los parámetros de ese color. En última instancia, esto tiene el efecto de tirar de cada parámetro en la dirección correcta.
Las nuevas conjeturas se calculan con estas funciones:
# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)
# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)
El proceso EM se repite con estas nuevas conjeturas desde el paso 2 en adelante. Podemos repetir los pasos para un número dado de iteraciones (digamos 20), o hasta que veamos que los parámetros convergen.
Después de cinco iteraciones, vemos que nuestras malas conjeturas iniciales comienzan a mejorar:
Después de 20 iteraciones, el proceso EM ha convergido más o menos:
A modo de comparación, aquí están los resultados del proceso EM en comparación con los valores calculados donde la información de color no está oculta:
| EM guess | Actual
----------+----------+--------
Red mean | 2.910 | 2.802
Red std | 0.854 | 0.871
Blue mean | 6.838 | 6.932
Blue std | 2.227 | 2.195
Nota: esta respuesta fue adaptada de mi respuesta en Stack Overflow aquí .