python torch exp_Python:PyTorch 保存和加载训练过的网络 (八十)
保存和加載模型
在這個(gè) notebook 中,我將為你展示如何使用 Pytorch 來(lái)保存和加載模型。這個(gè)步驟十分重要,因?yàn)槟阋欢ㄏM軌蚣虞d預(yù)先訓(xùn)練好的模型來(lái)進(jìn)行預(yù)測(cè),或是根據(jù)新數(shù)據(jù)繼續(xù)訓(xùn)練。
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms
import helper
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# Download and load the training data
trainset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Download and load the test data
testset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
在這里我們可以看見(jiàn)一張圖片。
image, label = next(iter(trainloader))
helper.imshow(image[0,:]);
構(gòu)建網(wǎng)絡(luò)
在這里,我將使用與第五部分中一樣的模型。
class Network(nn.Module):
def __init__(self, input_size, output_size, hidden_layers, drop_p=0.5):
''' Builds a feedforward network with arbitrary hidden layers.
Arguments
---------
input_size: integer, size of the input layer
output_size: integer, size of the output layer
hidden_layers: list of integers, the sizes of the hidden layers
'''
super().__init__()
# Input to a hidden layer
self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_layers[0])])
# Add a variable number of more hidden layers
layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])
self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes])
self.output = nn.Linear(hidden_layers[-1], output_size)
self.dropout = nn.Dropout(p=drop_p)
def forward(self, x):
''' Forward pass through the network, returns the output logits '''
for each in self.hidden_layers:
x = F.relu(each(x))
x = self.dropout(x)
x = self.output(x)
return F.log_softmax(x, dim=1)
訓(xùn)練網(wǎng)絡(luò)
并使用之前一樣的方法來(lái)訓(xùn)練網(wǎng)絡(luò)。
# Create the network, define the criterion and optimizer
model = Network(784, 10, [500, 100])
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 2
steps = 0
running_loss = 0
print_every = 100
for e in range(epochs):
for images, labels in iter(trainloader):
steps += 1
# Flatten images into a 784 long vector
images.resize_(images.size()[0], 784)
# Wrap images and labels in Variables so we can calculate gradients
inputs = Variable(images)
targets = Variable(labels)
optimizer.zero_grad()
output = model.forward(inputs)
loss = criterion(output, targets)
loss.backward()
optimizer.step()
running_loss += loss.data[0]
if steps % print_every == 0:
# Model in inference mode, dropout is off
model.eval()
accuracy = 0
test_loss = 0
for ii, (images, labels) in enumerate(testloader):
images = images.resize_(images.size()[0], 784)
# Set volatile to True so we don't save the history
inputs = Variable(images, volatile=True)
labels = Variable(labels, volatile=True)
output = model.forward(inputs)
test_loss += criterion(output, labels).data[0]
## Calculating the accuracy
# Model's output is log-softmax, take exponential to get the probabilities
ps = torch.exp(output).data
# Class with highest probability is our predicted class, compare with true label
equality = (labels.data == ps.max(1)[1])
# Accuracy is number of correct predictions divided by all predictions, just take the mean
accuracy += equality.type_as(torch.FloatTensor()).mean()
print("Epoch: {}/{}.. ".format(e+1, epochs),
"Training Loss: {:.3f}.. ".format(running_loss/print_every),
"Test Loss: {:.3f}.. ".format(test_loss/len(testloader)),
"Test Accuracy: {:.3f}".format(accuracy/len(testloader)))
running_loss = 0
# Make sure dropout is on for training
model.train()
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:21: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:33: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:34: UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:37: UserWarning: invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
Epoch: 1/2.. Training Loss: 1.114.. Test Loss: 0.655.. Test Accuracy: 0.752
Epoch: 1/2.. Training Loss: 0.749.. Test Loss: 0.594.. Test Accuracy: 0.781
Epoch: 1/2.. Training Loss: 0.654.. Test Loss: 0.567.. Test Accuracy: 0.784
Epoch: 1/2.. Training Loss: 0.621.. Test Loss: 0.498.. Test Accuracy: 0.811
Epoch: 1/2.. Training Loss: 0.600.. Test Loss: 0.518.. Test Accuracy: 0.807
Epoch: 1/2.. Training Loss: 0.551.. Test Loss: 0.494.. Test Accuracy: 0.816
Epoch: 1/2.. Training Loss: 0.565.. Test Loss: 0.476.. Test Accuracy: 0.824
Epoch: 1/2.. Training Loss: 0.561.. Test Loss: 0.479.. Test Accuracy: 0.821
Epoch: 1/2.. Training Loss: 0.522.. Test Loss: 0.476.. Test Accuracy: 0.827
Epoch: 2/2.. Training Loss: 0.539.. Test Loss: 0.461.. Test Accuracy: 0.831
Epoch: 2/2.. Training Loss: 0.523.. Test Loss: 0.450.. Test Accuracy: 0.832
Epoch: 2/2.. Training Loss: 0.511.. Test Loss: 0.454.. Test Accuracy: 0.833
Epoch: 2/2.. Training Loss: 0.511.. Test Loss: 0.451.. Test Accuracy: 0.831
Epoch: 2/2.. Training Loss: 0.508.. Test Loss: 0.447.. Test Accuracy: 0.834
Epoch: 2/2.. Training Loss: 0.492.. Test Loss: 0.448.. Test Accuracy: 0.838
Epoch: 2/2.. Training Loss: 0.486.. Test Loss: 0.440.. Test Accuracy: 0.833
Epoch: 2/2.. Training Loss: 0.505.. Test Loss: 0.427.. Test Accuracy: 0.845
Epoch: 2/2.. Training Loss: 0.488.. Test Loss: 0.441.. Test Accuracy: 0.837
保存和加載模型
可以想象,在每次使用神經(jīng)網(wǎng)絡(luò)時(shí)都重新進(jìn)行訓(xùn)練很不現(xiàn)實(shí)。因此,我們可以保存之前訓(xùn)練好的網(wǎng)絡(luò),并在繼續(xù)訓(xùn)練或是進(jìn)行預(yù)測(cè)時(shí)加載網(wǎng)絡(luò)。
PyTorch 網(wǎng)絡(luò)的參數(shù)都存儲(chǔ)在模型的 state_dict 中。我們可以看到這個(gè)狀態(tài)字典包含了每個(gè)層的權(quán)重和偏差矩陣。
print("Our model: \n\n", model, '\n')
print("The state dict keys: \n\n", model.state_dict().keys())
Our model:
Network(
(hidden_layers): ModuleList(
(0): Linear(in_features=784, out_features=500, bias=True)
(1): Linear(in_features=500, out_features=100, bias=True)
)
(output): Linear(in_features=100, out_features=10, bias=True)
(dropout): Dropout(p=0.5)
)
The state dict keys:
odict_keys(['hidden_layers.0.weight', 'hidden_layers.0.bias', 'hidden_layers.1.weight', 'hidden_layers.1.bias', 'output.weight', 'output.bias'])
# Our network:
## Network((hidden_layers): ModuleList((0): Linear(in_features=784, out_features=500)
## (1): Linear(in_features=500, out_features=100))
## (output): Linear(in_features=100, out_features=10)
## )
# The state dict keys:
# odict_keys(['hidden_layers.0.weight', 'hidden_layers.0.bias', 'hidden_layers.1.weight', 'hidden_layers.1.bias', 'output.weight', 'output.bias'])
最簡(jiǎn)單的做法是使用 torch.save 來(lái)保存狀態(tài)字典。比如,我們可以將它保存到文件 'checkpoint.pth' 中。
torch.save(model.state_dict(), 'checkpoint.pth')
接著,我們可以使用 torch.load 來(lái)加載這個(gè)狀態(tài)字典。
state_dict = torch.load('checkpoint.pth')
print(state_dict.keys())
odict_keys(['hidden_layers.0.weight', 'hidden_layers.0.bias', 'hidden_layers.1.weight', 'hidden_layers.1.bias', 'output.weight', 'output.bias'])
#odict_keys(['hidden_layers.0.weight', 'hidden_layers.0.bias', 'hidden_layers.1.weight', 'hidden_layers.1.bias', 'output.weight', 'output.bias'])
要將狀態(tài)字典加載到神經(jīng)網(wǎng)絡(luò)中,你需要使用 model.load_state_dict(state_dict)'。
model.load_state_dict(state_dict)
這看上去十分簡(jiǎn)單,但實(shí)際情況更加復(fù)雜。只有當(dāng)模型結(jié)構(gòu)與檢查點(diǎn)的結(jié)構(gòu)完全一致時(shí),狀態(tài)字典才能成功加載。如果我在創(chuàng)建模型時(shí)使用了不同的結(jié)構(gòu),便無(wú)法順利加載。
# Try this
net = Network(784, 10, [400, 200, 100])
# This will throw an error because the tensor sizes are wrong!
net.load_state_dict(state_dict)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
in ()
2 net = Network(784, 10, [400, 200, 100])
3 # This will throw an error because the tensor sizes are wrong!
----> 4 net.load_state_dict(state_dict)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
719 if len(error_msgs) > 0:
720 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 721 self.__class__.__name__, "\n\t".join(error_msgs)))
722
723 def parameters(self):
RuntimeError: Error(s) in loading state_dict for Network:
Missing key(s) in state_dict: "hidden_layers.2.weight", "hidden_layers.2.bias".
While copying the parameter named "hidden_layers.0.weight", whose dimensions in the model are torch.Size([400, 784]) and whose dimensions in the checkpoint are torch.Size([500, 784]).
While copying the parameter named "hidden_layers.0.bias", whose dimensions in the model are torch.Size([400]) and whose dimensions in the checkpoint are torch.Size([500]).
While copying the parameter named "hidden_layers.1.weight", whose dimensions in the model are torch.Size([200, 400]) and whose dimensions in the checkpoint are torch.Size([100, 500]).
While copying the parameter named "hidden_layers.1.bias", whose dimensions in the model are torch.Size([200]) and whose dimensions in the checkpoint are torch.Size([100]).
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
481 try:
--> 482 own_state[name].copy_(param)
483 except Exception:
RuntimeError: inconsistent tensor size, expected tensor [400 x 784] and src [500 x 784] to have the same number of elements, but got 313600 and 392000 elements respectively at /Users/soumith/minicondabuild3/conda-bld/pytorch_1512381214802/work/torch/lib/TH/generic/THTensorCopy.c:86
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
in ()
2 net = Network(784, 10, [400, 200, 100])
3 # This will throw an error because the tensor sizes are wrong!
----> 4 net.load_state_dict(state_dict)
~/miniconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
485 'whose dimensions in the model are {} and '
486 'whose dimensions in the checkpoint are {}.'
--> 487 .format(name, own_state[name].size(), param.size()))
488 elif strict:
489 raise KeyError('unexpected key "{}" in state_dict'
RuntimeError: While copying the parameter named hidden_layers.0.weight, whose dimensions in the model are torch.Size([400, 784]) and whose dimensions in the checkpoint are torch.Size([500, 784]).
這意味著我們需要重建一個(gè)與訓(xùn)練時(shí)完全相同的模型。有關(guān)模型結(jié)構(gòu)的信息需要與狀態(tài)字典一起存儲(chǔ)在檢查點(diǎn)中。為了做到這一點(diǎn),你需要構(gòu)建一個(gè)字典,字典中包含重建模型的全部信息。
checkpoint = {'input_size': 784,
'output_size': 10,
'hidden_layers': [each.out_features for each in model.hidden_layers],
'state_dict': model.state_dict()}
torch.save(checkpoint, 'checkpoint.pth')
現(xiàn)在,檢查點(diǎn)中包含了重建訓(xùn)練模型所需的全部信息。你可以隨意將它編寫為函數(shù)。相似地,我們也可以編寫一個(gè)函數(shù)來(lái)加載檢查點(diǎn)。
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
model = Network(checkpoint['input_size'],
checkpoint['output_size'],
checkpoint['hidden_layers'])
model.load_state_dict(checkpoint['state_dict'])
return model
model = load_checkpoint('checkpoint.pth')
print(model)
Network(
(hidden_layers): ModuleList(
(0): Linear(in_features=784, out_features=500, bias=True)
(1): Linear(in_features=500, out_features=100, bias=True)
)
(output): Linear(in_features=100, out_features=10, bias=True)
(dropout): Dropout(p=0.5)
)
"""
Network(
(hidden_layers): ModuleList(
(0): Linear(in_features=784, out_features=500)
(1): Linear(in_features=500, out_features=100)
)
(output): Linear(in_features=100, out_features=10)
)
"""
為者常成,行者常至
總結(jié)
以上是生活随笔為你收集整理的python torch exp_Python:PyTorch 保存和加载训练过的网络 (八十)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: visio 模板_盒图模板一键套用,便捷
- 下一篇: sharedpreferences 重启