La atención es un método para agregar un conjunto de vectores en un solo vector, a menudo a través de un vector de búsqueda . Por lo general, son las entradas al modelo o los estados ocultos de pasos de tiempo anteriores, o los estados ocultos un nivel hacia abajo (en el caso de LSTM apilados).vyotuvyo
El resultado a menudo se llama el vector de contexto , ya que contiene el contexto relevante para el paso de tiempo actual.do
Este vector de contexto adicional se alimenta al RNN / LSTM (puede simplemente concatenarse con la entrada original). Por lo tanto, el contexto puede usarse para ayudar con la predicción.do
La forma más sencilla de hacer esto es calcular el vector de probabilidad y donde es la concatenación de todos los anteriores . Un vector de búsqueda común es el estado oculto actual .p = softmax ( VTu )c = ∑yopagsyovyoVvyotuht
Hay muchas variaciones en esto, y puedes hacer las cosas tan complicadas como quieras. Por ejemplo, en lugar de usar como logits, uno puede elegir , donde es una red neuronal arbitraria.vTyotuF( vyo, U )F
Un mecanismo de atención común para los modelos de secuencia a secuencia utiliza , donde son los estados ocultos del codificador y es el oculto actual estado del decodificador y ambos s son parámetros.p=softmax(qTtanh(W1vi+W2ht))vhtqW
Algunos documentos que muestran diferentes variaciones en la idea de atención:
Las redes de punteros prestan atención a las entradas de referencia para resolver problemas de optimización combinatoria.
Las redes de entidades recurrentes mantienen estados de memoria separados para diferentes entidades (personas / objetos) mientras leen texto, y actualizan el estado de memoria correcto con atención.
Los modelos de transformadores también hacen un uso extensivo de la atención. Su formulación de atención es un poco más general y también involucra vectores clave : los pesos de atención se calculan realmente entre las teclas y la búsqueda, y el contexto se construye con .kipvi
Aquí hay una implementación rápida de una forma de atención, aunque no puedo garantizar la corrección más allá del hecho de que pasó algunas pruebas simples.
RNN básico:
def rnn(inputs_split):
bias = tf.get_variable('bias', shape = [hidden_dim, 1])
weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])
hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
for i, input in enumerate(inputs_split):
input = tf.reshape(input, (batch, in_dim, 1))
last_state = hidden_states[-1]
hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
hidden_states.append(hidden)
return hidden_states[-1]
Con atención, agregamos solo unas pocas líneas antes de que se calcule el nuevo estado oculto:
if len(hidden_states) > 1:
logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
probs = tf.nn.softmax(logits)
probs = tf.reshape(probs, (batch, -1, 1, 1))
context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
else:
context = tf.zeros_like(last_state)
last_state = tf.concat([last_state, context], axis = 1)
hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
el código completo