WaveNet no es realmente una convolución dilatada, ¿verdad?


10

En el reciente artículo de WaveNet , los autores se refieren a su modelo como capas apiladas de convoluciones dilatadas. También producen los siguientes cuadros, que explican la diferencia entre convoluciones "regulares" y convoluciones dilatadas.

Las convoluciones regulares se ven así. Convoluciones no dilatadas Esta es una convolución con un tamaño de filtro de 2 y un paso de 1, repetido para 4 capas.

Luego muestran una arquitectura utilizada por su modelo, a la que se refieren como convoluciones dilatadas. Se parece a esto. Convoluciones Dilatadas WaveNet Dicen que cada capa tiene dilataciones crecientes de (1, 2, 4, 8). Pero para mí esto parece una convolución regular con un tamaño de filtro de 2 y un paso de 2, repetido para 4 capas.

Según tengo entendido, una convolución dilatada, con un tamaño de filtro de 2, zancada de 1 y dilataciones crecientes de (1, 2, 4, 8), se vería así. Convolución Dilatada Real

En el diagrama WaveNet, ninguno de los filtros salta una entrada disponible. No hay agujeros En mi diagrama, cada filtro omite (d - 1) las entradas disponibles. Así es como se supone que funciona la dilatación, ¿no?

Entonces mi pregunta es, ¿cuál (si alguna) de las siguientes proposiciones son correctas?

  1. No entiendo las convoluciones dilatadas y / o regulares.
  2. Deepmind en realidad no implementó una convolución dilatada, sino más bien una convolución zancada, pero usó mal la palabra dilatación.
  3. Deepmind implementó una convolución dilatada, pero no implementó la tabla correctamente.

No soy lo suficientemente fluido en el código de TensorFlow para entender exactamente qué está haciendo su código, pero publiqué una pregunta relacionada en Stack Exchange , que contiene el bit de código que podría responder a esta pregunta.


Encontré su pregunta y respuesta a continuación bastante interesante. Dado que el artículo de WaveNet no explica esa equivalencia de zancada y tasa de dilatación, decidí resumir los conceptos clave en una publicación de blog: theblog.github.io/post/… . Puede resultarle interesante si todavía está trabajando con neuronas autorregresivas redes
Kilian Batzner

Respuestas:


8

Del documento de wavenet:

"A dilated convolution (also called a trous, or convolution with 
holes) is a convolution where the filter is applied over an area larger 
than its length by skipping input values with a certain step. It is 
equivalent to a convolution with a larger filter derived from the 
original filter by dilating it with zeros, but is significantly more 
efficient. A dilated convolution  effectively allows the network to 
operate on a coarser scale than with a normal convolution. This is 
similar to pooling or strided  convolutions, but 
here the output has the same size as the input. As a special case, 
dilated convolution with dilation 1 yields the standard convolution. 
Fig. 3 depicts dilated causal convolutions for dilations 1, 2, 4, and 
8."

Las animaciones muestran una zancada fija y un factor de dilatación que aumenta en cada capa. Fig. 3 animada de la publicación del blog wavenet de Google


6

El centavo solo cayó en este para mí. De esas 3 proposiciones, la correcta es 4: no entendí el documento de WaveNet.

Mi problema era que estaba interpretando que el diagrama de WaveNet cubría una sola muestra, para ejecutarse en diferentes muestras dispuestas en una estructura 2D, siendo 1 dimensión el tamaño de la muestra y la otra el recuento de lotes.

Sin embargo, WaveNet solo está ejecutando todo ese filtro en una serie de tiempo 1D con un paso de 1. Esto obviamente tiene una huella de memoria mucho menor pero logra lo mismo.

Si intentas hacer el mismo truco usando una estructura estriada, la dimensión de salida sería incorrecta.

Para resumir, hacerlo de manera progresiva con una muestra en 2D x estructura por lotes proporciona el mismo modelo, pero con un uso de memoria mucho mayor.

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.