¿Cómo funciona el método de "vista" en PyTorch?


206

Estoy confundido sobre el método view()en el siguiente fragmento de código.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool  = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

Mi confusión es con respecto a la siguiente línea.

x = x.view(-1, 16*5*5)

¿Qué hace la tensor.view()función? He visto su uso en muchos lugares, pero no puedo entender cómo interpreta sus parámetros.

¿Qué sucede si le doy valores negativos como parámetros a la view()función? Por ejemplo, ¿qué pasa si llamo tensor_variable.view(1, 1, -1),?

¿Alguien puede explicar el principio principal de la view()función con algunos ejemplos?

Respuestas:


284

La función de vista está destinada a remodelar el tensor.

Digamos que tienes un tensor

import torch
a = torch.range(1, 16)

aes un tensor que tiene 16 elementos del 1 al 16 (incluidos). Si desea remodelar este tensor para convertirlo en 4 x 4tensor, puede usar

a = a.view(4, 4)

Ahora aserá un 4 x 4tensor. Tenga en cuenta que después de la remodelación, el número total de elementos debe permanecer igual. Cambiar la forma del tensor aa un 3 x 5tensor no sería apropiado.

¿Cuál es el significado del parámetro -1?

Si hay alguna situación en la que no sabe cuántas filas desea pero está seguro del número de columnas, puede especificar esto con un -1. ( Tenga en cuenta que puede extender esto a tensores con más dimensiones. Solo uno de los valores del eje puede ser -1 ). Esta es una forma de decirle a la biblioteca: "dame un tensor que tenga tantas columnas y calcules el número apropiado de filas que es necesario para que esto suceda".

Esto se puede ver en el código de red neuronal que ha proporcionado anteriormente. Después de la línea x = self.pool(F.relu(self.conv2(x)))en la función de avance, tendrá un mapa de características de 16 profundidades. Tienes que aplanar esto para darle a la capa completamente conectada. Entonces le dice a pytorch que cambie la forma del tensor que obtuvo para tener un número específico de columnas y le dice que decida el número de filas por sí mismo.

Dibujar una similitud entre numpy y pytorch viewes similar a la función de remodelación de numpy .


93
"la vista es similar a la remodelación de numpy" - ¿por qué no lo llamaron reshapeen PyTorch?
MaxB

54
@MaxB A diferencia de la remodelación, el nuevo tensor devuelto por "view" comparte los datos subyacentes con el tensor original, por lo que es realmente una vista del tensor antiguo en lugar de crear uno nuevo.
qihqi

37
@blckbird "la remodelación siempre copia la memoria. la vista nunca copia la memoria". github.com/torch/cutorch/issues/98
devinbost

3
@devinbost La remodelación de la antorcha siempre copia la memoria. NumPy remodelar no lo hace.
Tavian Barnes

32

Hagamos algunos ejemplos, de más simple a más difícil.

  1. El viewmétodo devuelve un tensor con los mismos datos que el selftensor (lo que significa que el tensor devuelto tiene el mismo número de elementos), pero con una forma diferente. Por ejemplo:

    a = torch.arange(1, 17)  # a's shape is (16,)
    
    a.view(4, 4) # output below
      1   2   3   4
      5   6   7   8
      9  10  11  12
     13  14  15  16
    [torch.FloatTensor of size 4x4]
    
    a.view(2, 2, 4) # output below
    (0 ,.,.) = 
    1   2   3   4
    5   6   7   8
    
    (1 ,.,.) = 
     9  10  11  12
    13  14  15  16
    [torch.FloatTensor of size 2x2x4]
  2. Suponiendo que ese -1no es uno de los parámetros, cuando los multiplica, el resultado debe ser igual al número de elementos en el tensor. Si lo hace: a.view(3, 3)generará una RuntimeErrorforma porque (3 x 3) no es válida para la entrada con 16 elementos. En otras palabras: 3 x 3 no es igual a 16 sino a 9.

  3. Puede usar -1uno de los parámetros que pasa a la función, pero solo una vez. Todo lo que sucede es que el método hará los cálculos matemáticos sobre cómo llenar esa dimensión. Por ejemplo a.view(2, -1, 4)es equivalente a a.view(2, 2, 4). [16 / (2 x 4) = 2]

  4. Observe que el tensor devuelto comparte los mismos datos . Si realiza un cambio en la "vista", está cambiando los datos del tensor original:

    b = a.view(4, 4)
    b[0, 2] = 2
    a[2] == 3.0
    False
  5. Ahora, para un caso de uso más complejo. La documentación dice que cada nueva dimensión de vista debe ser un subespacio de una dimensión original, o solo abarcar d, d + 1, ..., d + k que satisfagan la siguiente condición similar a la contigüidad que para todo i = 0,. .., k - 1, zancada [i] = zancada [i + 1] x tamaño [i + 1] . De lo contrario, contiguous()debe llamarse antes de que se pueda ver el tensor. Por ejemplo:

    a = torch.rand(5, 4, 3, 2) # size (5, 4, 3, 2)
    a_t = a.permute(0, 2, 3, 1) # size (5, 3, 2, 4)
    
    # The commented line below will raise a RuntimeError, because one dimension
    # spans across two contiguous subspaces
    # a_t.view(-1, 4)
    
    # instead do:
    a_t.contiguous().view(-1, 4)
    
    # To see why the first one does not work and the second does,
    # compare a.stride() and a_t.stride()
    a.stride() # (24, 6, 2, 1)
    a_t.stride() # (24, 2, 1, 6)

    Tenga en cuenta que for a_t, stride [0]! = Stride [1] x size [1] since 24! = 2 x 3


7

torch.Tensor.view()

En pocas palabras, torch.Tensor.view()que está inspirado en numpy.ndarray.reshape()o numpy.reshape(), crea una nueva vista del tensor, siempre que la nueva forma sea compatible con la forma del tensor original.

Comprendamos esto en detalle utilizando un ejemplo concreto.

In [43]: t = torch.arange(18) 

In [44]: t 
Out[44]: 
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17])

Con este tensor tde la forma (18,), nuevos puntos de vista pueden solamente ser creados por las siguientes formas:

(1, 18)o equivalente (1, -1)o o equivalente o o equivalente o o equivalente o o equivalente o o equivalente o(-1, 18)
(2, 9)(2, -1)(-1, 9)
(3, 6)(3, -1)(-1, 6)
(6, 3)(6, -1)(-1, 3)
(9, 2)(9, -1)(-1, 2)
(18, 1)(18, -1)(-1, 1)

Como ya podemos observar en las tuplas de forma anteriores, la multiplicación de los elementos de la tupla de forma (por ejemplo 2*9, 3*6etc.) siempre debe ser igual al número total de elementos en el tensor original ( 18en nuestro ejemplo).

Otra cosa a observar es que usamos un -1en uno de los lugares en cada una de las tuplas de forma. Al usar a -1, somos perezosos al hacer el cálculo nosotros mismos y delegamos la tarea a PyTorch para hacer el cálculo de ese valor para la forma cuando crea la nueva vista . Una cosa importante a tener en cuenta es que solo podemos usar una sola -1en la tupla de forma. Los valores restantes deben ser suministrados explícitamente por nosotros. Else PyTorch se quejará lanzando un RuntimeError:

RuntimeError: solo se puede inferir una dimensión

Entonces, con todas las formas mencionadas anteriormente, PyTorch siempre devolverá una nueva vista del tensor original t. Esto básicamente significa que solo cambia la información de paso del tensor para cada una de las nuevas vistas que se solicitan.

A continuación se muestran algunos ejemplos que ilustran cómo se cambian los pasos de los tensores con cada nueva vista .

# stride of our original tensor `t`
In [53]: t.stride() 
Out[53]: (1,)

Ahora, veremos los avances de las nuevas vistas :

# shape (1, 18)
In [54]: t1 = t.view(1, -1)
# stride tensor `t1` with shape (1, 18)
In [55]: t1.stride() 
Out[55]: (18, 1)

# shape (2, 9)
In [56]: t2 = t.view(2, -1)
# stride of tensor `t2` with shape (2, 9)
In [57]: t2.stride()       
Out[57]: (9, 1)

# shape (3, 6)
In [59]: t3 = t.view(3, -1) 
# stride of tensor `t3` with shape (3, 6)
In [60]: t3.stride() 
Out[60]: (6, 1)

# shape (6, 3)
In [62]: t4 = t.view(6,-1)
# stride of tensor `t4` with shape (6, 3)
In [63]: t4.stride() 
Out[63]: (3, 1)

# shape (9, 2)
In [65]: t5 = t.view(9, -1) 
# stride of tensor `t5` with shape (9, 2)
In [66]: t5.stride()
Out[66]: (2, 1)

# shape (18, 1)
In [68]: t6 = t.view(18, -1)
# stride of tensor `t6` with shape (18, 1)
In [69]: t6.stride()
Out[69]: (1, 1)

Esa es la magia de la view()función. Simplemente cambia los pasos del tensor (original) para cada una de las nuevas vistas , siempre que la forma de la nueva vista sea ​​compatible con la forma original.

Otra cosa interesante uno podría observar desde las tuplas zancadas es que el valor del elemento en el 0 º posición es igual al valor del elemento en el 1 st posición de la tupla forma.

In [74]: t3.shape 
Out[74]: torch.Size([3, 6])
                        |
In [75]: t3.stride()    |
Out[75]: (6, 1)         |
          |_____________|

Esto es porque:

In [76]: t3 
Out[76]: 
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [12, 13, 14, 15, 16, 17]])

la zancada (6, 1)dice que para ir de un elemento al siguiente elemento a lo largo de la 0 ª dimensión, tenemos que saltar o tomar 6 pasos. (es decir, para ir de 0a 6, uno tiene que tomar 6 pasos). Pero para ir de un elemento al siguiente elemento en la dimensión, solo necesitamos un paso (por ejemplo, ir de 2a 3).

Por lo tanto, la información de pasos está en el corazón de cómo se accede a los elementos desde la memoria para realizar el cálculo.


torch.reshape ()

Esta función devolvería una vista y es exactamente lo mismo que usar torch.Tensor.view()siempre que la nueva forma sea compatible con la forma del tensor original. De lo contrario, devolverá una copia.

Sin embargo, las notas de torch.reshape()advierte que:

Las entradas contiguas y las entradas con zancadas compatibles se pueden reformar sin copiar, pero no se debe depender del comportamiento de copia vs. visualización.


1

Me di cuenta de que x.view(-1, 16 * 5 * 5)es equivalente a x.flatten(1), donde el parámetro 1 indica que el proceso de aplanar comienza desde la primera dimensión (sin aplanar la dimensión de 'muestra') Como puede ver, el último uso es semánticamente más claro y más fácil de usar, por lo que prefieren flatten().


1

¿Cuál es el significado del parámetro -1?

Puede leer -1como número dinámico de parámetros o "cualquier cosa". Por eso solo puede haber un parámetro -1en view().

Si pregunta x.view(-1,1)esto, generará una forma de tensor [anything, 1]dependiendo del número de elementos en x. Por ejemplo:

import torch
x = torch.tensor([1, 2, 3, 4])
print(x,x.shape)
print("...")
print(x.view(-1,1), x.view(-1,1).shape)
print(x.view(1,-1), x.view(1,-1).shape)

Saldrá:

tensor([1, 2, 3, 4]) torch.Size([4])
...
tensor([[1],
        [2],
        [3],
        [4]]) torch.Size([4, 1])
tensor([[1, 2, 3, 4]]) torch.Size([1, 4])

1

weights.reshape(a, b) devolverá un nuevo tensor con los mismos datos que los pesos con tamaño (a, b) ya que copia los datos en otra parte de la memoria.

weights.resize_(a, b)devuelve el mismo tensor con una forma diferente. Sin embargo, si la nueva forma da como resultado menos elementos que el tensor original, algunos elementos se eliminarán del tensor (pero no de la memoria). Si la nueva forma da como resultado más elementos que el tensor original, los nuevos elementos no se inicializarán en la memoria.

weights.view(a, b) devolverá un nuevo tensor con los mismos datos que los pesos con tamaño (a, b)


0

Realmente me gustaron los ejemplos de @Jadiel de Armas.

Me gustaría agregar una pequeña idea de cómo se ordenan los elementos para .view (...)

  • Para un tensor con forma (a, b, c) , el orden de la misma de los elementos se determina por un sistema de numeración: donde el primer dígito tiene un número, segundo dígito tiene b números y tercer dígito tiene c números.
  • La asignación de los elementos en el nuevo Tensor devuelto por .view (...) conserva este orden del Tensor original.

0

Tratemos de entender la vista con los siguientes ejemplos:

    a=torch.range(1,16)

print(a)

    tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
            15., 16.])

print(a.view(-1,2))

    tensor([[ 1.,  2.],
            [ 3.,  4.],
            [ 5.,  6.],
            [ 7.,  8.],
            [ 9., 10.],
            [11., 12.],
            [13., 14.],
            [15., 16.]])

print(a.view(2,-1,4))   #3d tensor

    tensor([[[ 1.,  2.,  3.,  4.],
             [ 5.,  6.,  7.,  8.]],

            [[ 9., 10., 11., 12.],
             [13., 14., 15., 16.]]])
print(a.view(2,-1,2))

    tensor([[[ 1.,  2.],
             [ 3.,  4.],
             [ 5.,  6.],
             [ 7.,  8.]],

            [[ 9., 10.],
             [11., 12.],
             [13., 14.],
             [15., 16.]]])

print(a.view(4,-1,2))

    tensor([[[ 1.,  2.],
             [ 3.,  4.]],

            [[ 5.,  6.],
             [ 7.,  8.]],

            [[ 9., 10.],
             [11., 12.]],

            [[13., 14.],
             [15., 16.]]])

-1 como valor de argumento es una manera fácil de calcular el valor de decir x siempre que conozcamos los valores de y, z o al revés en el caso de 3d y para 2d nuevamente, una manera fácil de calcular el valor de decir x siempre que saber valores de y o viceversa.

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.