¿Qué función de pérdida usar para las clases desequilibradas (usando PyTorch)?


17

Tengo un conjunto de datos con 3 clases con los siguientes elementos:

  • Clase 1: 900 elementos
  • Clase 2: 15000 elementos.
  • Clase 3: 800 elementos

Necesito predecir la clase 1 y la clase 3, que indican desviaciones importantes de la norma. La clase 2 es el caso "normal" predeterminado que no me importa.

¿Qué tipo de función de pérdida usaría aquí? Estaba pensando en usar CrossEntropyLoss, pero dado que hay un desequilibrio de clase, supongo que esto debería ponderarse. ¿Cómo funciona eso en la práctica? ¿Te gusta esto (usando PyTorch)?

summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)

¿O debería invertirse el peso? es decir, 1 / peso?

¿Es este el enfoque correcto para comenzar o hay otros / mejores métodos que podría usar?

Gracias

Respuestas:


13

¿Qué tipo de función de pérdida usaría aquí?

La entropía cruzada es la función de ir a la pérdida para tareas de clasificación, ya sea balanceadas o no balanceadas. Es la primera opción cuando todavía no se crea ninguna preferencia a partir del conocimiento del dominio.

Esto tendría que ser ponderado, supongo. ¿Cómo funciona eso en la práctica?

Si. Peso de la claseC es el tamaño de la clase más grande dividido por el tamaño de la clase C.

Por ejemplo, si la clase 1 tiene 900, la clase 2 tiene 15000 y la clase 3 tiene 800 muestras, entonces sus pesos serían 16.67, 1.0 y 18.75 respectivamente.

También puede usar la clase más pequeña como nominador, que da 0.889, 0.053 y 1.0 respectivamente. Esto es solo una reescalada, los pesos relativos son los mismos.

¿Es este el enfoque correcto para comenzar o hay otros / mejores métodos que podría usar?

Sí, este es el enfoque correcto.

EDITAR :

Gracias a @Muppet, también podemos usar sobremuestreo de clase, que es equivalente a usar pesos de clase . Esto se logra WeightedRandomSampleren PyTorch, utilizando los mismos pesos mencionados anteriormente.


2
Solo quería agregar que usar WeightedRandomSampler de PyTorch también ayudó, en caso de que alguien más esté viendo esto.
Muppet

0

Cuando dices: también puedes usar la clase más pequeña como nominador, que da 0.889, 0.053 y 1.0 respectivamente. Esto es solo una reescalada, los pesos relativos son los mismos.

Pero esta solución está en contradicción con la primera que diste, ¿cómo funciona?

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.