El pepinillo biblioteca Python implementa protocolos binarios para serializar y deserializar un objeto Python.
Cuando usted import torch(o cuando usa PyTorch) lo hará import picklepor usted y no necesita llamar pickle.dump()y pickle.load()directamente, cuáles son los métodos para guardar y cargar el objeto.
De hecho, torch.save()y torch.load()lo envolverá pickle.dump()y pickle.load()para ti.
La state_dictotra respuesta mencionada merece unas pocas notas más.
¿ state_dictQué tenemos dentro de PyTorch? En realidad hay dos state_dicts.
El modelo PyTorch torch.nn.Moduletiene model.parameters()llamada para obtener parámetros que se pueden aprender (w y b). Estos parámetros que se pueden aprender, una vez establecidos al azar, se actualizarán con el tiempo a medida que aprendamos. Los parámetros que se pueden aprender son los primeros state_dict.
El segundo state_dictes el optimizador de estado dict. Recuerda que el optimizador se utiliza para mejorar nuestros parámetros de aprendizaje. Pero el optimizadorstate_dict es fijo. Nada que aprender allí.
Debido a que los state_dictobjetos son diccionarios de Python, se pueden guardar, actualizar, alterar y restaurar fácilmente, agregando una gran modularidad a los modelos y optimizadores de PyTorch.
Creemos un modelo súper simple para explicar esto:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Este código generará lo siguiente:
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
Tenga en cuenta que este es un modelo mínimo. Puede intentar agregar una pila de secuencial
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Tenga en cuenta que solo las capas con parámetros que se pueden aprender (capas convolucionales, capas lineales, etc.) y memorias intermedias registradas (capas de batchnorm) tienen entradas en el modelo state_dict.
Cosas que no se pueden aprender, pertenecen al objeto optimizador state_dict , que contiene información sobre el estado del optimizador, así como los hiperparámetros utilizados.
El resto de la historia es igual; en la fase de inferencia (esta es una fase cuando usamos el modelo después del entrenamiento) para predecir; predecimos en función de los parámetros que aprendimos. Entonces, para la inferencia, solo necesitamos guardar los parámetros model.state_dict().
torch.save(model.state_dict(), filepath)
Y para usar luego model.load_state_dict (torch.load (filepath)) model.eval ()
Nota: No olvide la última línea, model.eval()esto es crucial después de cargar el modelo.
Tampoco intentes guardar torch.save(model.parameters(), filepath). El model.parameters()es solo el objeto generador.
Por otro lado, torch.save(model, filepath)guarda el objeto del modelo en sí, pero tenga en cuenta que el modelo no tiene el optimizador state_dict. Verifique la otra excelente respuesta de @Jadiel de Armas para guardar la sentencia de estado del optimizador.