¿Cuál es una explicación intuitiva de la técnica de maximización de expectativas? [cerrado]


109

La maximización de expectativas (EM) es una especie de método probabilístico para clasificar datos. Por favor corríjame si me equivoco si no es un clasificador.

¿Cuál es una explicación intuitiva de esta técnica EM? ¿Qué hay expectationaquí y qué está siendo maximized?


12
¿Qué es el algoritmo de maximización de expectativas? , Nature Biotechnology 26 , 897–899 (2008) tiene una bonita imagen que ilustra cómo funciona el algoritmo.
chl

@chl En la parte b de la bonita imagen , ¿cómo obtuvieron los valores de la distribución de probabilidad en Z (es decir, 0.45xA, 0.55xB, etc.)?
Noob Saibot


3
Enlace actualizado a la imagen que mencionó @chl.
n1k31t4

Respuestas:


120

Nota: el código detrás de esta respuesta se puede encontrar aquí .


Supongamos que tenemos algunos datos muestreados de dos grupos diferentes, rojo y azul:

ingrese la descripción de la imagen aquí

Aquí, podemos ver qué punto de datos pertenece al grupo rojo o azul. Esto facilita la búsqueda de los parámetros que caracterizan a cada grupo. Por ejemplo, la media del grupo rojo es alrededor de 3, la media del grupo azul es alrededor de 7 (y podríamos encontrar la media exacta si quisiéramos).

Esto se conoce, en general, como estimación de máxima verosimilitud. . Dados algunos datos, calculamos el valor de un parámetro (o parámetros) que mejor explica esos datos.

Ahora imagine que no podemos ver qué valor se muestreó de qué grupo. Todo nos parece morado:

ingrese la descripción de la imagen aquí

Aquí tenemos el conocimiento de que hay dos grupos de valores, pero no sabemos a qué grupo pertenece un valor en particular.

¿Todavía podemos estimar las medias para el grupo rojo y el grupo azul que mejor se ajustan a estos datos?

¡Sí, a menudo podemos! La maximización de expectativas nos da una forma de hacerlo. La idea muy general detrás del algoritmo es esta:

  1. Comience con una estimación inicial de lo que podría ser cada parámetro.
  2. Calcule la probabilidad de que cada parámetro produzca el punto de datos.
  3. Calcule los pesos para cada punto de datos indicando si es más rojo o más azul en función de la probabilidad de que lo produzca un parámetro. Combine los pesos con los datos ( expectativa ).
  4. Calcule una mejor estimación de los parámetros utilizando los datos ajustados por peso ( maximización ).
  5. Repita los pasos 2 a 4 hasta que la estimación del parámetro converja (el proceso deja de producir una estimación diferente).

Estos pasos necesitan una explicación más detallada, por lo que analizaré el problema descrito anteriormente.

Ejemplo: estimación de la desviación estándar y media

Usaré Python en este ejemplo, pero el código debería ser bastante fácil de entender si no está familiarizado con este lenguaje.

Supongamos que tenemos dos grupos, rojo y azul, con los valores distribuidos como en la imagen de arriba. Específicamente, cada grupo contiene un valor extraído de una distribución normal con los siguientes parámetros:

import numpy as np
from scipy import stats

np.random.seed(110) # for reproducible results

# set parameters
red_mean = 3
red_std = 0.8

blue_mean = 7
blue_std = 2

# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)

both_colours = np.sort(np.concatenate((red, blue))) # for later use...

Aquí hay una imagen de estos grupos rojo y azul nuevamente (para evitar que tenga que desplazarse hacia arriba):

ingrese la descripción de la imagen aquí

Cuando podemos ver el color de cada punto (es decir, a qué grupo pertenece), es muy fácil estimar la media y la desviación estándar de cada grupo. Simplemente pasamos los valores rojo y azul a las funciones integradas en NumPy. Por ejemplo:

>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195

Pero, ¿y si no podemos ver los colores de los puntos? Es decir, en lugar de rojo o azul, todos los puntos se han coloreado de púrpura.

Para intentar recuperar los parámetros de desviación estándar y media de los grupos rojo y azul, podemos utilizar la maximización de expectativas.

Nuestro primer paso ( paso 1 anterior) es adivinar los valores de los parámetros para la media y la desviación estándar de cada grupo. No tenemos que adivinar inteligentemente; podemos elegir los números que queramos:

# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9

# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7

Estas estimaciones de parámetros producen curvas de campana que se ven así:

ingrese la descripción de la imagen aquí

Estas son malas estimaciones. Ambos medios (las líneas verticales punteadas) se ven lejos de cualquier tipo de "medio" para grupos sensibles de puntos, por ejemplo. Queremos mejorar estas estimaciones.

El siguiente paso ( paso 2 ) es calcular la probabilidad de que cada punto de datos aparezca bajo las suposiciones de parámetros actuales:

likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)

Aquí, simplemente hemos puesto cada punto de datos en la función de densidad de probabilidad para una distribución normal usando nuestras suposiciones actuales en la desviación estándar y media para el rojo y el azul. Esto nos dice, por ejemplo, que con nuestras suposiciones actuales, es mucho más probable que el punto de datos en 1,761 sea ​​rojo (0,189) que azul (0,00003).

Para cada punto de datos, podemos convertir estos dos valores de probabilidad en pesos ( paso 3 ) para que sumen 1 de la siguiente manera:

likelihood_total = likelihood_of_red + likelihood_of_blue

red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total

Con nuestras estimaciones actuales y nuestras ponderaciones recién calculadas, ahora podemos calcular nuevas estimaciones para la media y la desviación estándar de los grupos rojo y azul ( paso 4 ).

Calculamos dos veces la media y la desviación estándar usando todos los puntos de datos, pero con diferentes ponderaciones: una vez para las ponderaciones rojas y una vez para las azules.

La clave de la intuición es que cuanto mayor es el peso de un color en un punto de datos, más influye el punto de datos en las próximas estimaciones de los parámetros de ese color. Esto tiene el efecto de "tirar" de los parámetros en la dirección correcta.

def estimate_mean(data, weight):
    """
    For each data point, multiply the point by the probability it
    was drawn from the colour's distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among our data points.
    """
    return np.sum(data * weight) / np.sum(weight)

def estimate_std(data, weight, mean):
    """
    For each data point, multiply the point's squared difference
    from a mean value by the probability it was drawn from
    that distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among the values for the difference of
    each data point from the mean.

    This is the estimate of the variance, take the positive square
    root to find the standard deviation.
    """
    variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
    return np.sqrt(variance)

# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)

# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)

Tenemos nuevas estimaciones para los parámetros. Para mejorarlos de nuevo, podemos volver al paso 2 y repetir el proceso. Hacemos esto hasta que las estimaciones converjan, o después de que se hayan realizado algunas iteraciones ( paso 5 ).

Para nuestros datos, las primeras cinco iteraciones de este proceso se ven así (las iteraciones recientes tienen una apariencia más fuerte):

ingrese la descripción de la imagen aquí

Vemos que las medias ya están convergiendo en algunos valores, y las formas de las curvas (gobernadas por la desviación estándar) también se están volviendo más estables.

Si continuamos durante 20 iteraciones, terminamos con lo siguiente:

ingrese la descripción de la imagen aquí

El proceso EM ha convergido a los siguientes valores, que resultan muy cercanos a los valores reales (donde podemos ver los colores, sin variables ocultas):

          | EM guess | Actual |  Delta
----------+----------+--------+-------
Red mean  |    2.910 |  2.802 |  0.108
Red std   |    0.854 |  0.871 | -0.017
Blue mean |    6.838 |  6.932 | -0.094
Blue std  |    2.227 |  2.195 |  0.032

En el código anterior, es posible que haya notado que la nueva estimación de la desviación estándar se calculó utilizando la estimación de la iteración anterior para la media. En última instancia, no importa si primero calculamos un nuevo valor para la media, ya que solo estamos encontrando la varianza (ponderada) de los valores alrededor de algún punto central. Seguiremos viendo converger las estimaciones de los parámetros.


¿Qué pasa si ni siquiera sabemos el número de distribuciones normales de las que proviene esto? Aquí ha tomado un ejemplo de k = 2 distribuciones, ¿podemos también estimar k, y los k conjuntos de parámetros?
stackit

1
@stackit: No estoy seguro de que haya una forma general sencilla de calcular el valor más probable de k como parte del proceso EM en este caso. El problema principal es que necesitaríamos comenzar EM con estimaciones para cada uno de los parámetros que queremos encontrar, y eso implica que necesitamos saber / estimar k antes de comenzar. Sin embargo, es posible estimar aquí la proporción de puntos que pertenecen a un grupo a través de EM. Quizás si sobrestimamos k, la proporción de todos los grupos menos dos caería a casi cero. No he experimentado con esto, así que no sé qué tan bien funcionaría en la práctica.
Alex Riley

1
@AlexRiley ¿Puede decir un poco más sobre las fórmulas para calcular las nuevas estimaciones de desviación estándar y media?
Lemon

2
@AlexRiley Gracias por la explicación. ¿Por qué se calculan las nuevas estimaciones de la desviación estándar utilizando la estimación anterior de la media? ¿Qué pasa si las nuevas estimaciones de la media se encuentran primero?
GoodDeeds

1
@Lemon GoodDeeds Kaushal: disculpas por mi tardía respuesta a tus preguntas. Intenté editar la respuesta para abordar los puntos que ha planteado. También he hecho que todo el código utilizado en esta respuesta esté accesible en un cuaderno aquí (que también incluye explicaciones más detalladas de algunos puntos que mencioné).
Alex Riley

36

EM es un algoritmo para maximizar una función de probabilidad cuando algunas de las variables en su modelo no se observan (es decir, cuando tiene variables latentes).

Podría preguntar, si solo estamos tratando de maximizar una función, ¿por qué no usamos la maquinaria existente para maximizar una función? Bueno, si intenta maximizar esto tomando derivadas y poniéndolas a cero, encontrará que en muchos casos las condiciones de primer orden no tienen solución. Hay un problema del huevo y la gallina en el sentido de que para resolver los parámetros de su modelo necesita conocer la distribución de sus datos no observados; pero la distribución de sus datos no observados es una función de los parámetros de su modelo.

EM intenta evitar esto adivinando iterativamente una distribución para los datos no observados, luego estimando los parámetros del modelo maximizando algo que es un límite inferior en la función de probabilidad real y repitiendo hasta la convergencia:

El algoritmo EM

Comience con adivinar los valores de los parámetros de su modelo

Paso E: para cada punto de datos que tenga valores perdidos, use la ecuación de su modelo para resolver la distribución de los datos faltantes dada su estimación actual de los parámetros del modelo y dados los datos observados (tenga en cuenta que está resolviendo una distribución para cada valor, no para el valor esperado). Ahora que tenemos una distribución para cada valor perdido, podemos calcular la expectativa de la función de verosimilitud con respecto a las variables no observadas. Si nuestra conjetura para el parámetro del modelo fue correcta, esta probabilidad esperada será la probabilidad real de nuestros datos observados; si los parámetros no son correctos, será solo un límite inferior.

Paso M: ahora que tenemos una función de probabilidad esperada sin variables no observadas en ella, maximice la función como lo haría en el caso completamente observado, para obtener una nueva estimación de los parámetros de su modelo.

Repita hasta convergencia.


5
No entiendo tu E-paso. Parte del problema es que mientras estoy aprendiendo estas cosas, no puedo encontrar personas que usen la misma terminología. Entonces, ¿qué quieres decir con ecuación modelo? No sé a qué te refieres con resolver una distribución de probabilidad.
user678392

27

A continuación, se incluye una receta sencilla para comprender el algoritmo de maximización de expectativas:

1- Lea este artículo tutorial de EM de Do y Batzoglou.

2- Es posible que tenga signos de interrogación en la cabeza, eche un vistazo a las explicaciones en esta página de intercambio de pilas de matemáticas .

3- Mira este código que escribí en Python que explica el ejemplo en el tutorial de EM del artículo 1:

Advertencia: el código puede ser desordenado / subóptimo, ya que no soy un desarrollador de Python. Pero cumple su función.

import numpy as np
import math

#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* #### 

def get_mn_log_likelihood(obs,probs):
    """ Return the (log)likelihood of obs, given the probs"""
    # Multinomial Distribution Log PMF
    # ln (pdf)      =             multinomial coeff            *   product of probabilities
    # ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]     

    multinomial_coeff_denom= 0
    prod_probs = 0
    for x in range(0,len(obs)): # loop through state counts in each observation
        multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
        prod_probs = prod_probs + obs[x]*math.log(probs[x])

    multinomial_coeff = math.log(math.factorial(sum(obs))) -  multinomial_coeff_denom
    likelihood = multinomial_coeff + prod_probs
    return likelihood

# 1st:  Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd:  Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd:  Coin A, {HTHHHHHTHH}, 8H,2T
# 4th:  Coin B, {HTHTTTHHTT}, 4H,6T
# 5th:  Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45

# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = zip(head_counts,tail_counts)

# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50

# E-M begins!
delta = 0.001  
j = 0 # iteration counter
improvement = float('inf')
while (improvement>delta):
    expectation_A = np.zeros((5,2), dtype=float) 
    expectation_B = np.zeros((5,2), dtype=float)
    for i in range(0,len(experiments)):
        e = experiments[i] # i'th experiment
        ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
        ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B

        weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A 
        weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B                            

        expectation_A[i] = np.dot(weightA, e) 
        expectation_B[i] = np.dot(weightB, e)

    pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A)); 
    pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B)); 

    improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
    j = j+1

Encuentro que su programa dará como resultado A y B a 0.66, también lo implemento usando scala, también encuentro que el resultado es 0.66, ¿pueden ayudar a verificar eso?
zjffdu

Usando una hoja de cálculo, solo encuentro sus 0,66 resultados si mis conjeturas iniciales son iguales. De lo contrario, puedo reproducir el resultado del tutorial.
soakley

@zjffdu, ¿cuántas iteraciones ejecuta el EM antes de devolverle 0.66? Si inicializa con valores iguales, es posible que se atasque en un máximo local y verá que el número de iteraciones es extremadamente bajo (ya que no hay mejora).
Zhubarb

También puede consultar esta diapositiva de Andrew Ng y la nota del curso de Harvard
Minh Phan

16

Técnicamente, el término "EM" está un poco subespecificado, pero supongo que se refiere a la técnica de análisis de conglomerados del modelado de mezcla gaussiana, que es un ejemplo del principio general de EM.

En realidad, el análisis de conglomerados EM no es un clasificador . Sé que algunas personas consideran que la agrupación es una "clasificación no supervisada", pero en realidad el análisis de agrupaciones es algo bastante diferente.

La diferencia clave, y el gran malentendido de clasificación que la gente siempre tiene con el análisis de conglomerados es que: en el análisis de conglomerados, no existe una "solución correcta" . Es un método de descubrimiento de conocimiento , ¡en realidad está destinado a encontrar algo nuevo ! Esto hace que la evaluación sea muy complicada. A menudo se evalúa utilizando una clasificación conocida como referencia, pero eso no siempre es apropiado: la clasificación que tiene puede reflejar o no lo que hay en los datos.

Déjame darte un ejemplo: tienes un gran conjunto de datos de clientes, incluidos datos de género. Un método que divide este conjunto de datos en "masculino" y "femenino" es óptimo cuando se compara con las clases existentes. En una forma de "predicción" de pensar, esto es bueno, ya que para los nuevos usuarios ahora puede predecir su género. En una forma de pensar de "descubrimiento de conocimiento", esto es realmente malo, porque quería descubrir alguna estructura nueva en los datos. Sin embargo, un método que dividiría los datos en personas mayores y niños obtendría la peor puntuación posible con respecto a la clase masculina / femenina. Sin embargo, ese sería un excelente resultado de agrupación (si no se proporcionara la edad).

Ahora volvamos a EM. Esencialmente, asume que sus datos están compuestos por múltiples distribuciones normales multivariadas (tenga en cuenta que esta es una suposición muy sólida, en particular cuando se fija el número de clústeres). Luego trata de encontrar un modelo óptimo local para esto mejorando alternativamente el modelo y la asignación de objetos al modelo .

Para obtener los mejores resultados en un contexto de clasificación, elija el número de clústeres más grande que el número de clases, o incluso aplique el clúster solo a clases individuales (¡para averiguar si hay alguna estructura dentro de la clase!).

Supongamos que desea entrenar a un clasificador para distinguir "automóviles", "bicicletas" y "camiones". Es de poca utilidad suponer que los datos constan de exactamente 3 distribuciones normales. Sin embargo, puede suponer que hay más de un tipo de automóvil (y camiones y bicicletas). Entonces, en lugar de entrenar a un clasificador para estas tres clases, agrupa autos, camiones y bicicletas en 10 grupos cada uno (o tal vez 10 autos, 3 camiones y 3 bicicletas, lo que sea), luego entrena a un clasificador para diferenciar estas 30 clases, y luego fusionar el resultado de la clase con las clases originales. También puede descubrir que hay un grupo que es particularmente difícil de clasificar, por ejemplo, Triciclos. Son algo coches y algo bicicletas. O camiones de reparto, que se parecen más a coches de gran tamaño que a camiones.


¿Cómo se subespecifica EM?
sam boosalis

Hay más de una versión. Técnicamente, también puede llamar "EM" al estilo de Lloyd. Debes especificar qué modelo usas.
Ha QUIT - Anony-Mousse

2

Si las demás respuestas son buenas, intentaré proporcionar otra perspectiva y abordar la parte intuitiva de la pregunta.

El algoritmo EM (Expectation-Maximization) es una variante de una clase de algoritmos iterativos que utilizan la dualidad

Extracto (el énfasis es mío):

En matemáticas, una dualidad, en general, traduce conceptos, teoremas o estructuras matemáticas en otros conceptos, teoremas o estructuras, de manera uno a uno, a menudo (pero no siempre) mediante una operación de involución: si el dual de A es B, entonces el dual de B es A. Tales involuciones veces tienen puntos fijos , de modo que el dual de A es A mismo

Por lo general, un B dual de un objeto A está relacionado con A de alguna manera que conserva algunos simetría o compatibilidad . Por ejemplo AB = const

Ejemplos de algoritmos iterativos que emplean dualidad (en el sentido anterior) son:

  1. Algoritmo euclidiano para el mayor divisor común y sus variantes
  2. Algoritmo y variantes de Gram-Schmidt Vector Basis
  3. Media aritmética - Desigualdad de media geométrica y sus variantes
  4. Algoritmo de maximización de expectativas y sus variantes (consulte también aquí para obtener una vista geométrica de la información )
  5. (.. otros algoritmos similares ..)

De manera similar, el algoritmo EM también puede verse como dos pasos de maximización duales :

.. [EM] se considera que maximiza una función conjunta de los parámetros y de la distribución sobre las variables no observadas. El paso E maximiza esta función con respecto a la distribución sobre las variables no observadas; el paso M con respecto a los parámetros.

En un algoritmo iterativo que usa dualidad, existe la suposición explícita (o implícita) de un punto de convergencia de equilibrio (o fijo) (para EM, esto se demuestra usando la desigualdad de Jensen)

Entonces, el esquema de tales algoritmos es:

  1. Paso similar a E: Encuentre la mejor solución x con respecto a que y dada se mantenga constante.
  2. Paso similar a M (dual): Encuentre la mejor solución y con respecto a x (como se calculó en el paso anterior) que se mantiene constante.
  3. Criterio de paso de terminación / convergencia: repita los pasos 1, 2 con los valores actualizados de x , y hasta que se alcance la convergencia (o el número especificado de iteraciones)

Nota que cuando un algoritmo converge tales a un óptimo (global), se ha encontrado una configuración que es mejor en los dos sentidos (es decir, tanto en el x de dominio / parámetros y las Y de dominio / parámetros). Sin embargo, el algoritmo solo puede encontrar un óptimo local y no el óptimo global .

Yo diría que esta es la descripción intuitiva del esquema del algoritmo.

Para los argumentos y aplicaciones estadísticos, otras respuestas han dado buenas explicaciones (verifique también las referencias en esta respuesta)


2

La respuesta aceptada hace referencia al documento Chuong EM , que hace un trabajo decente al explicar EM. También hay un video de youtube que explica el artículo con más detalle.

Para recapitular, aquí está el escenario:

1st:  {H,T,T,T,H,H,T,H,T,H} 5 Heads, 5 Tails; Did coin A or B generate me?
2nd:  {H,H,H,H,T,H,H,H,H,H} 9 Heads, 1 Tails
3rd:  {H,T,H,H,H,H,H,T,H,H} 8 Heads, 2 Tails
4th:  {H,T,H,T,T,T,H,H,T,T} 4 Heads, 6 Tails
5th:  {T,H,H,H,T,H,H,H,T,H} 7 Heads, 3 Tails

Two possible coins, A & B are used to generate these distributions.
A & B have an unknown parameter: their bias towards heads.

We don't know the biases, but we can simply start with a guess: A=60% heads, B=50% heads.

En el caso de la pregunta del primer ensayo, intuitivamente pensaríamos que B lo generó ya que la proporción de caras coincide muy bien con el sesgo de B ... pero ese valor fue solo una suposición, por lo que no podemos estar seguros.

Con eso en mente, me gusta pensar en la solución EM de esta manera:

  • Cada prueba de lanzamientos puede 'votar' sobre qué moneda le gusta más
    • Esto se basa en qué tan bien se ajusta cada moneda a su distribución
    • O, desde el punto de vista de la moneda, existe una gran expectativa de ver esta prueba en relación con la otra moneda (según las probabilidades de registro ).
  • Dependiendo de cuánto le guste a cada prueba cada moneda, puede actualizar la suposición del parámetro de esa moneda (sesgo).
    • Cuanto más le gusta una moneda a una prueba, más actualiza el sesgo de la moneda para reflejar el suyo.
    • Esencialmente, los sesgos de la moneda se actualizan combinando estas actualizaciones ponderadas en todas las pruebas, un proceso llamado ( maximización ), que se refiere a tratar de obtener las mejores conjeturas para el sesgo de cada moneda dado un conjunto de pruebas.

Esto puede ser una simplificación excesiva (o incluso fundamentalmente incorrecto en algunos niveles), ¡pero espero que esto ayude a un nivel intuitivo!


1

EM se utiliza para maximizar la probabilidad de un modelo Q con variables latentes Z.

Es una optimización iterativa.

theta <- initial guess for hidden parameters
while not converged:
    #e-step
    Q(theta'|theta) = E[log L(theta|Z)]
    #m-step
    theta <- argmax_theta' Q(theta'|theta)

e-paso: dada la estimación actual de Z calcular la función de loglikelihood esperada

m-paso: encuentra theta que maximiza esta Q

Ejemplo de GMM:

e-step: estimar las asignaciones de etiquetas para cada punto de datos dada la estimación actual del parámetro gmm

m-step: maximizar un nuevo theta dadas las nuevas asignaciones de etiquetas

K-means también es un algoritmo EM y hay muchas animaciones explicativas en K-means.


1

Usando el mismo artículo de Do y Batzoglou citado en la respuesta de Zhubarb, implementé EM para ese problema en Java . Los comentarios a su respuesta muestran que el algoritmo se atasca en un óptimo local, lo que también ocurre con mi implementación si los parámetros thetaA y thetaB son los mismos.

A continuación se muestra la salida estándar de mi código, que muestra la convergencia de los parámetros.

thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960

A continuación se muestra mi implementación Java de EM para resolver el problema en (Do y Batzoglou, 2008). La parte central de la implementación es el ciclo para ejecutar EM hasta que los parámetros converjan.

private Parameters _parameters;

public Parameters run()
{
    while (true)
    {
        expectation();

        Parameters estimatedParameters = maximization();

        if (_parameters.converged(estimatedParameters)) {
            break;
        }

        _parameters = estimatedParameters;
    }

    return _parameters;
}

A continuación se muestra el código completo.

import java.util.*;

/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B,
respectively.
*****************************************************************************/
class Parameters
{
    double _thetaA = 0.0; // Probability of heads for coin A.
    double _thetaB = 0.0; // Probability of heads for coin B.

    double _delta = 0.00001;

    public Parameters(double thetaA, double thetaB)
    {
        _thetaA = thetaA;
        _thetaB = thetaB;
    }

    /*************************************************************************
    Returns true if this parameter is close enough to another parameter
    (typically the estimated parameter coming from the maximization step).
    *************************************************************************/
    public boolean converged(Parameters other)
    {
        if (Math.abs(_thetaA - other._thetaA) < _delta &&
            Math.abs(_thetaB - other._thetaB) < _delta)
        {
            return true;
        }

        return false;
    }

    public double getThetaA()
    {
        return _thetaA;
    }

    public double getThetaB()
    {
        return _thetaB;
    }

    public String toString()
    {
        return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
    }

}


/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
experimental observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
    double _numHeads = 0;
    double _numTails = 0;

    public Observation(String s)
    {
        for (int i = 0; i < s.length(); i++)
        {
            char c = s.charAt(i);

            if (c == 'H')
            {
                _numHeads++;
            }
            else if (c == 'T')
            {
                _numTails++;
            }
            else
            {
                throw new RuntimeException("Unknown character: " + c);
            }
        }
    }

    public Observation(double numHeads, double numTails)
    {
        _numHeads = numHeads;
        _numTails = numTails;
    }

    public double getNumHeads()
    {
        return _numHeads;
    }

    public double getNumTails()
    {
        return _numTails;
    }

    public String toString()
    {
        return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
    }

}

/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
    // Current estimated parameters.
    private Parameters _parameters;

    // Observations from the trials. These observations are set once.
    private final List<Observation> _observations;

    // Estimated observations per coin. These observations are the output
    // of the expectation step.
    private List<Observation> _expectedObservationsForCoinA;
    private List<Observation> _expectedObservationsForCoinB;

    private static java.io.PrintStream o = System.out;

    /*************************************************************************
    Principal constructor.
    @param observations The observations from the trial.
    @param parameters The initial guessed parameters.
    *************************************************************************/
    public EM(List<Observation> observations, Parameters parameters)
    {
        _observations = observations;
        _parameters = parameters;
    }

    /*************************************************************************
    Run EM until parameters converge.
    *************************************************************************/
    public Parameters run()
    {

        while (true)
        {
            expectation();

            Parameters estimatedParameters = maximization();

            o.printf("%s\n", estimatedParameters);

            if (_parameters.converged(estimatedParameters)) {
                break;
            }

            _parameters = estimatedParameters;
        }

        return _parameters;

    }

    /*************************************************************************
    Given the observations and current estimated parameters, compute new
    estimated completions (distribution over the classes) and observations.
    *************************************************************************/
    private void expectation()
    {

        _expectedObservationsForCoinA = new ArrayList<Observation>();
        _expectedObservationsForCoinB = new ArrayList<Observation>();

        for (Observation observation : _observations)
        {
            int numHeads = (int)observation.getNumHeads();
            int numTails = (int)observation.getNumTails();

            double probabilityOfObservationForCoinA=
                binomialProbability(10, numHeads, _parameters.getThetaA());

            double probabilityOfObservationForCoinB=
                binomialProbability(10, numHeads, _parameters.getThetaB());

            double normalizer = probabilityOfObservationForCoinA +
                                probabilityOfObservationForCoinB;

            // Compute the completions for coin A and B (i.e. the probability
            // distribution of the two classes, summed to 1.0).

            double completionCoinA = probabilityOfObservationForCoinA /
                                     normalizer;
            double completionCoinB = probabilityOfObservationForCoinB /
                                     normalizer;

            // Compute new expected observations for the two coins.

            Observation expectedObservationForCoinA =
                new Observation(numHeads * completionCoinA,
                                numTails * completionCoinA);

            Observation expectedObservationForCoinB =
                new Observation(numHeads * completionCoinB,
                                numTails * completionCoinB);

            _expectedObservationsForCoinA.add(expectedObservationForCoinA);
            _expectedObservationsForCoinB.add(expectedObservationForCoinB);
        }
    }

    /*************************************************************************
    Given new estimated observations, compute new estimated parameters.
    *************************************************************************/
    private Parameters maximization()
    {

        double sumCoinAHeads = 0.0;
        double sumCoinATails = 0.0;
        double sumCoinBHeads = 0.0;
        double sumCoinBTails = 0.0;

        for (Observation observation : _expectedObservationsForCoinA)
        {
            sumCoinAHeads += observation.getNumHeads();
            sumCoinATails += observation.getNumTails();
        }

        for (Observation observation : _expectedObservationsForCoinB)
        {
            sumCoinBHeads += observation.getNumHeads();
            sumCoinBTails += observation.getNumTails();
        }

        return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
                              sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));

        //o.printf("parameters: %s\n", _parameters);

    }

    /*************************************************************************
    Since the coin-toss experiment posed in this article is a Bernoulli trial,
    use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
    *************************************************************************/
    private static double binomialProbability(int n, int k, double p)
    {
        double q = 1.0 - p;
        return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
    }

    private static long nChooseK(int n, int k)
    {
        long numerator = 1;

        for (int i = 0; i < k; i++)
        {
            numerator = numerator * n;
            n--;
        }

        long denominator = factorial(k);

        return (long)(numerator / denominator);
    }

    private static long factorial(int n)
    {
        long result = 1;
        for (; n >0; n--)
        {
            result = result * n;
        }

        return result;
    }

    /*************************************************************************
    Entry point into the program.
    *************************************************************************/
    public static void main(String argv[])
    {
        // Create the observations and initial parameter guess
        // from the (Do and Batzoglou, 2008) article.

        List<Observation> observations = new ArrayList<Observation>();
        observations.add(new Observation("HTTTHHTHTH"));
        observations.add(new Observation("HHHHTHHHHH"));
        observations.add(new Observation("HTHHHHHTHH"));
        observations.add(new Observation("HTHTTTHHTT"));
        observations.add(new Observation("THHHTHHHTH"));

        Parameters initialParameters = new Parameters(0.6, 0.5);

        EM em = new EM(observations, initialParameters);

        Parameters finalParameters = em.run();

        o.printf("Final result:\n%s\n", finalParameters);
    }
}
Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.