¿Cómo calcula Keras la precisión?


26

¿Cómo calcula Keras la precisión de las probabilidades de clase? Digamos, por ejemplo, tenemos 100 muestras en el conjunto de prueba que pueden pertenecer a una de dos clases. También tenemos una lista de las probabilidades de clase. ¿Qué umbral utiliza Keras para asignar una muestra a cualquiera de las dos clases?


¿estás usando model.evaluate en keras?
Hima Varsha

Sí, estoy usando model.evaluate. Más específicamente, model.evaluate_generator.
Raghuram


Posiblemente relacionado @SO: ¿Cómo evalúa Keras la precisión? )
desertnaut

Respuestas:


24

Para la clasificación binaria, el código para la métrica de precisión es:

K.mean(K.equal(y_true, K.round(y_pred)))

lo que sugiere que 0.5 es el umbral para distinguir entre clases. y_true debería ser, por supuesto, 1 hots en este caso.

Es un poco diferente para la clasificación categórica:

K.mean(K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)))

lo que significa "con qué frecuencia las predicciones tienen un máximo en el mismo lugar que los valores verdaderos"

También hay una opción para la precisión categórica top-k, que es similar a la anterior, pero calcula con qué frecuencia la clase objetivo está dentro de las predicciones top-k.


Gracias por la respuesta. ¿Eso significa que incluso para la clasificación binaria, las etiquetas deben estar codificadas en caliente?
Raghuram

@ Raghuram No, para la clasificación binaria solo necesita 0 o 1 como clase, no es necesario codificarlos en caliente. Como K.mean (K.equal (y_true, K.round (y_pred))) coincidirá con 2 valores flotantes para cada caso, por lo que debe ser 0 o 1 y no [0,1], [1,0].
Divyanshu Kalra

Para una precisión categórica, use categorical_accuracy.
Shital Shah

1
para un problema de varias clases (con más de dos clases), ¿hay alguna diferencia entre usar "precisión" versus "precisión_categoría"
Quetzalcóatl
Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.