Captura de patrones iniciales cuando se utiliza la retropropagación truncada a través del tiempo (RNN / LSTM)


12

Digamos que uso un RNN / LSTM para hacer análisis de sentimientos, que es un enfoque de muchos a uno (vea este blog ). La red se entrena a través de una retropropagación truncada a través del tiempo (BPTT), donde la red se desenrolla durante solo 30 últimos pasos, como de costumbre.

En mi caso, cada una de mis secciones de texto que quiero clasificar son mucho más largas que los 30 pasos que se están desenrollando (~ 100 palabras). Según mi conocimiento, BPTT solo se ejecuta una sola vez para una sola sección de texto, que es cuando ha pasado por toda la sección de texto y ha calculado el objetivo de clasificación binaria, , que luego compara con la función de pérdida para encontrar el error.y

Los gradientes nunca se calcularán con respecto a las primeras palabras de cada sección de texto. ¿Cómo puede el RNN / LSTM ajustar sus pesos para capturar patrones específicos que solo ocurren dentro de las primeras palabras? Por ejemplo, digamos que todas las oraciones marcadas como comienzan con "I love this" y todas las oraciones marcadas como n e g a t i v e comienzan con "I hate this". ¿Cómo capturaría el RNN / LSTM eso cuando solo se desenrolla durante los últimos 30 pasos cuando llega al final de una secuencia larga de 100 pasos?positivenegative


por lo general, la abreviatura es TBPTT para propagación hacia atrás truncada a través del tiempo.
Charlie Parker

Respuestas:


11

Es cierto que limitar su propagación de gradiente a 30 pasos de tiempo evitará que aprenda todo lo posible en su conjunto de datos. Sin embargo, depende en gran medida de su conjunto de datos si eso evitará que aprenda cosas importantes sobre las características de su modelo.

Limitar el gradiente durante el entrenamiento es más como limitar la ventana sobre la cual su modelo puede asimilar las características de entrada y el estado oculto con gran confianza. Debido a que en el momento de la prueba aplica su modelo a toda la secuencia de entrada, aún podrá incorporar información sobre todas las características de entrada en su estado oculto. Es posible que no sepa exactamente cómo preservar esa información hasta que haga su predicción final para la oración, pero puede haber algunas conexiones (ciertamente más débiles) que aún podría hacer.

Piensa primero en un ejemplo artificial. Suponga que su red debe generar un 1 si hay un 1 en cualquier parte de su entrada, y un 0 en caso contrario. Supongamos que entrena la red en secuencias de longitud 20 y limita el gradiente a 10 pasos. Si el conjunto de datos de entrenamiento nunca contiene un 1 en los últimos 10 pasos de una entrada, entonces la red tendrá un problema con las entradas de prueba de cualquier configuración. Sin embargo, si el conjunto de entrenamiento tiene algunos ejemplos como [1 0 0 ... 0 0 0] y otros como [0 0 0 ... 1 0 0], entonces la red podrá detectar la "presencia de una característica de 1 "en cualquier parte de su entrada.

Volver al análisis de sentimientos entonces. Digamos que durante el entrenamiento, su modelo encuentra una oración negativa larga como "Odio esto porque ... vueltas y vueltas" con, digamos, 50 palabras en puntos suspensivos. Al limitar la propagación del gradiente a 30 pasos de tiempo, el modelo no conectará el "Odio esto porque" a la etiqueta de salida, por lo que no captará "I", "odio" o "esto" de esta capacitación. ejemplo. Pero recogerá las palabras que están dentro de 30 pasos de tiempo desde el final de la oración. Si su conjunto de entrenamiento contiene otros ejemplos que contienen esas mismas palabras, posiblemente junto con "odio", entonces tiene la posibilidad de retomar el vínculo entre "odio" y la etiqueta de sentimiento negativo. Además, si tiene ejemplos de entrenamiento más cortos, diga: "¡Odiamos esto porque es terrible!" entonces su modelo podrá conectar las funciones "odio" y "esto" a la etiqueta de destino. Si tiene suficientes ejemplos de capacitación, entonces el modelo debería poder aprender la conexión de manera efectiva.

En el momento de la prueba, digamos que presentas al modelo con otra oración larga como "¡Odio esto porque ... en el gecko!" La entrada del modelo comenzará con "Odio esto", que se pasará al estado oculto del modelo de alguna forma. Este estado oculto se utiliza para influir en los futuros estados ocultos del modelo, por lo que, aunque puede haber 50 palabras antes del final de la oración, el estado oculto de esas palabras iniciales tiene una posibilidad teórica de influir en la salida, a pesar de que nunca fue entrenado en muestras que contenían una distancia tan grande entre el "Odio esto" y el final de la oración.


0

@ Imjohns3 tiene razón, si procesa secuencias largas (tamaño N) y limita la propagación hacia atrás a los últimos K pasos, la red no aprenderá patrones al principio.

He trabajado con textos largos y uso el enfoque donde calculo la pérdida y hago la propagación hacia atrás después de cada K pasos. Supongamos que mi secuencia tenía N = 1000 tokens, mi proceso RNN primero K = 100, luego trato de hacer predicciones (pérdida de cálculo) y propagación hacia atrás. Luego, mientras mantiene el estado RNN, rompa la cadena de gradiente (en pytorch-> detach) y comience otros k = 100 pasos.

Un buen ejemplo de esta técnica se puede encontrar aquí: https://github.com/ksopyla/pytorch_neural_networks/blob/master/RNN/lstm_imdb_tbptt.py

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.