Caída repentina de precisión al entrenar LSTM o GRU en Keras


8

Mi red neuronal recurrente (LSTM, resp. GRU) se comporta de una manera que no puedo explicar. El entrenamiento comienza y se entrena bien (los resultados se ven bastante bien) cuando de repente cae la precisión (y la pérdida aumenta rápidamente) , tanto las métricas de entrenamiento como de prueba. A veces, la red simplemente se vuelve loca y devuelve salidas aleatorias y, a veces (como en el último de los tres ejemplos dados) comienza a devolver la misma salida a todas las entradas .

imagen

¿Tienes alguna explicación para este comportamiento ? Cualquier opinión es bienvenida. Por favor, vea la descripción de la tarea y las figuras a continuación.

La tarea: a partir de una palabra predecir su vector word2vec La entrada: Tenemos un modelo propio de word2vec (normalizado) y alimentamos la red con una palabra (letra por letra). Rellenamos las palabras (ver el ejemplo a continuación). Ejemplo: tenemos una palabra fútbol y queremos predecir su vector word2vec que tiene 100 dimensiones de ancho. Entonces la entrada es $football$$$$$$$$$$.

Tres ejemplos del comportamiento:

Capa simple LSTM

model = Sequential([
    LSTM(1024, input_shape=encoder.shape, return_sequences=False),
    Dense(w2v_size, activation="linear")
])

model.compile(optimizer='adam', loss="mse", metrics=["accuracy"])

imagen

GRU de una sola capa

model = Sequential([
    GRU(1024, input_shape=encoder.shape, return_sequences=False),
    Dense(w2v_size, activation="linear")
])

model.compile(optimizer='adam', loss="mse", metrics=["accuracy"])

imagen

Doble capa LSTM

model = Sequential([
    LSTM(512, input_shape=encoder.shape, return_sequences=True),
    TimeDistributed(Dense(512, activation="sigmoid")),
    LSTM(512, return_sequences=False),
    Dense(256, activation="tanh"),
    Dense(w2v_size, activation="linear")
])

model.compile(optimizer='adam', loss="mse", metrics=["accuracy"])

imagen

También hemos experimentado este tipo de comportamiento en otro proyecto antes que utilizaba una arquitectura similar pero su objetivo y datos eran diferentes. Por lo tanto, la razón no debe ocultarse en los datos o en el objetivo particular, sino más bien en la arquitectura.


¿descubriste qué estaba causando el problema?
Antoine

Lamentablemente no realmente. Cambiamos a una arquitectura diferente y luego no tuvimos la oportunidad de volver a esto. Sin embargo, tenemos algunas pistas. Suponemos que algo causó el cambio de uno o más de los parámetros nan.
Marek

nanparámetro no resultaría en una pérdida no nan. Supongo que sus gradientes explotan, me sucedió algo similar en redes normalizadas sin lotes.
Lugi

Esa es también una de las cosas que tratamos de examinar usando TensorBoard, pero la explosión de gradiente nunca se ha probado en nuestro caso. La idea fue que nanapareció en uno de los cálculos y luego se convirtió en otro valor que causó que la red se volviera loca. Pero es solo una suposición salvaje. Gracias por su opinión.
Marek

Respuestas:


2

Aquí está mi sugerencia para señalar el problema:

1) Mire la curva de aprendizaje del entrenamiento: ¿Cómo se establece la curva de aprendizaje en el tren? ¿Aprende el conjunto de entrenamiento? Si no, primero trabaje en eso para asegurarse de que puede encajar demasiado en el conjunto de entrenamiento.

2) Verifique sus datos para asegurarse de que no contenga NaN (capacitación, validación, prueba)

3) Verifique los gradientes y los pesos para asegurarse de que no haya NaN.

4) Disminuya la tasa de aprendizaje mientras entrena para asegurarse de que no se deba a una gran actualización repentina que se quedó en un mínimo agudo.

5) Para asegurarse de que todo esté bien, verifique las predicciones de su red para que su red no esté haciendo predicciones constantes o repetitivas.

6) Verifique si sus datos en su lote están equilibrados con respecto a todas las clases.

7) normalice sus datos para que sean cero unidades medias de varianza. Inicialice los pesos de la misma manera. Ayudará a la capacitación.

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.