Pytorch迁移学习加载部分预训练权重
生活随笔
收集整理的這篇文章主要介紹了
Pytorch迁移学习加载部分预训练权重
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
遷移學習在圖像分類領域非常常見,利用在超大數據集上訓練得到的網絡權重,遷移到自己的數據上進行訓練可以節約大量的訓練時間,降低欠擬合/過擬合的風險。
如果用原生網絡進行遷移學習非常簡單,其核心是
model.load_state_dict()以Pytorch中官方提供的Resnet加載預訓練權重的代碼為例:
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)if pretrained:model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))但很多時候,我們可能需要對原生網絡做一些修改,比如自定義地增加一些網絡層,改變某些網絡層的結構等等,這時候如果直接像上面那樣直接加載就會報錯。
AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.
因為網絡結構發生了變化,預訓練權重是以字典的形式存儲的,它會和當前網絡結構的字典對應不上。
因此,我們需要通過加載部分預訓練權重的方式來進行初始化。
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)if pretrained:pretrained_dict = model_zoo.load_url(model_urls['resnet34'])model_dict = model.state_dict()# 篩除不加載的層結構pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}# 更新當前網絡的結構字典model_dict.update(pretrained_dict)model.load_state_dict(model_dict)通過以上簡單的代碼即可實現。
?
參考:https://blog.csdn.net/weixin_37978645/article/details/80788955
總結
以上是生活随笔為你收集整理的Pytorch迁移学习加载部分预训练权重的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 《流浪地球》海报丨见证小破球24亿票房逆
- 下一篇: VUE初级试炼1