¿Cómo decirle a Keras que deje de entrenar en función del valor de la pérdida?


82

Actualmente utilizo el siguiente código:

callbacks = [
    EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

Le dice a Keras que deje de entrenar cuando la pérdida no mejoró durante 2 épocas. Pero quiero dejar de entrenar después de que la pérdida se vuelva más pequeña que un "THR" constante:

if val_loss < THR:
    break

He visto en la documentación que existe la posibilidad de hacer su propia devolución de llamada: http://keras.io/callbacks/ Pero no se encontró cómo detener el proceso de entrenamiento. Necesito un consejo.

Respuestas:


85

Encontré la respuesta. Busqué en las fuentes de Keras y encontré el código para EarlyStopping. Hice mi propia devolución de llamada, basada en ella:

class EarlyStoppingByLossVal(Callback):
    def __init__(self, monitor='val_loss', value=0.00001, verbose=0):
        super(Callback, self).__init__()
        self.monitor = monitor
        self.value = value
        self.verbose = verbose

    def on_epoch_end(self, epoch, logs={}):
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn("Early stopping requires %s available!" % self.monitor, RuntimeWarning)

        if current < self.value:
            if self.verbose > 0:
                print("Epoch %05d: early stopping THR" % epoch)
            self.model.stop_training = True

Y uso:

callbacks = [
    EarlyStoppingByLossVal(monitor='val_loss', value=0.00001, verbose=1),
    # EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

1
Solo si será útil para alguien, en mi caso usé monitor = 'loss', funcionó bien.
QtRoS

15
Parece que Keras se ha actualizado. La función de devolución de llamada EarlyStopping tiene min_delta integrado ahora. Ya no es necesario piratear el código fuente, ¡yay! stackoverflow.com/a/41459368/3345375
jkdev

3
Al volver a leer la pregunta y las respuestas, necesito corregirme: min_delta significa "Detente temprano si no hay suficiente mejora por época (o por varias épocas)". Sin embargo, el OP preguntó cómo "detenerse temprano cuando la pérdida está por debajo de cierto nivel".
jkdev

NameError: el nombre 'Callback' no está definido ... ¿Cómo lo arreglaré?
alyssaeliyah

2
Eliyah prueba esto: from keras.callbacks import Callback
ZFTurbo

26

La devolución de llamada keras.callbacks.EarlyStopping tiene un argumento min_delta. De la documentación de Keras:

min_delta: el cambio mínimo en la cantidad monitoreada para calificar como una mejora, es decir, un cambio absoluto menor que min_delta, contará como ninguna mejora.


3
Como referencia, aquí están los documentos de una versión anterior de Keras (1.1.0) en la que el argumento min_delta aún no se incluyó: faroit.github.io/keras-docs/1.1.0/callbacks/#earlystopping
jkdev

¿Cómo podría hacer que no se detuviera hasta que min_deltapersista en múltiples épocas?
zyxue

Hay otro parámetro para EarlyStopping llamado paciencia: número de épocas sin mejora después de las cuales se detendrá el entrenamiento.
devin

13

Una solución es llamar model.fit(nb_epoch=1, ...)dentro de un bucle for, luego puede poner una declaración de interrupción dentro del bucle for y hacer cualquier otro flujo de control personalizado que desee.


Sería bueno si hicieran una devolución de llamada que incluya una sola función que pueda hacer eso.
Honestidad

7

Resolví el mismo problema usando devolución de llamada personalizada.

En el siguiente código de devolución de llamada personalizado, asigne a THR el valor en el que desea detener el entrenamiento y agregue la devolución de llamada a su modelo.

from keras.callbacks import Callback

class stopAtLossValue(Callback):

        def on_batch_end(self, batch, logs={}):
            THR = 0.03 #Assign THR with the value at which you want to stop training.
            if logs.get('loss') <= THR:
                 self.model.stop_training = True

2

Mientras cursaba la especialización práctica de TensorFlow , aprendí una técnica muy elegante. Solo un poco modificado de la respuesta aceptada.

Pongamos el ejemplo con nuestros datos MNIST favoritos.

import tensorflow as tf

class new_callback(tf.keras.callbacks.Callback):
    def epoch_end(self, epoch, logs={}): 
        if(logs.get('accuracy')> 0.90): # select the accuracy
            print("\n !!! 90% accuracy, no further training !!!")
            self.model.stop_training = True

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 #normalize

callbacks = new_callback()

# model = tf.keras.models.Sequential([# define your model here])

model.compile(optimizer=tf.optimizers.Adam(),
          loss='sparse_categorical_crossentropy',
          metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

Entonces, aquí configuro el metrics=['accuracy'], y por lo tanto en la clase de devolución de llamada se establece la condición 'accuracy'> 0.90.

Puede elegir cualquier métrica y monitorear el entrenamiento como este ejemplo. Lo más importante es que puede establecer diferentes condiciones para diferentes métricas y usarlas simultáneamente.

¡Ojalá esto ayude!


el nombre de la función debe ser on_epoch_end
xarion

0

Para mí, el modelo solo dejaría de entrenarse si agregué una declaración de retorno después de establecer el parámetro stop_training en True porque estaba llamando después de self.model.evaluate. Por lo tanto, asegúrese de poner stop_training = True al final de la función o agregue una declaración de retorno.

def on_epoch_end(self, batch, logs):
        self.epoch += 1
        self.stoppingCounter += 1
        print('\nstopping counter \n',self.stoppingCounter)

        #Stop training if there hasn't been any improvement in 'Patience' epochs
        if self.stoppingCounter >= self.patience:
            self.model.stop_training = True
            return

        # Test on additional set if there is one
        if self.testingOnAdditionalSet:
            evaluation = self.model.evaluate(self.val2X, self.val2Y, verbose=0)
            self.validationLoss2.append(evaluation[0])
            self.validationAcc2.append(evaluation[1])enter code here

0

Si está usando un ciclo de entrenamiento personalizado, puede usar a collections.deque, que es una lista "continua" que se puede agregar, y los elementos de la izquierda aparecen cuando la lista es más larga que maxlen. Aquí está la línea:

loss_history = deque(maxlen=early_stopping + 1)

for epoch in range(epochs):
    fit(epoch)
    loss_history.append(test_loss.result().numpy())
    if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history)
            break

Aquí tienes un ejemplo completo:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Dense
from collections import deque

data, info = tfds.load('iris', split='train', as_supervised=True, with_info=True)

data = data.map(lambda x, y: (tf.cast(x, tf.int32), y))

train_dataset = data.take(120).batch(4)
test_dataset = data.skip(120).take(30).batch(4)

model = tf.keras.models.Sequential([
    Dense(8, activation='relu'),
    Dense(16, activation='relu'),
    Dense(info.features['label'].num_classes)])

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()

train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()

opt = tf.keras.optimizers.Adam(learning_rate=1e-3)


@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits = model(inputs, training=True)
        loss = loss_object(labels, logits)

    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_acc(labels, logits)


@tf.function
def test_step(inputs, labels):
    logits = model(inputs, training=False)
    loss = loss_object(labels, logits)
    test_loss(loss)
    test_acc(labels, logits)


def fit(epoch):
    template = 'Epoch {:>2} Train Loss {:.3f} Test Loss {:.3f} ' \
               'Train Acc {:.2f} Test Acc {:.2f}'

    train_loss.reset_states()
    test_loss.reset_states()
    train_acc.reset_states()
    test_acc.reset_states()

    for X_train, y_train in train_dataset:
        train_step(X_train, y_train)

    for X_test, y_test in test_dataset:
        test_step(X_test, y_test)

    print(template.format(
        epoch + 1,
        train_loss.result(),
        test_loss.result(),
        train_acc.result(),
        test_acc.result()
    ))


def main(epochs=50, early_stopping=10):
    loss_history = deque(maxlen=early_stopping + 1)

    for epoch in range(epochs):
        fit(epoch)
        loss_history.append(test_loss.result().numpy())
        if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history):
            print(f'\nEarly stopping. No validation loss '
                  f'improvement in {early_stopping} epochs.')
            break

if __name__ == '__main__':
    main(epochs=250, early_stopping=10)
Epoch  1 Train Loss 1.730 Test Loss 1.449 Train Acc 0.33 Test Acc 0.33
Epoch  2 Train Loss 1.405 Test Loss 1.220 Train Acc 0.33 Test Acc 0.33
Epoch  3 Train Loss 1.173 Test Loss 1.054 Train Acc 0.33 Test Acc 0.33
Epoch  4 Train Loss 1.006 Test Loss 0.935 Train Acc 0.33 Test Acc 0.33
Epoch  5 Train Loss 0.885 Test Loss 0.846 Train Acc 0.33 Test Acc 0.33
...
Epoch 89 Train Loss 0.196 Test Loss 0.240 Train Acc 0.89 Test Acc 0.87
Epoch 90 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 91 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 92 Train Loss 0.194 Test Loss 0.239 Train Acc 0.90 Test Acc 0.87

Early stopping. No validation loss improvement in 10 epochs.
Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.