首先很多網上的博客,講的都不對,自己跟著他們踩了很多坑
1.單卡訓練,單卡加載
這里我為了把三個模塊save到同一個文件里,我選擇對所有的模型先封裝成一個checkpoint字典,然后保存到同一個文件里,這樣就可以在加載時只需要加載一個參數文件。
保存:
states
= {'state_dict_encoder': encoder
.state_dict
(),'state_dict_decoder': decoder
.state_dict
(),}
torch
.save
(states
, fname
)
加載:
encoder
= Encoder
()
decoder
= Decoder
()
checkpoint
= torch
.load
(model_path
)
encoder_state_dict
=checkpoint
['state_dict_encoder']
decoder_state_dict
=checkpoint
['state_dict_decoder']
encoder
.load_state_dict
(encoder_state_dict
)
decoder
.load_state_dict
(decoder_state_dict
)
2.單卡訓練,多卡加載
保存:
保存過程一樣,不做任何改變
states
= {'state_dict_encoder': encoder
.state_dict
(),'state_dict_decoder': decoder
.state_dict
(),}
torch
.save
(states
, fname
)
加載:
加載過程也沒有任何改變,但是要注意,先加載模型參數,再對模型做并行化處理
encoder
= Encoder
()
decoder
= Decoder
()
checkpoint
= torch
.load
(model_path
)
encoder_state_dict
=checkpoint
['state_dict_encoder']
decoder_state_dict
=checkpoint
['state_dict_decoder']
encoder
.load_state_dict
(encoder_state_dict
)
decoder
.load_state_dict
(decoder_state_dict
)
encoder
= nn
.DataParallel
(encoder
)
decoder
= nn
.DataParallel
(decoder
)
3.多卡訓練,單卡加載
注意,如果你考慮到以后可能需要單卡加載你多卡訓練的模型,建議在保存模型時,去除模型參數字典里面的module,如何去除呢,使用model.module.state_dict()代替model.state_dict()
保存:
states
= {'state_dict_encoder': encoder
.module
.state_dict
(), 'state_dict_decoder': decoder
.module
.state_dict
(),}
torch
.save
(states
, fname
)
加載:
要注意由于我們保存的方式是以單卡的方式保存的,所以還是要先加載模型參數,再對模型做并行化處理
encoder
= Encoder
()
decoder
= Decoder
()
checkpoint
= torch
.load
(model_path
)
encoder_state_dict
=checkpoint
['state_dict_encoder']
decoder_state_dict
=checkpoint
['state_dict_decoder']
encoder
.load_state_dict
(encoder_state_dict
)
decoder
.load_state_dict
(decoder_state_dict
)
encoder
= nn
.DataParallel
(encoder
)
decoder
= nn
.DataParallel
(decoder
)
同時,你也可以用第二種方式去保存和加載:
3.多卡訓練,單卡加載,方法二
使用model.state_dict()保存,但是單卡加載的時候,要把模型做并行化(在單卡上并行)
保存:
states
= {'state_dict_encoder': encoder
.state_dict
(), 'state_dict_decoder': decoder
.state_dict
(),}
torch
.save
(states
, fname
)
加載:
要注意由于我們保存的方式是以多卡的方式保存的,所以無論你加載之后的模型是在單卡運行還是在多卡運行,都先把模型并行化再去加載
encoder
= Encoder
()
decoder
= Decoder
()
encoder
= nn
.DataParallel
(encoder
)
decoder
= nn
.DataParallel
(decoder
)
checkpoint
= torch
.load
(model_path
)
encoder_state_dict
=checkpoint
['state_dict_encoder']
decoder_state_dict
=checkpoint
['state_dict_decoder']
encoder
.load_state_dict
(encoder_state_dict
)
decoder
.load_state_dict
(decoder_state_dict
)
4.多卡保存,多卡加載
這就和多卡保存,單卡加載第二中方式一樣了
**使用model.state_dict()**保存,加載的時候,要先把模型做并行化(在多卡上并行)
保存:
states
= {'state_dict_encoder': encoder
.state_dict
(), 'state_dict_decoder': decoder
.state_dict
(),}
torch
.save
(states
, fname
)
加載:
要注意由于我們保存的方式是以多卡的方式保存的,所以無論你加載之后的模型是在單卡運行還是在多卡運行,都先把模型并行化再去加載
encoder
= Encoder
()
decoder
= Decoder
()
encoder
= nn
.DataParallel
(encoder
)
decoder
= nn
.DataParallel
(decoder
)
checkpoint
= torch
.load
(model_path
)
encoder_state_dict
=checkpoint
['state_dict_encoder']
decoder_state_dict
=checkpoint
['state_dict_decoder']
encoder
.load_state_dict
(encoder_state_dict
)
decoder
.load_state_dict
(decoder_state_dict
)
創作挑戰賽新人創作獎勵來咯,堅持創作打卡瓜分現金大獎
總結
以上是生活随笔為你收集整理的pytorch 使用DataParallel 单机多卡和单卡保存和加载模型时遇到的问题的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔網站內容還不錯,歡迎將生活随笔推薦給好友。