En pocas palabras, torch.Tensor.view()
que está inspirado en numpy.ndarray.reshape()
o numpy.reshape()
, crea una nueva vista del tensor, siempre que la nueva forma sea compatible con la forma del tensor original.
Comprendamos esto en detalle utilizando un ejemplo concreto.
In [43]: t = torch.arange(18)
In [44]: t
Out[44]:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17])
Con este tensor t
de la forma (18,)
, nuevos puntos de vista pueden solamente ser creados por las siguientes formas:
(1, 18)
o equivalente (1, -1)
o o equivalente o o equivalente o o equivalente o o equivalente o o equivalente o(-1, 18)
(2, 9)
(2, -1)
(-1, 9)
(3, 6)
(3, -1)
(-1, 6)
(6, 3)
(6, -1)
(-1, 3)
(9, 2)
(9, -1)
(-1, 2)
(18, 1)
(18, -1)
(-1, 1)
Como ya podemos observar en las tuplas de forma anteriores, la multiplicación de los elementos de la tupla de forma (por ejemplo 2*9
, 3*6
etc.) siempre debe ser igual al número total de elementos en el tensor original ( 18
en nuestro ejemplo).
Otra cosa a observar es que usamos un -1
en uno de los lugares en cada una de las tuplas de forma. Al usar a -1
, somos perezosos al hacer el cálculo nosotros mismos y delegamos la tarea a PyTorch para hacer el cálculo de ese valor para la forma cuando crea la nueva vista . Una cosa importante a tener en cuenta es que solo podemos usar una sola -1
en la tupla de forma. Los valores restantes deben ser suministrados explícitamente por nosotros. Else PyTorch se quejará lanzando un RuntimeError
:
RuntimeError: solo se puede inferir una dimensión
Entonces, con todas las formas mencionadas anteriormente, PyTorch siempre devolverá una nueva vista del tensor original t
. Esto básicamente significa que solo cambia la información de paso del tensor para cada una de las nuevas vistas que se solicitan.
A continuación se muestran algunos ejemplos que ilustran cómo se cambian los pasos de los tensores con cada nueva vista .
# stride of our original tensor `t`
In [53]: t.stride()
Out[53]: (1,)
Ahora, veremos los avances de las nuevas vistas :
# shape (1, 18)
In [54]: t1 = t.view(1, -1)
# stride tensor `t1` with shape (1, 18)
In [55]: t1.stride()
Out[55]: (18, 1)
# shape (2, 9)
In [56]: t2 = t.view(2, -1)
# stride of tensor `t2` with shape (2, 9)
In [57]: t2.stride()
Out[57]: (9, 1)
# shape (3, 6)
In [59]: t3 = t.view(3, -1)
# stride of tensor `t3` with shape (3, 6)
In [60]: t3.stride()
Out[60]: (6, 1)
# shape (6, 3)
In [62]: t4 = t.view(6,-1)
# stride of tensor `t4` with shape (6, 3)
In [63]: t4.stride()
Out[63]: (3, 1)
# shape (9, 2)
In [65]: t5 = t.view(9, -1)
# stride of tensor `t5` with shape (9, 2)
In [66]: t5.stride()
Out[66]: (2, 1)
# shape (18, 1)
In [68]: t6 = t.view(18, -1)
# stride of tensor `t6` with shape (18, 1)
In [69]: t6.stride()
Out[69]: (1, 1)
Esa es la magia de la view()
función. Simplemente cambia los pasos del tensor (original) para cada una de las nuevas vistas , siempre que la forma de la nueva vista sea compatible con la forma original.
Otra cosa interesante uno podría observar desde las tuplas zancadas es que el valor del elemento en el 0 º posición es igual al valor del elemento en el 1 st posición de la tupla forma.
In [74]: t3.shape
Out[74]: torch.Size([3, 6])
|
In [75]: t3.stride() |
Out[75]: (6, 1) |
|_____________|
Esto es porque:
In [76]: t3
Out[76]:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17]])
la zancada (6, 1)
dice que para ir de un elemento al siguiente elemento a lo largo de la 0 ª dimensión, tenemos que saltar o tomar 6 pasos. (es decir, para ir de 0
a 6
, uno tiene que tomar 6 pasos). Pero para ir de un elemento al siguiente elemento en la 1ª dimensión, solo necesitamos un paso (por ejemplo, ir de 2
a 3
).
Por lo tanto, la información de pasos está en el corazón de cómo se accede a los elementos desde la memoria para realizar el cálculo.
Esta función devolvería una vista y es exactamente lo mismo que usar torch.Tensor.view()
siempre que la nueva forma sea compatible con la forma del tensor original. De lo contrario, devolverá una copia.
Sin embargo, las notas de torch.reshape()
advierte que:
Las entradas contiguas y las entradas con zancadas compatibles se pueden reformar sin copiar, pero no se debe depender del comportamiento de copia vs. visualización.
reshape
en PyTorch?