Tengo un conjunto de datos con 3 clases con los siguientes elementos:
- Clase 1: 900 elementos
- Clase 2: 15000 elementos.
- Clase 3: 800 elementos
Necesito predecir la clase 1 y la clase 3, que indican desviaciones importantes de la norma. La clase 2 es el caso "normal" predeterminado que no me importa.
¿Qué tipo de función de pérdida usaría aquí? Estaba pensando en usar CrossEntropyLoss, pero dado que hay un desequilibrio de clase, supongo que esto debería ponderarse. ¿Cómo funciona eso en la práctica? ¿Te gusta esto (usando PyTorch)?
summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)
¿O debería invertirse el peso? es decir, 1 / peso?
¿Es este el enfoque correcto para comenzar o hay otros / mejores métodos que podría usar?
Gracias