(pytorch-深度学习系列)读取和存储数据-学习笔记
生活随笔
收集整理的這篇文章主要介紹了
(pytorch-深度学习系列)读取和存储数据-学习笔记
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
讀取和存儲數據
我們可以使用pt文件存儲Tensor數據:
import torch from torch import nnx = torch.ones(3) torch.save(x, 'x.pt')這樣我們就將數據存儲在名為x.pt的文件中了
我們可以從文件中將該數據讀入內存:
還可以存儲Tensor列表到文件中,并讀取:
y = torch.zeros(4) torch.save([x, y], "xy.pt") xy_list = torch.load("xy.pt") print(xy_list)不僅如此,還可以存儲一個鍵值為Tensor變量的字典:
torch.save({'x':x, 'y':y}, "xy_dict") xy_dict = torch.load("xy_dict") print(xy_dict)對模型參數進行讀寫:
對于Module類的對象,我們可以使用model.parameters()函數來訪問模型的參數。而state_dict函數將會返回一個模型的參數名稱到參數Tensor對象的一個字典對象。
class my_module(mm.Module):def __init__(self):super(my_module, self)self.hidden = nn.Linear(3, 2)self.action = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):middle = self.action(self.hidden(x))return self.output(middle) net = my_module() net.state_dict()輸出:
OrderedDict([('hidden.weight', tensor([[ 0.2448, 0.1856, -0.5678],[ 0.2030, -0.2073, -0.0104]])),('hidden.bias', tensor([-0.3117, -0.4232])),('output.weight', tensor([[-0.4556, 0.4084]])),('output.bias', tensor([-0.3573]))])但是,只有具有可變參數(可學習參數)的網絡層才會在state_dict中,
同樣的,優化器(optim)也有一個state_dict,這個函數返回一個字典,該字典包含優化器的狀態以及其超參數信息:
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) optimizer.state_dict()輸出:
{'param_groups': [{'dampening': 0,'lr': 0.001,'momentum': 0.9,'nesterov': False,'params': [4736167728, 4736166648, 4736167368, 4736165352],'weight_decay': 0}],'state': {}}那么就可以通過保存模型的state_dict來保存模型:
torch.save(net.state_dict(), PATH)model = my_module(*args, **kwargs) model.load_state_dict(torch.load(PATH))還可以直接保存整個模型:
torch.save(model, PATH) model = torch.load(PATH)總結
以上是生活随笔為你收集整理的(pytorch-深度学习系列)读取和存储数据-学习笔记的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 她是法国数学界的“花木兰”,高斯的“救命
- 下一篇: 比尔•盖茨当选中国工程院外籍院士!(附名