TensorFlow guardando / cargando un gráfico desde un archivo


98

De lo que he recopilado hasta ahora, hay varias formas diferentes de volcar un gráfico de TensorFlow en un archivo y luego cargarlo en otro programa, pero no he podido encontrar ejemplos / información claros sobre cómo funcionan. Lo que ya sé es esto:

  1. Guarde las variables del modelo en un archivo de punto de control (.ckpt) usando a tf.train.Saver()y restáurelas más tarde ( fuente )
  2. Guarde un modelo en un archivo .pb y cárguelo de nuevo usando tf.train.write_graph()y tf.import_graph_def()( fuente )
  3. Cargue un modelo desde un archivo .pb, reentrenarlo y volcarlo en un nuevo archivo .pb usando Bazel ( fuente )
  4. Congele el gráfico para guardar el gráfico y los pesos juntos ( fuente )
  5. Úselo as_graph_def()para guardar el modelo, y para pesos / variables, mapearlos en constantes ( fuente )

Sin embargo, no he podido aclarar varias preguntas sobre estos diferentes métodos:

  1. Con respecto a los archivos de puntos de control, ¿solo guardan los pesos entrenados de un modelo? ¿Los archivos de puntos de control se pueden cargar en un nuevo programa y usarse para ejecutar el modelo, o simplemente sirven como formas de guardar los pesos en un modelo en un momento / etapa determinada?
  2. Respecto tf.train.write_graph(), ¿también se guardan los pesos / variables?
  3. Con respecto a Bazel, ¿solo puede guardar / cargar archivos .pb para volver a capacitarse? ¿Existe un comando simple de Bazel solo para volcar un gráfico en un .pb?
  4. Con respecto a la congelación, ¿se puede cargar un gráfico congelado en el uso tf.import_graph_def()?
  5. La demostración de Android para TensorFlow se carga en el modelo Inception de Google desde un archivo .pb. Si quisiera sustituir mi propio archivo .pb, ¿cómo lo haría? ¿Necesitaría cambiar algún código / método nativo?
  6. En general, ¿cuál es exactamente la diferencia entre todos estos métodos? O más ampliamente, ¿cuál es la diferencia entre as_graph_def()/.ckpt/.pb?

En resumen, lo que estoy buscando es un método para guardar un gráfico (como en, las diversas operaciones y demás) y sus pesos / variables en un archivo, que luego se puede usar para cargar el gráfico y los pesos en otro programa. , para uso (no necesariamente para continuar / reentrenamiento)

La documentación sobre este tema no es muy sencilla, por lo que cualquier respuesta / información sería muy apreciada.


2
La API más nueva / más completa es el meta gráfico, que le brinda una manera de guardar los tres a la vez: 1) gráfico 2) valores de parámetros 3) colecciones: tensorflow.org/versions/r0.10/how_tos/meta_graph/ index.html
Yaroslav Bulatov

Respuestas:


80

Hay muchas formas de abordar el problema de guardar un modelo en TensorFlow, lo que puede hacer que sea un poco confuso. Tomando cada una de sus sub-preguntas por turno:

  1. Los archivos de puntos de control (producidos, por ejemplo, al llamar saver.save()a un tf.train.Saverobjeto) contienen solo los pesos y cualquier otra variable definida en el mismo programa. Para usarlos en otro programa, debes volver a crear la estructura del gráfico asociado (p. Ej., Ejecutando código para compilarlo nuevamente o llamando tf.import_graph_def()), lo que le dice a TensorFlow qué hacer con esos pesos. Tenga en cuenta que la llamada saver.save()también produce un archivo que contiene a MetaGraphDef, que contiene un gráfico y detalles de cómo asociar los pesos de un punto de control con ese gráfico. Consulte el tutorial para obtener más detalles.

  2. tf.train.write_graph()solo escribe la estructura del gráfico; no los pesos.

  3. Bazel no está relacionado con la lectura o escritura de gráficos de TensorFlow. (Quizás no entiendo bien tu pregunta: no dudes en aclararla en un comentario).

  4. Un gráfico congelado se puede cargar usando tf.import_graph_def(). En este caso, los pesos están (normalmente) incrustados en el gráfico, por lo que no necesita cargar un punto de control por separado.

  5. El cambio principal sería actualizar los nombres de los tensores que se introducen en el modelo y los nombres de los tensores que se obtienen del modelo. En la demostración de TensorFlow para Android, esto correspondería a las cadenas inputNamey outputNameque se pasan a TensorFlowClassifier.initializeTensorFlow().

  6. El GraphDefes la estructura del programa, que por lo general no cambia a través del proceso de formación. El punto de control es una instantánea del estado de un proceso de formación, que normalmente cambia en cada paso del proceso de formación. Como resultado, TensorFlow usa diferentes formatos de almacenamiento para estos tipos de datos, y la API de bajo nivel proporciona diferentes formas de guardarlos y cargarlos. Las bibliotecas de nivel superior, como las MetaGraphDefbibliotecas, Keras y skflow, se basan en estos mecanismos para proporcionar formas más convenientes de guardar y restaurar un modelo completo.


¿Significa esto que la documentación de la API de C ++ miente, cuando dice que puede cargar el gráfico guardado con tf.train.write_graph()y luego ejecutarlo?
mnicky

2
La documentación de la API de C ++ no miente, pero faltan algunos detalles. El detalle más importante es que, además del GraphDefguardado por tf.train.write_graph(), también es necesario recordar los nombres de los tensores que desea alimentar y recuperar al ejecutar el gráfico (elemento 5 anterior).
mrry

@mrry: intenté usar el ejemplo de DeepDream de tensorflows. ¡pero parece que necesita modelos previamente entrenados en formato pb! Ejecuté el ejemplo de Cifar10, ¡pero solo crea puntos de control! ¡No pude encontrar ningún archivo pb ni nada! ¿Cómo puedo convertir mis puntos de control al formato pb que usa el ejemplo de deepdream?
Rika

2
@ Coderx7 Realmente creo que no se puede convertir un .ckpt en un .pb ya que el punto de control solo contiene los pesos y las variables y no sabe nada sobre la estructura del gráfico
davidivad

1
¿Existe un código simple para cargar un archivo .pb y luego ejecutarlo?
Kong

1

Puedes probar el siguiente código:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.