Diferencia entre Variable y get_variable en TensorFlow


125

Hasta donde yo sé, Variablees la operación predeterminada para hacer una variable, y get_variablese usa principalmente para compartir peso.

Por un lado, hay algunas personas que sugieren usar en get_variablelugar de la Variableoperación primitiva siempre que necesite una variable. Por otro lado, simplemente veo algún uso get_variableen los documentos y demostraciones oficiales de TensorFlow.

Por lo tanto, quiero conocer algunas reglas generales sobre cómo usar correctamente estos dos mecanismos. ¿Hay algún principio "estándar"?


66
get_variable es una forma nueva, la variable es antigua (lo que podría ser compatible para siempre), como dice Lukasz (PD: escribió gran parte del alcance del nombre de la variable en TF)
Yaroslav Bulatov

Respuestas:


90

Recomendaría usar siempre tf.get_variable(...): facilitará la refactorización de su código si necesita compartir variables en cualquier momento, por ejemplo, en una configuración multi-gpu (vea el ejemplo CIFAR multi-gpu). No hay inconveniente en ello.

Puro tf.Variablees de nivel inferior; en algún momento tf.get_variable()no existía, por lo que algunos códigos todavía usan la forma de bajo nivel.


55
Muchas gracias por tu respuesta. Pero todavía tengo una pregunta sobre cómo reemplazarla tf.Variableen tf.get_variabletodas partes. Es entonces cuando quiero inicializar una variable con una matriz numpy, no puedo encontrar una manera limpia y eficiente de hacerlo como lo hago con tf.Variable. ¿Cómo lo resuelves? Gracias.
Lifu Huang

68

tf.Variable es una clase, y hay varias formas de crear tf.Variable que incluyen tf.Variable.__init__y tf.get_variable.

tf.Variable.__init__: Crea una nueva variable con initial_value .

W = tf.Variable(<initial-value>, name=<optional-name>)

tf.get_variable: Obtiene una variable existente con estos parámetros o crea una nueva. También puedes usar initializer.

W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
       regularizer=None, trainable=True, collections=None)

Es muy útil usar inicializadores como xavier_initializer :

W = tf.get_variable("W", shape=[784, 256],
       initializer=tf.contrib.layers.xavier_initializer())

Más información aquí .


Sí, en Variablerealidad me refiero a usar su __init__. Como get_variablees tan conveniente, me pregunto por qué la mayoría del código TensorFlow que vi usar en Variablelugar de get_variable. ¿Hay algunas convenciones o factores a considerar al elegir entre ellos? ¡Gracias!
Lifu Huang

Si desea tener un cierto valor, usar Variable es simple: x = tf.Variable (3).
Sung Kim

@SungKim normalmente cuando usamos tf.Variable()podemos inicializarlo como un valor aleatorio de una distribución normal truncada. Aquí está mi ejemplo w1 = tf.Variable(tf.truncated_normal([5, 50], stddev = 0.01), name = 'w1'). ¿Cuál sería el equivalente de esto? ¿Cómo le digo que quiero una normal truncada? ¿Debo hacer w1 = tf.get_variable(name = 'w1', shape = [5,50], initializer = tf.truncated_normal, regularizer = tf.nn.l2_loss)?
Euler_Salter

@Euler_Salter: puede usar tf.truncated_normal_initializer()para obtener el resultado deseado.
Beta

46

Puedo encontrar dos diferencias principales entre una y otra:

  1. Primero es que tf.Variablesiempre creará una nueva variable, mientras que tf.get_variableobtiene una variable existente con parámetros específicos del gráfico, y si no existe, crea una nueva.

  2. tf.Variable requiere que se especifique un valor inicial.

Es importante aclarar que la función tf.get_variableantepone el nombre con el alcance de la variable actual para realizar verificaciones de reutilización. Por ejemplo:

with tf.variable_scope("one"):
    a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
    b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
    c = tf.get_variable("v", [1]) #c.name == "one/v:0"

with tf.variable_scope("two"):
    d = tf.get_variable("v", [1]) #d.name == "two/v:0"
    e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"

assert(a is c)  #Assertion is true, they refer to the same object.
assert(a is d)  #AssertionError: they are different objects
assert(d is e)  #AssertionError: they are different objects

El último error de aserción es interesante: se supone que dos variables con el mismo nombre bajo el mismo alcance son la misma variable. Pero si prueba los nombres de las variables dy ese dará cuenta de que Tensorflow cambió el nombre de la variable e:

d.name   #d.name == "two/v:0"
e.name   #e.name == "two/v_1:0"

Gran ejemplo! Con respecto d.namey e.name, acabo de venir a través de un documento en este TensorFlow tensor de operación de gráfico de denominación que lo explica:If the default graph already contained an operation named "answer", the TensorFlow would append "_1", "_2", and so on to the name, in order to make it unique.
Atlas7

2

Otra diferencia radica en que uno está en la ('variable_store',)colección pero el otro no.

Por favor vea el código fuente :

def _get_default_variable_store():
  store = ops.get_collection(_VARSTORE_KEY)
  if store:
    return store[0]
  store = _VariableStore()
  ops.add_to_collection(_VARSTORE_KEY, store)
  return store

Déjame ilustrarte eso:

import tensorflow as tf
from tensorflow.python.framework import ops

embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="word_embeddings_1", dtype=tf.float32) 
embedding_2 = tf.get_variable("word_embeddings_2", shape=[30522, 1024])

graph = tf.get_default_graph()
collections = graph.collections

for c in collections:
    stores = ops.get_collection(c)
    print('collection %s: ' % str(c))
    for k, store in enumerate(stores):
        try:
            print('\t%d: %s' % (k, str(store._vars)))
        except:
            print('\t%d: %s' % (k, str(store)))
    print('')

La salida:

collection ('__variable_store',): 0: {'word_embeddings_2': <tf.Variable 'word_embeddings_2:0' shape=(30522, 1024) dtype=float32_ref>}

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.