Estoy tratando de actualizar / cambiar los parámetros de un modelo de red neuronal y luego hacer que el paso directo de la red neuronal actualizada esté en el gráfico de cálculo (no importa cuántos cambios / actualizaciones hagamos).
Intenté esta idea, pero cada vez que lo hago, pytorch configura mis tensores actualizados (dentro del modelo) para que sean hojas, lo que mata el flujo de gradientes a las redes que quiero recibir. Mata el flujo de gradientes porque los nodos de hoja no son parte del gráfico de cálculo de la forma en que quiero que sean (ya que no son realmente hojas).
He intentado varias cosas pero nada parece funcionar. Creé un código ficticio autónomo que imprime los gradientes de las redes que deseo tener gradientes:
import torch
import torch.nn as nn
import copy
from collections import OrderedDict
# img = torch.randn([8,3,32,32])
# targets = torch.LongTensor([1, 2, 0, 6, 2, 9, 4, 9])
# img = torch.randn([1,3,32,32])
# targets = torch.LongTensor([1])
x = torch.randn(1)
target = 12.0*x**2
criterion = nn.CrossEntropyLoss()
#loss_net = nn.Sequential(OrderedDict([('conv0',nn.Conv2d(in_channels=3,out_channels=10,kernel_size=32))]))
loss_net = nn.Sequential(OrderedDict([('fc0', nn.Linear(in_features=1,out_features=1))]))
hidden = torch.randn(size=(1,1),requires_grad=True)
updater_net = nn.Sequential(OrderedDict([('fc0',nn.Linear(in_features=1,out_features=1))]))
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
#
nb_updates = 2
for i in range(nb_updates):
print(f'i = {i}')
new_params = copy.deepcopy( loss_net.state_dict() )
## w^<t> := f(w^<t-1>,delta^<t-1>)
for (name, w) in loss_net.named_parameters():
print(f'name = {name}')
print(w.size())
hidden = updater_net(hidden).view(1)
print(hidden.size())
#delta = ((hidden**2)*w/2)
delta = w + hidden
wt = w + delta
print(wt.size())
new_params[name] = wt
#del loss_net.fc0.weight
#setattr(loss_net.fc0, 'weight', nn.Parameter( wt ))
#setattr(loss_net.fc0, 'weight', wt)
#loss_net.fc0.weight = wt
#loss_net.fc0.weight = nn.Parameter( wt )
##
loss_net.load_state_dict(new_params)
#
print()
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
outputs = loss_net(x)
loss_val = 0.5*(target - outputs)**2
loss_val.backward()
print()
print(f'-- params that dont matter if they have gradients --')
print(f'loss_net.grad = {loss_net.fc0.weight.grad}')
print('-- params we want to have gradients --')
print(f'hidden.grad = {hidden.grad}')
print(f'updater_net.fc0.weight.grad = {updater_net.fc0.weight.grad}')
print(f'updater_net.fc0.bias.grad = {updater_net.fc0.bias.grad}')
si alguien sabe cómo hacer esto, por favor deme un ping ... configuré el número de veces para actualizar en 2 porque la operación de actualización debería estar en el gráfico de cálculo un número arbitrario de veces ... por lo que DEBE funcionar para 2)
Publicación fuertemente relacionada:
- SO: ¿Cómo se pueden tener parámetros en un modelo de pytorch que no sean hojas y estar en el gráfico de cálculo?
- foro de pytorch: https://discuss.pytorch.org/t/how-does-one-have-the-parameters-of-a-model-not-be-leafs/70076
Publicación cruzada:
backward
? A saberretain_graph=True
y / ocreate_graph=True
?