Tensorflow: ¿como guardar / restaurar un modelo?


553

Después de entrenar un modelo en Tensorflow:

  1. ¿Cómo se guarda el modelo entrenado?
  2. ¿Cómo restaurar más tarde este modelo guardado?

¿Pudiste restaurar las variables utilizadas en el modelo de inicio? También estoy intentando exactamente el mismo problema, pero no puedo escribir un conjunto de variables que se usaron durante el entrenamiento del modelo de inicio (del cual tengo un archivo ckpt)
exAres

No he probado con el modelo de inicio. ¿Tiene la estructura de red del modelo con sus nombres? Debe replicar la red y luego cargar los pesos y los sesgos (el archivo ckpt) como explica Ryan. Tal vez algo ha cambiado desde Nov'15 y hay un enfoque más directo ahora, no estoy seguro
mathetes

Oh, está bien. He cargado otros modelos de tensorflow previamente entrenados anteriormente, pero estaba buscando especificaciones variables del modelo de inicio. Gracias.
exAres

1
Si restaura para continuar entrenando, solo use los puntos de control Saver. Si guarda el modelo para hacer referencia, solo las API de Tensorflow SavedModel.
HY G

Además, si está utilizando LSTM, tendrá un mapa de cadena a una lista de caracteres, ¡asegúrese de guardar y cargar esa lista en el mismo orden! Esto no se cubre al guardar los pesos del modelo y la red de gráficos del modelo y hará que parezca que su modelo no se cargó cuando cambia las sesiones o los datos cambian.
devssh

Respuestas:


119

Docs

De los documentos:

Salvar

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in path: %s" % save_path)

Restaurar

tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Check the values of the variables
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

Tensorflow 2

Esto todavía es beta, por lo que desaconsejaría por ahora. Si todavía quieres seguir ese camino, aquí está la tf.saved_modelguía de uso

Tensorflow <2

simple_save

Muchas buenas respuestas, para completar agregaré mis 2 centavos: simple_save . También un ejemplo de código independiente que usa la tf.data.DatasetAPI.

Python 3; Tensorflow 1.14

import tensorflow as tf
from tensorflow.saved_model import tag_constants

with tf.Graph().as_default():
    with tf.Session() as sess:
        ...

        # Saving
        inputs = {
            "batch_size_placeholder": batch_size_placeholder,
            "features_placeholder": features_placeholder,
            "labels_placeholder": labels_placeholder,
        }
        outputs = {"prediction": model_output}
        tf.saved_model.simple_save(
            sess, 'path/to/your/location/', inputs, outputs
        )

Restaurando:

graph = tf.Graph()
with restored_graph.as_default():
    with tf.Session() as sess:
        tf.saved_model.loader.load(
            sess,
            [tag_constants.SERVING],
            'path/to/your/location/',
        )
        batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
        features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
        labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
        prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')

        sess.run(prediction, feed_dict={
            batch_size_placeholder: some_value,
            features_placeholder: some_other_value,
            labels_placeholder: another_value
        })

Ejemplo independiente

Publicación original del blog

El siguiente código genera datos aleatorios por el bien de la demostración.

  1. Comenzamos creando los marcadores de posición. Retendrán los datos en tiempo de ejecución. A partir de ellos, creamos el Datasety luego su Iterator. Obtenemos el tensor generado por el iterador, llamado input_tensorque servirá como entrada a nuestro modelo.
  2. El modelo en sí está construido a partir de input_tensor: un RNN bidireccional basado en GRU seguido de un clasificador denso. Porque, porque no.
  3. La pérdida es una softmax_cross_entropy_with_logits, optimizada con Adam. Después de 2 épocas (de 2 lotes cada una), guardamos el modelo "entrenado" con tf.saved_model.simple_save. Si ejecuta el código como está, el modelo se guardará en una carpeta llamada simple/en su directorio de trabajo actual.
  4. En un nuevo gráfico, luego restauramos el modelo guardado con tf.saved_model.loader.load. Agarramos los marcadores de posición y logits con graph.get_tensor_by_namey la Iteratoroperación de inicialización con graph.get_operation_by_name.
  5. Por último, ejecutamos una inferencia para ambos lotes en el conjunto de datos y verificamos que tanto el modelo guardado como el restaurado produzcan los mismos valores. ¡Ellas hacen!

Código:

import os
import shutil
import numpy as np
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants


def model(graph, input_tensor):
    """Create the model which consists of
    a bidirectional rnn (GRU(10)) followed by a dense classifier

    Args:
        graph (tf.Graph): Tensors' graph
        input_tensor (tf.Tensor): Tensor fed as input to the model

    Returns:
        tf.Tensor: the model's output layer Tensor
    """
    cell = tf.nn.rnn_cell.GRUCell(10)
    with graph.as_default():
        ((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=cell,
            cell_bw=cell,
            inputs=input_tensor,
            sequence_length=[10] * 32,
            dtype=tf.float32,
            swap_memory=True,
            scope=None)
        outputs = tf.concat((fw_outputs, bw_outputs), 2)
        mean = tf.reduce_mean(outputs, axis=1)
        dense = tf.layers.dense(mean, 5, activation=None)

        return dense


def get_opt_op(graph, logits, labels_tensor):
    """Create optimization operation from model's logits and labels

    Args:
        graph (tf.Graph): Tensors' graph
        logits (tf.Tensor): The model's output without activation
        labels_tensor (tf.Tensor): Target labels

    Returns:
        tf.Operation: the operation performing a stem of Adam optimizer
    """
    with graph.as_default():
        with tf.variable_scope('loss'):
            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
                    logits=logits, labels=labels_tensor, name='xent'),
                    name="mean-xent"
                    )
        with tf.variable_scope('optimizer'):
            opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss)
        return opt_op


if __name__ == '__main__':
    # Set random seed for reproducibility
    # and create synthetic data
    np.random.seed(0)
    features = np.random.randn(64, 10, 30)
    labels = np.eye(5)[np.random.randint(0, 5, (64,))]

    graph1 = tf.Graph()
    with graph1.as_default():
        # Random seed for reproducibility
        tf.set_random_seed(0)
        # Placeholders
        batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph')
        features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph')
        labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph')
        # Dataset
        dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
        dataset = dataset.batch(batch_size_ph)
        iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
        dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
        input_tensor, labels_tensor = iterator.get_next()

        # Model
        logits = model(graph1, input_tensor)
        # Optimization
        opt_op = get_opt_op(graph1, logits, labels_tensor)

        with tf.Session(graph=graph1) as sess:
            # Initialize variables
            tf.global_variables_initializer().run(session=sess)
            for epoch in range(3):
                batch = 0
                # Initialize dataset (could feed epochs in Dataset.repeat(epochs))
                sess.run(
                    dataset_init_op,
                    feed_dict={
                        features_data_ph: features,
                        labels_data_ph: labels,
                        batch_size_ph: 32
                    })
                values = []
                while True:
                    try:
                        if epoch < 2:
                            # Training
                            _, value = sess.run([opt_op, logits])
                            print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0]))
                            batch += 1
                        else:
                            # Final inference
                            values.append(sess.run(logits))
                            print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0]))
                            batch += 1
                    except tf.errors.OutOfRangeError:
                        break
            # Save model state
            print('\nSaving...')
            cwd = os.getcwd()
            path = os.path.join(cwd, 'simple')
            shutil.rmtree(path, ignore_errors=True)
            inputs_dict = {
                "batch_size_ph": batch_size_ph,
                "features_data_ph": features_data_ph,
                "labels_data_ph": labels_data_ph
            }
            outputs_dict = {
                "logits": logits
            }
            tf.saved_model.simple_save(
                sess, path, inputs_dict, outputs_dict
            )
            print('Ok')
    # Restoring
    graph2 = tf.Graph()
    with graph2.as_default():
        with tf.Session(graph=graph2) as sess:
            # Restore saved values
            print('\nRestoring...')
            tf.saved_model.loader.load(
                sess,
                [tag_constants.SERVING],
                path
            )
            print('Ok')
            # Get restored placeholders
            labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0')
            features_data_ph = graph2.get_tensor_by_name('features_data_ph:0')
            batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0')
            # Get restored model output
            restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0')
            # Get dataset initializing operation
            dataset_init_op = graph2.get_operation_by_name('dataset_init')

            # Initialize restored dataset
            sess.run(
                dataset_init_op,
                feed_dict={
                    features_data_ph: features,
                    labels_data_ph: labels,
                    batch_size_ph: 32
                }

            )
            # Compute inference for both batches in dataset
            restored_values = []
            for i in range(2):
                restored_values.append(sess.run(restored_logits))
                print('Restored values: ', restored_values[i][0])

    # Check if original inference and restored inference are equal
    valid = all((v == rv).all() for v, rv in zip(values, restored_values))
    print('\nInferences match: ', valid)

Esto imprimirá:

$ python3 save_and_restore.py

Epoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595   0.12804556  0.20013677 -0.08229901]
Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045  -0.00107776]
Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792  -0.00602257  0.07465433  0.11674127]
Epoch 1, batch 1 | Sample value: [-0.05275984  0.05981954 -0.15913513 -0.3244143   0.10673307]
Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

Saving...
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb'
Ok

Restoring...
INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables'
Ok
Restored values:  [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
Restored values:  [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

Inferences match:  True

1
Soy principiante y necesito más explicaciones ...: si tengo un modelo CNN, ¿debo almacenar solo 1. input_placeholder 2. labels_placeholder y 3. output_of_cnn? O todo el intermedio tf.contrib.layers?
Llueve el

2
El gráfico está completamente restaurado. Podrías comprobarlo corriendo [n.name for n in graph2.as_graph_def().node]. Como dice la documentación, guardar simple tiene como objetivo simplificar la interacción con el servicio de tensorflow, este es el punto de los argumentos; sin embargo, otras variables aún se restauran, de lo contrario no se produciría inferencia. Simplemente tome sus variables de interés como lo hice en el ejemplo. Consulte la documentación
ted

@ted ¿cuándo usaría tf.saved_model.simple_save vs tf.train.Saver ()? Desde mi intuición, usaría tf.train.Saver () durante el entrenamiento y para almacenar diferentes momentos en el tiempo. Usaría tf.saved_model.simple_save cuando termine el entrenamiento para usarlo en producción. (Pedí lo mismo también en un comentario aquí )
loco.loop

1
Bien, supongo, pero ¿también funciona con los modelos de modo Eager y tfe.Saver?
Geoffrey Anderson

1
sin global_stepargumento, si se detiene, intente retomar el entrenamiento nuevamente, pensará que es un paso uno. Al menos arruinará las visualizaciones de su tensorboard
Monica Heddneck

252

Estoy mejorando mi respuesta para agregar más detalles para guardar y restaurar modelos.

En (y después) Tensorflow versión 0.11 :

Guarda el modelo:

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

Restaurar el modelo:

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 

Este y algunos casos de uso más avanzados se han explicado muy bien aquí.

Un tutorial rápido y completo para guardar y restaurar modelos de Tensorflow


3
+1 para este # Acceso a las variables guardadas directamente print (sess.run ('bias: 0')) # Esto imprimirá 2, que es el valor del sesgo que guardamos. Ayuda mucho a fines de depuración para ver si el modelo se carga correctamente. las variables se pueden obtener con "All_varaibles = tf.get_collection (tf.GraphKeys.GLOBAL_VARIABLES". Además, "sess.run (tf.global_variables_initializer ())" tiene que estar antes de restaurar.
LGG

1
¿Está seguro de que tenemos que ejecutar global_variables_initializer nuevamente? Restablecí mi gráfico con global_variable_initialization, y me da una salida diferente cada vez en los mismos datos. Así que comenté la inicialización y simplemente restauré el gráfico, la variable de entrada y las operaciones, y ahora funciona bien.
Aditya Shinde

@AdityaShinde No entiendo por qué siempre obtengo valores diferentes cada vez. Y no incluí el paso de inicialización variable para restaurar. Estoy usando mi propio código por cierto.
Chaine

@AdityaShinde: no necesita una operación inicial ya que los valores ya están inicializados por la función de restauración, así que elimínelos. Sin embargo, no estoy seguro de por qué obtuviste resultados diferentes al usar init op.
sankit

55
@sankit Cuando restaura los tensores, ¿por qué agrega :0a los nombres?
Sahar Rabinoviz

177

En (y después) TensorFlow versión 0.11.0RC1, puede guardar y restaurar su modelo directamente llamando tf.train.export_meta_graphy de tf.train.import_meta_graphacuerdo con https://www.tensorflow.org/programmers_guide/meta_graph .

Guardar el modelo

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta

Restaurar el modelo

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

44
¿Cómo cargar variables del modelo guardado? ¿Cómo copiar valores en alguna otra variable?
Neel

99
No puedo hacer que este código funcione. El modelo se guarda pero no puedo restaurarlo. Me está dando este error. <built-in function TF_Run> returned a result with an error set
Saad Qureshi

2
Cuando después de restaurar accedo a las variables como se muestra arriba, funciona. Pero no puedo obtener las variables más directamente usando tf.get_variable_scope().reuse_variables()seguido de var = tf.get_variable("varname"). Esto me da el error: "ValueError: el variable varname no existe o no se creó con tf.get_variable ()". ¿Por qué? ¿No debería ser esto posible?
Johann Petrak el

44
Esto funciona bien solo para variables, pero ¿cómo puede obtener acceso a un marcador de posición y alimentar valores después de restaurar el gráfico?
kbrose

11
Esto solo muestra cómo restaurar las variables. ¿Cómo puede restaurar todo el modelo y probarlo en nuevos datos sin redefinir la red?
Chaine

127

Para la versión TensorFlow <0.11.0RC1:

Los puntos de control que se guardan contienen valores para los Variables en su modelo, no el modelo / gráfico en sí, lo que significa que el gráfico debe ser el mismo cuando restaure el punto de control.

Aquí hay un ejemplo para una regresión lineal donde hay un ciclo de entrenamiento que guarda puntos de control de variables y una sección de evaluación que restaurará las variables guardadas en una ejecución anterior y calculará predicciones. Por supuesto, también puede restaurar variables y continuar entrenando si lo desea.

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))

...more setup for optimization and what not...

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if FLAGS.train:
        for i in xrange(FLAGS.training_steps):
            ...training loop...
            if (i + 1) % FLAGS.checkpoint_steps == 0:
                saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                           global_step=i+1)
    else:
        # Here's where you're restoring the variables w and b.
        # Note that the graph is exactly as it was when the variables were
        # saved in a prior training run.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            ...no checkpoint found...

        # Now you can run the model to get predictions
        batch_x = ...load some data...
        predictions = sess.run(y_hat, feed_dict={x: batch_x})

Aquí están los documentos para Variables, que cubren el ahorro y la restauración. Y aquí están los documentos para el Saver.


1
Las BANDERAS son definidas por el usuario. Aquí hay un ejemplo de cómo definirlos: github.com/tensorflow/tensorflow/blob/master/tensorflow/…
Ryan Sepassi

¿en qué formato batch_xdebe ser? ¿Binario? Numpy array?
pepe

@pepe Numpy Arrary debería estar bien. Y el tipo de elemento debe corresponder al tipo del marcador de posición. [enlace] tensorflow.org/versions/r0.9/api_docs/python/…
Donny

BANDERAS da error undefined. ¿Me puede decir cuál es def de FLAGS para este código? @RyanSepassi
Muhammad Hannan

Para que sea más explícito: Las versiones recientes de Tensorflow no permiten almacenar el modelo / gráfico. [No estaba claro para mí, qué aspectos de la respuesta se aplican a la restricción <0.11. Dado el gran número de votos a favor, tuve la tentación de creer que esta declaración general sigue siendo cierta para las versiones recientes.]
bluenote10

78

Mi entorno: Python 3.6, Tensorflow 1.3.0

Aunque ha habido muchas soluciones, la mayoría de ellas se basan en tf.train.Saver. Cuando cargamos un .ckptsalvados por Saver, tenemos que redefinir la red, ya sea tensorflow o utilizar algún nombre raro y recordado duro, por ejemplo 'placehold_0:0', 'dense/Adam/Weight:0'. Aquí recomiendo usar tf.saved_model, un ejemplo más simple que se muestra a continuación, puede obtener más información al servir un modelo TensorFlow :

Guarda el modelo:

import tensorflow as tf

# define the tensorflow network and do some trains
x = tf.placeholder("float", name="x")
w = tf.Variable(2.0, name="w")
b = tf.Variable(0.0, name="bias")

h = tf.multiply(x, w)
y = tf.add(h, b, name="y")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# save the model
export_path =  './savedmodel'
builder = tf.saved_model.builder.SavedModelBuilder(export_path)

tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

prediction_signature = (
  tf.saved_model.signature_def_utils.build_signature_def(
      inputs={'x_input': tensor_info_x},
      outputs={'y_output': tensor_info_y},
      method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

builder.add_meta_graph_and_variables(
  sess, [tf.saved_model.tag_constants.SERVING],
  signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          prediction_signature 
  },
  )
builder.save()

Cargue el modelo:

import tensorflow as tf
sess=tf.Session() 
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'x_input'
output_key = 'y_output'

export_path =  './savedmodel'
meta_graph_def = tf.saved_model.loader.load(
           sess,
          [tf.saved_model.tag_constants.SERVING],
          export_path)
signature = meta_graph_def.signature_def

x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name

x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)

y_out = sess.run(y, {x: 3.0})

44
+1 para un gran ejemplo de la API SavedModel. Sin embargo, ¡me gustaría que tu sección Guardar el modelo mostrara un ciclo de entrenamiento como la respuesta de Ryan Sepassi! Me doy cuenta de que esta es una vieja pregunta, pero esta respuesta es uno de los pocos (y valiosos) ejemplos de SavedModel que encontré en Google.
Dylan F

@ Tom, esta es una gran respuesta: solo una dirigida al nuevo SavedModel. ¿Podría echar un vistazo a esta pregunta de SavedModel? stackoverflow.com/questions/48540744/…
bluesummers

Ahora haga que todo funcione correctamente con los modelos TF Eager. Google aconsejó en su presentación de 2018 que todos se alejen del código gráfico TF.
Geoffrey Anderson el

55

Hay dos partes en el modelo, la definición del modelo, guardada Supervisorcomo graph.pbtxten el directorio del modelo y los valores numéricos de los tensores, guardados en archivos de puntos de control como model.ckpt-1003418.

La definición del modelo se puede restaurar usando tf.import_graph_def, y los pesos se restauran usando Saver.

Sin embargo, Saverusa una lista de variables de retención de colección especial que se adjunta al modelo Graph, y esta colección no se inicializa usando import_graph_def, por lo que no puede usar las dos juntas en este momento (está en nuestra hoja de ruta para solucionarlo). Por ahora, debe usar el enfoque de Ryan Sepassi: construir manualmente un gráfico con nombres de nodo idénticos y usarlo Saverpara cargar los pesos en él.

(Alternativamente, podría piratearlo usando import_graph_def, usando , creando variables manualmente, y usando tf.add_to_collection(tf.GraphKeys.VARIABLES, variable)para cada variable, luego usando Saver)


En el ejemplo classify_image.py que usa inceptionv3, solo se carga el graphdef. ¿Significa que ahora GraphDef también contiene la variable?
jrabary

1
@jrabary El modelo probablemente ha sido congelado .
Eric Platon

1
Hola, soy nuevo en Tensorflow y tengo problemas para guardar mi modelo. Realmente lo agradecería si me pudieran ayudar stackoverflow.com/questions/48083474/…
Ruchir Baronia

39

También puedes tomar este camino más fácil.

Paso 1: inicializa todas tus variables

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Similarly, W2, B2, W3, .....

Paso 2: guarde la sesión dentro del modelo Savery guárdela

model_saver = tf.train.Saver()

# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")

Paso 3: restaurar el modelo

with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Model restored.") 
    print('Initialized')

Paso 4: verifica tu variable

W1 = session.run(W1)
print(W1)

Mientras se ejecuta en una instancia de Python diferente, use

with tf.Session() as sess:
    # Restore latest checkpoint
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Initalize the variables
    sess.run(tf.global_variables_initializer())

    # Get default graph (supply your custom graph if you have one)
    graph = tf.get_default_graph()

    # It will give tensor object
    W1 = graph.get_tensor_by_name('W1:0')

    # To get the value (numpy array)
    W1_value = session.run(W1)

Hola, ¿cómo puedo guardar el modelo después de suponer 3000 iteraciones, similares a Caffe? Descubrí que tensorflow guarda solo los últimos modelos a pesar de que concateno el número de iteración con el modelo para diferenciarlo entre todas las iteraciones. Me refiero a model_3000.ckpt, model_6000.ckpt, --- model_100000.ckpt. ¿Puede explicar amablemente por qué no guarda todo, sino que solo guarda las últimas 3 iteraciones?
khan


3
¿Hay algún método para guardar todas las variables / nombres de operación guardados en el gráfico?
Moondra

21

En la mayoría de los casos, guardar y restaurar desde el disco usando a tf.train.Saveres su mejor opción:

... # build your model
saver = tf.train.Saver()

with tf.Session() as sess:
    ... # train the model
    saver.save(sess, "/tmp/my_great_model")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

También puede guardar / restaurar la estructura del gráfico en sí (consulte la documentación de MetaGraph para más detalles). Por defecto, Saverguarda la estructura del gráfico en un .metaarchivo. Puedes llamar import_meta_graph()para restaurarlo. Restaura la estructura del gráfico y devuelve un Saverque puede usar para restaurar el estado del modelo:

saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

Sin embargo, hay casos en los que necesita algo mucho más rápido. Por ejemplo, si implementa una detención temprana, desea guardar puntos de control cada vez que el modelo mejora durante el entrenamiento (según lo medido en el conjunto de validación), luego, si no hay progreso durante algún tiempo, desea volver al mejor modelo. Si guarda el modelo en el disco cada vez que mejora, ralentizará enormemente el entrenamiento. El truco es guardar los estados variables en la memoria , luego restaurarlos más tarde:

... # build your model

# get a handle on the graph nodes we need to save/restore the model
graph = tf.get_default_graph()
gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
init_values = [assign_op.inputs[1] for assign_op in assign_ops]

with tf.Session() as sess:
    ... # train the model

    # when needed, save the model state to memory
    gvars_state = sess.run(gvars)

    # when needed, restore the model state
    feed_dict = {init_value: val
                 for init_value, val in zip(init_values, gvars_state)}
    sess.run(assign_ops, feed_dict=feed_dict)

Una explicación rápida: cuando crea una variable X, TensorFlow crea automáticamente una operación de asignación X/Assignpara establecer el valor inicial de la variable. En lugar de crear marcadores de posición y operaciones de asignación adicionales (lo que haría que el gráfico fuera desordenado), solo usamos estas operaciones de asignación existentes. La primera entrada de cada asignación op es una referencia a la variable que se supone que debe inicializar, y la segunda entrada ( assign_op.inputs[1]) es el valor inicial. Entonces, para establecer cualquier valor que queramos (en lugar del valor inicial), necesitamos usar feed_dictay reemplazar el valor inicial. Sí, TensorFlow le permite alimentar un valor para cualquier operación, no solo para marcadores de posición, por lo que funciona bien.


Gracias por la respuesta. Tengo una pregunta similar sobre cómo convertir un solo archivo .ckpt a dos archivos .index y .data (por ejemplo, para modelos de inicio pre-capacitados disponibles en tf.slim). Mi pregunta está aquí: stackoverflow.com/questions/47762114/…
Amir

Hola, soy nuevo en Tensorflow y tengo problemas para guardar mi modelo. Realmente lo agradecería si me pudieran ayudar stackoverflow.com/questions/48083474/…
Ruchir Baronia

17

Como dijo Yaroslav, puede piratear la restauración desde un gráfico_def y un punto de control importando el gráfico, creando variables manualmente y luego utilizando un protector.

Implementé esto para mi uso personal, así que pensé en compartir el código aquí.

Enlace: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

(Esto es, por supuesto, un truco, y no hay garantía de que los modelos guardados de esta manera sigan siendo legibles en futuras versiones de TensorFlow).


14

Si es un modelo guardado internamente, solo debe especificar un restaurador para todas las variables como

restorer = tf.train.Saver(tf.all_variables())

y úselo para restaurar variables en una sesión actual:

restorer.restore(self._sess, model_file)

Para el modelo externo, debe especificar la asignación de los nombres de sus variables a sus nombres de variables. Puede ver los nombres de las variables del modelo con el comando

python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt

El script inspect_checkpoint.py se puede encontrar en la carpeta './tensorflow/python/tools' de la fuente de Tensorflow.

Para especificar el mapeo, puede usar mi Tensorflow-Worklab , que contiene un conjunto de clases y scripts para entrenar y reentrenar diferentes modelos. Incluye un ejemplo de reciclaje de modelos de ResNet, que se encuentra aquí.


all_variables()ahora está en desuso
MiniQuark

Hola, soy nuevo en Tensorflow y tengo problemas para guardar mi modelo. Realmente lo agradecería si me pudieran ayudar stackoverflow.com/questions/48083474/…
Ruchir Baronia

12

Aquí está mi solución simple para los dos casos básicos que difieren en si desea cargar el gráfico del archivo o compilarlo durante el tiempo de ejecución.

Esta respuesta es válida para Tensorflow 0.12+ (incluido 1.0).

Reconstruyendo el gráfico en código

Ahorro

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

Cargando

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # now you can use the graph, continue training or whatever

Cargando también el gráfico desde un archivo

Cuando utilice esta técnica, asegúrese de que todas sus capas / variables hayan establecido explícitamente nombres únicos.De lo contrario, Tensorflow hará que los nombres sean únicos y, por lo tanto, serán diferentes de los nombres almacenados en el archivo. No es un problema en la técnica anterior, porque los nombres están "destrozados" de la misma manera tanto en la carga como en el almacenamiento.

Ahorro

graph = ... # build the graph

for op in [ ... ]:  # operators you want to use after restoring the model
    tf.add_to_collection('ops_to_restore', op)

saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

Cargando

with ... as sess:  # your session object
    saver = tf.train.import_meta_graph('my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection

-1 Comenzar tu respuesta descartando "todas las demás respuestas aquí" es un poco duro. Dicho esto, rechacé por otras razones: definitivamente debe guardar todas las variables globales, no solo las variables entrenables. Por ejemplo, la global_stepvariable y los promedios móviles de la normalización de lotes son variables no entrenables, pero definitivamente vale la pena guardar ambas. Además, debe distinguir más claramente la construcción del gráfico de la ejecución de la sesión, por ejemplo Saver(...).save(), creará nuevos nodos cada vez que lo ejecute. Probablemente no sea lo que quieres. Y hay más ...: /
MiniQuark

@MiniQuark ok, gracias por tus comentarios, editaré la respuesta de acuerdo a tus sugerencias;)
Martin Pecka

10

También puede consultar ejemplos en TensorFlow / skflow , que ofrece savey restoremétodos que pueden ayudarlo a administrar fácilmente sus modelos. Tiene parámetros que también puede controlar con qué frecuencia desea hacer una copia de seguridad de su modelo.


9

Si usa tf.train.MonitoredTrainingSession como sesión predeterminada, no necesita agregar código adicional para guardar / restaurar cosas. Simplemente pase un nombre de directorio de punto de control al constructor de MonitoredTrainingSession, usará ganchos de sesión para manejarlos.


usando tf.train.Supervisor se encargará de crear una sesión de este tipo para usted y le proporcionará una solución más completa.
Mark

1
@Mark tf.train.Supervisor está en desuso
Changming Sun

¿Tiene algún enlace que respalde la afirmación de que el Supervisor está en desuso? No vi nada que indique que este sea el caso.
Marcar el


Gracias por la URL: verifiqué con la fuente original de la información y me dijeron que probablemente existirá hasta el final de la serie TF 1.x, pero no hay garantías después de eso.
Mark

8

Todas las respuestas aquí son geniales, pero quiero agregar dos cosas.

Primero, para explicar la respuesta de @ user7505159, puede ser importante agregar "./" al principio del nombre del archivo que está restaurando.

Por ejemplo, puede guardar un gráfico sin "./" en el nombre del archivo de esta manera:

# Some graph defined up here with specific names

saver = tf.train.Saver()
save_file = 'model.ckpt'

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

Pero para restaurar el gráfico, es posible que deba anteponer un "./" al nombre_archivo:

# Same graph defined up here

saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, save_file)

No siempre necesitará el "./", pero puede causar problemas dependiendo de su entorno y versión de TensorFlow.

También quiero mencionar que sess.run(tf.global_variables_initializer())puede ser importante antes de restaurar la sesión.

Si recibe un error con respecto a las variables no inicializadas al intentar restaurar una sesión guardada, asegúrese de incluir sess.run(tf.global_variables_initializer())antes de la saver.restore(sess, save_file)línea. Puede ahorrarte un dolor de cabeza.


7

Como se describe en el número 6255 :

use '**./**model_name.ckpt'
saver.restore(sess,'./my_model_final.ckpt')

en vez de

saver.restore('my_model_final.ckpt')

7

Según la nueva versión de Tensorflow, tf.train.Checkpointes la forma preferible de guardar y restaurar un modelo:

Checkpoint.savey Checkpoint.restoreescribir y leer puntos de control basados ​​en objetos, en contraste con tf.train.Saver que escribe y lee puntos de control basados ​​en variables.name. Los puntos de verificación basados ​​en objetos guardan un gráfico de dependencias entre los objetos de Python (Capas, Optimizadores, Variables, etc.) con bordes con nombre, y este gráfico se usa para unir variables al restaurar un punto de verificación. Puede ser más robusto a los cambios en el programa Python, y ayuda a admitir restauración en crear para variables cuando se ejecuta con entusiasmo. Prefiero tf.train.Checkpointsobre tf.train.Savernuevo código .

Aquí hay un ejemplo:

import tensorflow as tf
import os

tf.enable_eager_execution()

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)

Más información y ejemplo aquí.


7

Para tensorflow 2.0 , es tan simple como

# Save the model
model.save('path_to_my_model.h5')

Restaurar:

new_model = tensorflow.keras.models.load_model('path_to_my_model.h5')

¿Qué pasa con todas las operaciones y variables tf personalizadas que no son parte del objeto modelo? ¿Se salvarán de alguna manera cuando llames save () en el modelo? Tengo varias expresiones personalizadas de pérdida y probabilidad de flujo de tensor que se utilizan en la red de inferencia y generación, pero que no forman parte de mi modelo. Mi objeto modelo keras solo contiene las capas densas y conv. En TF 1 acabo de llamar al método de guardar y podría estar seguro de que todas las operaciones y tensores utilizados en mi gráfico se guardarían. En TF2 no veo cómo se guardarán las operaciones que de alguna manera no se agregan al modelo keras.
Kristof

¿Hay más información sobre la restauración de modelos en TF 2.0? No puedo restaurar los pesos de los archivos de puntos de control generados a través de la API C, consulte: stackoverflow.com/questions/57944786/…
jregalad


5

tf.keras Ahorro de modelo con TF2.0

Veo excelentes respuestas para guardar modelos usando TF1.x. Quiero proporcionar un par de punteros más para guardar tensorflow.kerasmodelos, lo cual es un poco complicado ya que hay muchas maneras de guardar un modelo.

Aquí estoy proporcionando un ejemplo de guardar un tensorflow.kerasmodelo en la model_pathcarpeta en el directorio actual. Esto funciona bien con el tensorflow más reciente (TF2.0). Actualizaré esta descripción si hay algún cambio en el futuro cercano.

Guardar y cargar todo el modelo

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

#import data
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# create a model
def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
# compile the model
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
  return model

# Create a basic model instance
model=create_model()

model.fit(x_train, y_train, epochs=1)
loss, acc = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

# Save entire model to a HDF5 file
model.save('./model_path/my_model.h5')

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('./model_path/my_model.h5')
loss, acc = new_model.evaluate(x_test, y_test)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

Guardar y cargar pesas modelo solamente

Si está interesado en guardar solo los pesos del modelo y luego cargar los pesos para restaurar el modelo, entonces

model.fit(x_train, y_train, epochs=5)
loss, acc = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

Guardar y restaurar usando la devolución de llamada de punto de control Keras

# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True,
    # Save weights, every 5-epochs.
    period=5)

model = create_model()
model.save_weights(checkpoint_path.format(epoch=0))
model.fit(train_images, train_labels,
          epochs = 50, callbacks = [cp_callback],
          validation_data = (test_images,test_labels),
          verbose=0)

latest = tf.train.latest_checkpoint(checkpoint_dir)

new_model = create_model()
new_model.load_weights(latest)
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

modelo de guardado con métricas personalizadas

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Custom Loss1 (for example) 
@tf.function() 
def customLoss1(yTrue,yPred):
  return tf.reduce_mean(yTrue-yPred) 

# Custom Loss2 (for example) 
@tf.function() 
def customLoss2(yTrue, yPred):
  return tf.reduce_mean(tf.square(tf.subtract(yTrue,yPred))) 

def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),  
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy', customLoss1, customLoss2])
  return model

# Create a basic model instance
model=create_model()

# Fit and evaluate model 
model.fit(x_train, y_train, epochs=1)
loss, acc,loss1, loss2 = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

model.save("./model.h5")

new_model=tf.keras.models.load_model("./model.h5",custom_objects={'customLoss1':customLoss1,'customLoss2':customLoss2})

Guardar el modelo de Keras con operaciones personalizadas

Cuando tenemos operaciones personalizadas como en el siguiente caso ( tf.tile), necesitamos crear una función y envolver con una capa Lambda. De lo contrario, el modelo no se puede guardar.

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda
from tensorflow.keras import Model

def my_fun(a):
  out = tf.tile(a, (1, tf.shape(a)[0]))
  return out

a = Input(shape=(10,))
#out = tf.tile(a, (1, tf.shape(a)[0]))
out = Lambda(lambda x : my_fun(x))(a)
model = Model(a, out)

x = np.zeros((50,10), dtype=np.float32)
print(model(x).numpy())

model.save('my_model.h5')

#load the model
new_model=tf.keras.models.load_model("my_model.h5")

Creo que he cubierto algunas de las muchas formas de guardar el modelo tf.keras. Sin embargo, hay muchas otras formas. Comente a continuación si ve que su caso de uso no está cubierto anteriormente. ¡Gracias!


3

Use tf.train.Saver para guardar un modelo, remerber, necesita especificar var_list, si desea reducir el tamaño del modelo. Val_list puede ser tf.trainable_variables o tf.global_variables.


3

Puede guardar las variables en la red usando

saver = tf.train.Saver() 
saver.save(sess, 'path of save/fileName.ckpt')

Para restaurar la red para su reutilización posterior o en otro script, use:

saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint('path of save/')
sess.run(....) 

Puntos importantes:

  1. sess debe ser igual entre la primera y las últimas ejecuciones (estructura coherente).
  2. saver.restore necesita la ruta de la carpeta de los archivos guardados, no una ruta de archivo individual.

2

Donde quiera guardar el modelo,

self.saver = tf.train.Saver()
with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            ...
            self.saver.save(sess, filename)

Asegúrese de que todos tf.Variabletengan nombres, porque es posible que desee restaurarlos más tarde utilizando sus nombres. Y donde quieres predecir,

saver = tf.train.import_meta_graph(filename)
name = 'name given when you saved the file' 
with tf.Session() as sess:
      saver.restore(sess, name)
      print(sess.run('W1:0')) #example to retrieve by variable name

Asegúrese de que el protector se ejecute dentro de la sesión correspondiente. Recuerda eso, si usas eltf.train.latest_checkpoint('./') , solo el último punto de control.


2

Estoy en la versión:

tensorflow (1.13.1)
tensorflow-gpu (1.13.1)

Manera simple es

Salvar:

model.save("model.h5")

Restaurar:

model = tf.keras.models.load_model("model.h5")

2

Para tensorflow-2.0

es muy simple.

import tensorflow as tf

SALVAR

model.save("model_name")

RESTAURAR

model = tf.keras.models.load_model('model_name')

1

Siguiendo la respuesta de @Vishnuvardhan Janapati, aquí hay otra forma de guardar y recargar el modelo con capa / métrica / pérdida personalizada bajo TensorFlow 2.0.0

import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils.generic_utils import get_custom_objects

# custom loss (for example)  
def custom_loss(y_true,y_pred):
  return tf.reduce_mean(y_true - y_pred)
get_custom_objects().update({'custom_loss': custom_loss}) 

# custom loss (for example) 
class CustomLayer(Layer):
  def __init__(self, ...):
      ...
  # define custom layer and all necessary custom operations inside custom layer

get_custom_objects().update({'CustomLayer': CustomLayer})  

De esta manera, una vez que se ha ejecutado este tipo de códigos, y salvó su modelo con tf.keras.models.save_modelo model.saveo ModelCheckpointde devolución de llamada, puede volver a cargar el modelo sin la necesidad de objetos personalizados precisos, tan simple como

new_model = tf.keras.models.load_model("./model.h5"})

0

En la nueva versión de tensorflow 2.0, el proceso de guardar / cargar un modelo es mucho más fácil. Debido a la implementación de la API de Keras, una API de alto nivel para TensorFlow.

Para guardar un modelo: Consulte la documentación para referencia: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model

tf.keras.models.save_model(model_name, filepath, save_format)

Para cargar un modelo:

https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model

model = tf.keras.models.load_model(filepath)

0

Aquí hay un ejemplo simple que usa el formato Tensorflow 2.0 SavedModel (que es el formato recomendado, según los documentos ) para un clasificador de conjunto de datos MNIST simple, usando la API funcional Keras sin demasiada fantasía:

# Imports
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt

# Load data
mnist = tf.keras.datasets.mnist # 28 x 28
(x_train,y_train), (x_test, y_test) = mnist.load_data()

# Normalize pixels [0,255] -> [0,1]
x_train = tf.keras.utils.normalize(x_train,axis=1)
x_test = tf.keras.utils.normalize(x_test,axis=1)

# Create model
input = Input(shape=(28,28), dtype='float64', name='graph_input')
x = Flatten()(input)
x = Dense(128, activation='relu')(x)
x = Dense(128, activation='relu')(x)
output = Dense(10, activation='softmax', name='graph_output', dtype='float64')(x)
model = Model(inputs=input, outputs=output)

model.compile(optimizer='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

# Train
model.fit(x_train, y_train, epochs=3)

# Save model in SavedModel format (Tensorflow 2.0)
export_path = 'model'
tf.saved_model.save(model, export_path)

# ... possibly another python program 

# Reload model
loaded_model = tf.keras.models.load_model(export_path) 

# Get image sample for testing
index = 0
img = x_test[index] # I normalized the image on a previous step

# Predict using the signature definition (Tensorflow 2.0)
predict = loaded_model.signatures["serving_default"]
prediction = predict(tf.constant(img))

# Show results
print(np.argmax(prediction['graph_output']))  # prints the class number
plt.imshow(x_test[index], cmap=plt.cm.binary)  # prints the image

¿Qué es serving_default?

Es el nombre de la definición de firma de la etiqueta que seleccionó (en este caso, servese seleccionó la etiqueta predeterminada ). Además, aquí se explica cómo encontrar las etiquetas y firmas de un modelo usando saved_model_cli.

Renuncias

Este es solo un ejemplo básico si solo desea ponerlo en funcionamiento, pero de ninguna manera es una respuesta completa, tal vez pueda actualizarlo en el futuro. Solo quería dar un ejemplo simple usando elSavedModel TF 2.0 porque no he visto uno, ni siquiera este simple, en ningún lado.

La respuesta de @ Tom es un ejemplo de SavedModel, pero no funcionará en Tensorflow 2.0, porque desafortunadamente hay algunos cambios importantes.

La respuesta de @ Vishnuvardhan Janapati dice TF 2.0, pero no es para el formato SavedModel.

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.