¿Para qué sirve torch.no_grad en pytorch?


21

Soy nuevo en pytorch y comencé con este código github. No entiendo el comentario en la línea 60-61 en el código "because weights have requires_grad=True, but we don't need to track this in autograd". Comprendí que mencionamos requires_grad=Truelas variables que necesitamos para calcular los gradientes para usar el autogrado, pero ¿qué significa ser "tracked by autograd"?

Respuestas:


24

El contenedor "con torch.no_grad ()" establece temporalmente todos los distintivos require_grad en falso. Un ejemplo del tutorial oficial de PyTorch ( https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients ):

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

Fuera:

True
True
False

Le recomiendo que lea todos los tutoriales del sitio web anterior.

En su ejemplo: supongo que el autor no quiere que PyTorch calcule los gradientes de las nuevas variables definidas w1 y w2 ya que solo desea actualizar sus valores.


6
with torch.no_grad()

hará que todas las operaciones en el bloque no tengan gradientes.

En pytorch, no se puede cambiar la ubicación de w1 y w2, que son dos variables con require_grad = True. Creo que evitar el cambio de ubicación de w1 y w2 se debe a que provocará un error en el cálculo de la propagación inversa. Dado que el cambio de ubicación cambiará totalmente w1 y w2.

Sin embargo, si usa esto no_grad(), puede controlar que el nuevo w1 y el nuevo w2 no tengan gradientes ya que son generados por operaciones, lo que significa que solo cambia el valor de w1 y w2, no parte del gradiente, todavía tienen información de gradiente variable definida previamente y la propagación hacia atrás puede continuar.

Al usar nuestro sitio, usted reconoce que ha leído y comprende nuestra Política de Cookies y Política de Privacidad.
Licensed under cc by-sa 3.0 with attribution required.