El pepinillo biblioteca Python implementa protocolos binarios para serializar y deserializar un objeto Python.
Cuando usted import torch
(o cuando usa PyTorch) lo hará import pickle
por usted y no necesita llamar pickle.dump()
y pickle.load()
directamente, cuáles son los métodos para guardar y cargar el objeto.
De hecho, torch.save()
y torch.load()
lo envolverá pickle.dump()
y pickle.load()
para ti.
La state_dict
otra respuesta mencionada merece unas pocas notas más.
¿ state_dict
Qué tenemos dentro de PyTorch? En realidad hay dos state_dict
s.
El modelo PyTorch torch.nn.Module
tiene model.parameters()
llamada para obtener parámetros que se pueden aprender (w y b). Estos parámetros que se pueden aprender, una vez establecidos al azar, se actualizarán con el tiempo a medida que aprendamos. Los parámetros que se pueden aprender son los primeros state_dict
.
El segundo state_dict
es el optimizador de estado dict. Recuerda que el optimizador se utiliza para mejorar nuestros parámetros de aprendizaje. Pero el optimizadorstate_dict
es fijo. Nada que aprender allí.
Debido a que los state_dict
objetos son diccionarios de Python, se pueden guardar, actualizar, alterar y restaurar fácilmente, agregando una gran modularidad a los modelos y optimizadores de PyTorch.
Creemos un modelo súper simple para explicar esto:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Este código generará lo siguiente:
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
Tenga en cuenta que este es un modelo mínimo. Puede intentar agregar una pila de secuencial
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Tenga en cuenta que solo las capas con parámetros que se pueden aprender (capas convolucionales, capas lineales, etc.) y memorias intermedias registradas (capas de batchnorm) tienen entradas en el modelo state_dict
.
Cosas que no se pueden aprender, pertenecen al objeto optimizador state_dict
, que contiene información sobre el estado del optimizador, así como los hiperparámetros utilizados.
El resto de la historia es igual; en la fase de inferencia (esta es una fase cuando usamos el modelo después del entrenamiento) para predecir; predecimos en función de los parámetros que aprendimos. Entonces, para la inferencia, solo necesitamos guardar los parámetros model.state_dict()
.
torch.save(model.state_dict(), filepath)
Y para usar luego model.load_state_dict (torch.load (filepath)) model.eval ()
Nota: No olvide la última línea, model.eval()
esto es crucial después de cargar el modelo.
Tampoco intentes guardar torch.save(model.parameters(), filepath)
. El model.parameters()
es solo el objeto generador.
Por otro lado, torch.save(model, filepath)
guarda el objeto del modelo en sí, pero tenga en cuenta que el modelo no tiene el optimizador state_dict
. Verifique la otra excelente respuesta de @Jadiel de Armas para guardar la sentencia de estado del optimizador.