RNN: ¿Cuándo aplicar BPTT y / o actualizar pesos?


15

Estoy tratando de comprender la aplicación de alto nivel de los RNN para el etiquetado de secuencias a través (entre otros) del documento de Graves de 2005 sobre la clasificación de fonemas.

Para resumir el problema: tenemos un gran conjunto de capacitación que consta de archivos de audio (de entrada) de oraciones individuales y (salida) horas de inicio etiquetadas por expertos, horas de finalización y etiquetas para fonemas individuales (incluidos algunos fonemas "especiales" como el silencio, de modo que cada muestra en cada archivo de audio esté etiquetada con algún símbolo de fonema).

El objetivo del trabajo es aplicar un RNN con celdas de memoria LSTM en la capa oculta a este problema. (Aplica varias variantes y varias otras técnicas como comparación. Por el momento, SOLO estoy interesado en el LSTM unidireccional, para simplificar las cosas).

Creo que entiendo la arquitectura de la red: una capa de entrada correspondiente a ventanas de 10 ms de los archivos de audio, preprocesada de manera estándar para el trabajo de audio; una capa oculta de celdas LSTM y una capa de salida con una codificación única de todos los 61 símbolos telefónicos posibles.

Creo que entiendo las ecuaciones (intrincadas pero sencillas) del paso hacia adelante y hacia atrás a través de las unidades LSTM. Son solo cálculo y la regla de la cadena.

Lo que no entiendo, después de leer este documento y varios similares varias veces, es cuándo aplicar exactamente el algoritmo de retropropagación y cuándo actualizar exactamente los diversos pesos en las neuronas.

Existen dos métodos plausibles:

1) Backprop y actualización en marco

Load a sentence.  
Divide into frames/timesteps.  
For each frame:
- Apply forward step
- Determine error function
- Apply backpropagation to this frame's error
- Update weights accordingly
At end of sentence, reset memory
load another sentence and continue.

o,

2) Backprop y actualización basada en oraciones:

Load a sentence.  
Divide into frames/timesteps.  
For each frame:
- Apply forward step
- Determine error function
At end of sentence:
- Apply backprop to average of sentence error function
- Update weights accordingly
- Reset memory
Load another sentence and continue.

Tenga en cuenta que esta es una pregunta general sobre el entrenamiento RNN utilizando el papel de Graves como un ejemplo puntiagudo (y personalmente relevante): cuando se entrena RNN en secuencias, ¿se aplica backprop en cada paso de tiempo? ¿Se ajustan los pesos cada vez? O, en una analogía laxa al entrenamiento por lotes en arquitecturas estrictamente avanzadas, ¿se acumulan y promedian los errores sobre una secuencia particular antes de aplicar las actualizaciones de backprop y peso?

¿O estoy aún más confundido de lo que pienso?

Respuestas:


25

Asumiré que estamos hablando de redes neuronales recurrentes (RNN) que producen una salida en cada paso de tiempo (si la salida solo está disponible al final de la secuencia, solo tiene sentido ejecutar backprop al final). Los RNN en este entorno a menudo se entrenan utilizando la retropropagación truncada a través del tiempo (BPTT), operando secuencialmente en 'fragmentos' de una secuencia. El procedimiento se ve así:

  1. Pase directo: avance por los próximos pasos de tiempo , calculando los estados de entrada, oculto y de salida.k1
  2. Calcule la pérdida, sumada en los pasos de tiempo anteriores (ver más abajo).
  3. Paso hacia atrás: calcule el gradiente de la pérdida wrt todos los parámetros, acumulando durante los pasos de tiempo anteriores (esto requiere haber almacenado todas las activaciones para estos pasos de tiempo). Recorte los degradados para evitar el problema de la explosión del degradado (ocurre raramente).k2
  4. Actualice los parámetros (esto ocurre una vez por porción, no de forma incremental en cada paso de tiempo).
  5. Si procesa varios fragmentos de una secuencia más larga, almacene el estado oculto en el último paso de tiempo (se usará para inicializar el estado oculto para el comienzo del próximo fragmento). Si hemos llegado al final de la secuencia, restablezca la memoria / estado oculto y avance al comienzo de la siguiente secuencia (o al comienzo de la misma secuencia, si solo hay una).
  6. Repita desde el paso 1.

La forma en que se suma la pérdida depende de y . Por ejemplo, cuando , la pérdida se suma en los últimos pasos de tiempo , pero el procedimiento es diferente cuando (ver Williams y Peng 1990).k1k2k1=k2k1=k2k2>k1

El cálculo de gradiente y las actualizaciones se realizan cada pasos de tiempo porque es computacionalmente más barato que la actualización en cada paso de tiempo. Actualizar varias veces por secuencia (es decir, establecer menos que la longitud de la secuencia) puede acelerar el entrenamiento porque las actualizaciones de peso son más frecuentes.k1k1

La retropropagación se realiza solo para pasos de tiempo porque es computacionalmente más barato que propagarse al comienzo de la secuencia (lo que requeriría almacenar y procesar repetidamente todos los pasos de tiempo). Los gradientes calculados de esta manera son una aproximación al gradiente 'verdadero' calculado en todos los pasos de tiempo. Pero, debido al problema de gradiente de desaparición, los gradientes tenderán a acercarse a cero después de cierto número de pasos de tiempo; propagar más allá de este límite no daría ningún beneficio. Establecer demasiado corto puede limitar la escala temporal sobre la cual la red puede aprender. Sin embargo, la memoria de la red no se limita a pasos de tiempo porque las unidades ocultas pueden almacenar información más allá de este período (p. Ej.k2k2k2)

Además de las consideraciones computacionales, la configuración adecuada para y depende de las estadísticas de los datos (por ejemplo, la escala temporal de las estructuras que son relevantes para producir buenos resultados). Probablemente también dependan de los detalles de la red. Por ejemplo, hay una serie de arquitecturas, trucos de inicialización, etc. diseñados para mitigar el problema del gradiente en descomposición.k1k2

Su opción 1 ('backprop en marco') corresponde a establecer en y en el número de pasos de tiempo desde el comienzo de la oración hasta el punto actual. La opción 2 ('backprop en cuanto a la oración') corresponde a establecer y en la longitud de la oración. Ambos son enfoques válidos (con consideraciones computacionales / de rendimiento como anteriormente; # 1 sería bastante computacionalmente intensivo para secuencias más largas). Ninguno de estos enfoques se llamaría 'truncado' porque la propagación hacia atrás ocurre en toda la secuencia. Son posibles otras configuraciones de y ; Voy a enumerar algunos ejemplos a continuación.k11k2k1k2k1k2

Referencias que describen BPTT truncado (procedimiento, motivación, cuestiones prácticas):

  • Sutskever (2013) . Entrenamiento de redes neuronales recurrentes.
  • Mikolov (2012) . Modelos estadísticos de lenguaje basados ​​en redes neuronales.
    • Usando los RNN de vainilla para procesar datos de texto como una secuencia de palabras, recomienda configurar en 10-20 palabras y en 5 palabrask1k2
    • Realizar múltiples actualizaciones por secuencia (es decir, menos que la longitud de la secuencia) funciona mejor que actualizar al final de la secuenciak1
    • Realizar actualizaciones una vez por fragmento es mejor que incrementalmente (lo que puede ser inestable)
  • Williams y Peng (1990) . Un algoritmo eficiente basado en gradientes para el entrenamiento en línea de trayectorias de red recurrentes.
    • Propuesta original (?) Del algoritmo
    • Discuten la elección de y (que llaman y ). Solo consideran .k1k2hhk2k1
    • Nota: Ellos usan la frase "BPTT (h; h ')" o' el algoritmo mejorado 'para referirse a lo que las otras referencias llaman' BPTT truncado '. Usan la frase 'BPTT truncado' para referirse al caso especial donde .k1=1

Otros ejemplos que usan BPTT truncado:

  • (Karpathy 2015). char-rnn.
    • Descripción y código
    • Vanilla RNN procesa documentos de texto de un carácter a la vez. Entrenado para predecir el próximo personaje. caracteres. La red solía generar texto nuevo al estilo del documento de capacitación, con resultados divertidos.k1=k2=25
  • Graves (2014) . Generando secuencias con redes neuronales recurrentes.
    • Consulte la sección sobre generación de artículos simulados de Wikipedia Red LSTM que procesa datos de texto como secuencia de bytes. Entrenado para predecir el siguiente byte. bytes. La memoria LSTM se reinicia cada bytes.10 , 000k1=k2=10010,000
  • Sak y col. (2014) . Arquitecturas de redes neuronales recurrentes basadas en memoria a largo plazo para reconocimiento de voz con vocabulario amplio
    • Redes LSTM modificadas, secuencias de procesamiento de características acústicas. .k1=k2=20
  • Ollivier y col. (2015) . Capacitación de redes recurrentes en línea sin retroceso.
    • El objetivo de este artículo era proponer un algoritmo de aprendizaje diferente, pero lo compararon con BPTT truncado. Usó RNN de vainilla para predecir secuencias de símbolos. Solo lo menciono aquí para decir que usaron .k1=k2=15
  • Hochreiter y Schmidhuber (1997) . Memoria a largo plazo a largo plazo.
    • Describen un procedimiento modificado para LSTM

Esta es una respuesta sobresaliente, y desearía tener la posición en este foro para otorgarle una recompensa sustancial. Especialmente útil es la discusión concreta de k1 vs k2 para contextualizar mis dos casos contra el uso más general, y ejemplos numéricos de los mismos.
Novak
Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.