La convergencia de GradienTape es mucho más lenta que Keras.model.fit


8

Actualmente estoy tratando de obtener la API TF2.0 , pero cuando comparé el GradientTape con un keras regular.Model.fit noté:

  1. Funcionó más lento (probablemente debido a la Ejecución Eager)

  2. Convergente mucho más lento (y no estoy seguro de por qué).

+--------+--------------+--------------+------------------+
|  Epoch | GradientTape | GradientTape | keras.Model.fit  |
|        |              |  shuffling   |                  |
+--------+--------------+--------------+------------------+
|    1   |     0.905    |     0.918    |      0.8793      |
+--------+--------------+--------------+------------------+
|    2   |     0.352    |     0.634    |      0.2226      |
+--------+--------------+--------------+------------------+
|    3   |     0.285    |     0.518    |      0.1192      |
+--------+--------------+--------------+------------------+
|    4   |     0.282    |     0.458    |      0.1029      |
+--------+--------------+--------------+------------------+
|    5   |     0.275    |     0.421    |      0.0940      |
+--------+--------------+--------------+------------------+

Aquí está el ciclo de entrenamiento que utilicé con GradientTape :


optimizer = keras.optimizers.Adam()
glove_model = GloveModel(vocab_size=len(labels))
train_loss = keras.metrics.Mean(name='train_loss')

@tf.function
def train_step(examples, labels):
    with tf.GradientTape() as tape:
        predictions = glove_model(examples)
        loss = glove_model.glove_loss(labels, predictions)

    gradients = tape.gradient(loss, glove_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, glove_model.trainable_variables))

    train_loss(loss)



total_step = 0
for epoch in range(epochs_number):

    pbar = tqdm(train_ds.enumerate(), total=int(len(index_data) / batch_size) + 1)

    for ix, (examples, labels) in pbar:

        train_step(examples, labels)


    print(f"Epoch {epoch + 1}, Loss {train_loss.result()}")

    # Reset the metrics for the next epoch
    train_loss.reset_states()

Y aquí está el entrenamiento Keras.Model.fit :

glove_model.compile(optimizer, glove_model.glove_loss)
glove_model.fit(train_ds, epochs=epochs_number)

Aquí está la fuente tf.data.Dataset

train_ds = data.Dataset.from_tensor_slices(
    (np.hstack([index_rows.reshape(-1, 1), index_cols.reshape(-1, 1)]), index_data)
).shuffle(100000).batch(batch_size, drop_remainder=True)

Y aquí está el modelo.

class GloveModel(keras.Model):

    def __init__(self, vocab_size, dim=100, a=3/4, x_max=100):
        super(GloveModel, self).__init__()

        self.vocab_size = vocab_size
        self.dim = dim
        self.a = a
        self.x_max = x_max

        self.target_embedding = layers.Embedding(
            input_dim=self.vocab_size, output_dim=self.dim, input_length=1, name="target_embedding"
        )
        self.target_bias = layers.Embedding(
            input_dim=self.vocab_size, output_dim=1, input_length=1, name="target_bias"
        )

        self.context_embedding = layers.Embedding(
            input_dim=self.vocab_size, output_dim=self.dim, input_length=1, name="context_embedding"
        )
        self.context_bias = layers.Embedding(
            input_dim=self.vocab_size, output_dim=1, input_length=1, name="context_bias"
        )

        self.dot_product = layers.Dot(axes=-1, name="dot")

        self.prediction = layers.Add(name="add")
        self.step = 0

    def call(self, inputs):

        target_ix = inputs[:, 0]
        context_ix = inputs[:, 1]

        target_embedding = self.target_embedding(target_ix)
        target_bias = self.target_bias(target_ix)

        context_embedding = self.context_embedding(context_ix)
        context_bias = self.context_bias(context_ix)

        dot_product = self.dot_product([target_embedding, context_embedding])
        prediction = self.prediction([dot_product, target_bias, context_bias])

        return prediction

    def glove_loss(self, y_true, y_pred):

        weight = tf.math.minimum(
            tf.math.pow(y_true/self.x_max, self.a), 1.0
        )
        loss_value = tf.math.reduce_mean(weight * tf.math.pow(y_pred - tf.math.log(y_true), 2.0))

        return loss_value



Intenté múltiples configuraciones y optimizadores, pero nada parece cambiar la tasa de convergencia.


1
Una cosa a tener en cuenta es la mezcla de datos antes de cada época.
THN

Tengo exactamente la misma combinación entre el método de ajuste y GradientTape porque uso la API tf.Data.
Benjamin Breton

1
Creo que no son exactamente lo mismo. ¿Puedes mostrar el código de tu tfds? Tenga en cuenta que keras por .fitdefecto se baraja antes de cada época. Puede probar desactivando la combinación aleatoria en keras y comparar su tasa de convergencia.
THN

@THN Te lo enviaré, pero ya realizo una combinación aleatoria con la API tf.Dataset, por lo que no debería cambiar nada, ¿verdad?
Benjamin Breton

@THN Agregué el tf.data.Dataset
Benjamin Breton el

Respuestas:


2

Dataset.shuffle()solo baraja cada minibatch, por lo que cada época tiene el mismo orden. Keras .fit()usa algunas magias para barajar todo el conjunto de datos antes de cada época. Para hacer esto en TF, debe usar el conjunto de datos .repeat(epochs_number)y .shuffle(..., reshuffle_each_iteration=True):

train_ds = data.Dataset.from_tensor_slices(
    (np.hstack([index_rows.reshape(-1, 1), index_cols.reshape(-1, 1)]), index_data)
    ).shuffle(100000, reshuffle_each_iteration=True
    ).batch(batch_size, drop_remainder=True
    ).repeat(epochs_number)

for ix, (examples, labels) in train_ds.enumerate():
    train_step(examples, labels)
    current_epoch = ix // (len(index_data) // batch_size)

Esta solución alternativa no es hermosa ni natural, por el momento puede usarla para mezclar cada época. Es un problema conocido y se solucionará, en el futuro puede usarlo en for epoch in range(epochs_number)lugar de hacerlo .repeat().


Lo siento, agregué su código, pero la convergencia es aún más lenta. Agregué los resultados en la columna GradientTape shuffle. No tiene sentido para mí ...
Benjamin Breton

@BenjaminBreton En este punto, dudo que haya otros errores al acecho en su código. Quizás sea mejor vincular a su repositorio para mostrar el código completo. Si está seguro de que sus experimentos se llevan a cabo correctamente, debe abrir un problema en el repositorio de tensorflow.
THN

Muchas gracias por su ayuda @THN Publiqué el problema en el repositorio TF2.0 github.com/tensorflow/tensorflow/issues/33898 . Intentaré reproducir el error con un modelo diferente.
Benjamin Breton el

1
Resulta que tenías razón @THN barajé usando numpy y resolvió el problema. Publicaré una respuesta integral
Benjamin Breton, el

0

El problema vino de la barajado utilizando la tf.Dataset método. Solo barajó el conjunto de datos un cubo a la vez. El uso del Keras.Model.fit arrojó mejores resultados porque probablemente agrega otro barajado.

Agregué un barajado numpy.random.shuffley mejoró el rendimiento con ambos métodos de entrenamiento:

La generación del conjunto de datos es ahora:

numpy_data = np.hstack([index_rows.reshape(-1, 1), index_cols.reshape(-1, 1), index_data.reshape(-1, 1)])

np.random.shuffle(numpy_data)

indexes = np.array(numpy_data[:, :2], dtype=np.uint32)
labels = np.array(numpy_data[:, 2].reshape(-1, 1), dtype=np.float32)

train_ds = data.Dataset.from_tensor_slices(
    (indexes, labels)
).shuffle(100000).batch(batch_size, drop_remainder=True)

Y los resultados son:

+--------+--------------+------------------+
|  Epoch | GradientTape |  keras.Model.fit |
+--------+--------------+------------------+
|    1   |     0.294    |      0.294       |
+--------+--------------+------------------+
|    2   |     0.111    |      0.110       |
+--------+--------------+------------------+
|    3   |     0.089    |      0.089       |
+--------+--------------+------------------+
|    4   |     0.074    |      0.075       |
+--------+--------------+------------------+
|    5   |     0.063    |      0.063       |
+--------+--------------+------------------+

El tipo de entrenamiento por época es aproximadamente el mismo en 2 minutos por época .

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.