Pytorch中参数和模型的保存与读取
生活随笔
收集整理的這篇文章主要介紹了
Pytorch中参数和模型的保存与读取
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
Tensor變量的存取(包括parameter)
對于普通Tensor變量的存取,如下代碼所示:
import torch import torch.nn as nn x = torch.ones(3) torch.save(x,'x.pt') x2 = torch.load('x.pt') print(x2)讀寫模型參數
保存模型參數
torch.save(net.state_dict(),'model_param.pth')載入模型參數
mynet = MLP() mynet.load_state_dict(torch.load('model_param.pth')) mynet.state_dict()保存和讀取整個模型
模型的保存
print(net(x)) torch.save(net,'model.pth')模型載入
mynets = torch.load('model.pth') mynets(x)總結
以上是生活随笔為你收集整理的Pytorch中参数和模型的保存与读取的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: docker 其他电脑访问权限_dock
- 下一篇: 二十万字C/C++、嵌入式软开面试题全集