Comprender las unidades LSTM frente a las células


32

He estado estudiando LSTM por un tiempo. Entiendo a alto nivel cómo funciona todo. Sin embargo, al implementarlos usando Tensorflow, he notado que BasicLSTMCell requiere un número de unidades (es decir num_units) parámetro.

A partir de esta explicación muy detallada de los LSTM, he deducido que una sola unidad LSTM es una de las siguientes

Unidad LSTM

que en realidad es una unidad GRU.

Supongo que el parámetro num_unitsde BasicLSTMCellse refiere a cuántos de estos queremos conectar entre sí en una capa.

Eso deja la pregunta: ¿qué es una "célula" en este contexto? ¿Es una "célula" equivalente a una capa en una red neuronal de alimentación normal?


Todavía estoy confundido, estaba leyendo colah.github.io/posts/2015-08-Understanding-LSTMs y lo entiendo bien. ¿Cómo se aplica el término celda con respecto a ese artículo? Parece que una celda LSTM en el artículo es un vector como en Tensorflow, ¿verdad?
Pinocho

Respuestas:


17

La terminología es lamentablemente inconsistente. num_unitsen TensorFlow es el número de estados ocultos, es decir, la dimensión de en las ecuaciones que proporcionó.ht

Además, desde https://github.com/tensorflow/tensorflow/blob/master/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.nn.rnn_cell.RNNCell.md :

La definición de célula en este paquete difiere de la definición utilizada en la literatura. En la literatura, la celda se refiere a un objeto con una salida escalar única. La definición en este paquete se refiere a una matriz horizontal de tales unidades.

La "capa LSTM" es probablemente más explícita, por ejemplo :

def lstm_layer(tparams, state_below, options, prefix='lstm', mask=None):
    nsteps = state_below.shape[0]
    if state_below.ndim == 3:
        n_samples = state_below.shape[1]
    else:
        n_samples = 1

    assert mask is not None
    […]

Ah, ya veo, entonces una "celda" es una num_unitmatriz horizontal de tamaño de celdas LSTM interconectadas. Tiene sentido. Entonces, ¿sería análogo a una capa oculta en una red de alimentación estándar?

* Unidades de estado LSTM

@rec Eso es correcto
Franck Dernoncourt

1
@Sycorax, por ejemplo, si la entrada de la red neuronal es una serie de tiempo con 10 pasos de tiempo, la dimensión horizontal tiene 10 elementos.
Franck Dernoncourt

1
Todavía estoy confundido, estaba leyendo colah.github.io/posts/2015-08-Understanding-LSTMs y lo entiendo bien. ¿Cómo se aplica el término celda con respecto a ese artículo? Parece que una celda LSTM en el artículo es un vector como en Tensorflow, ¿verdad?
Pinocho

4

La mayoría de los diagramas LSTM / RNN solo muestran las celdas ocultas pero nunca las unidades de esas celdas. De ahí la confusión. Cada capa oculta tiene celdas ocultas, tanto como el número de pasos de tiempo. Y además, cada celda oculta está compuesta de múltiples unidades ocultas, como en el diagrama a continuación. Por lo tanto, la dimensionalidad de una matriz de capa oculta en RNN es (número de pasos de tiempo, número de unidades ocultas).

ingrese la descripción de la imagen aquí


4

Aunque el problema es casi el mismo que respondí en esta respuesta , me gustaría ilustrar este problema, que también me confundió un poco hoy en el modelo seq2seq (gracias a la respuesta de @Franck Dernoncourt), en el gráfico. En este diagrama de codificador simple:

ingrese la descripción de la imagen aquí

hi


Creo que num_units = nen esta figura
notilas

-1

En mi opinión, celda significa un nodo como celda oculta, que también se llama nodo oculto, para el modelo LSTM multicapa, el número de celdas se puede calcular mediante time_steps * num_layers, y el número de unidades es igual a time_steps


-1

Que las unidades en Keras es la dimensión del espacio de salida, que es igual a la duración del retraso (time_step) a la que recurre la red.

keras.layers.LSTM(units, activation='tanh', ....)

https://keras.io/layers/recurrent/

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.