【pytorch学习】四种钩子方法(register_forward_hook等)的用法和注意点
??為了節(jié)省顯存(內(nèi)存),pytorch在計(jì)算過程中不保存中間變量,包括中間層的特征圖和非葉子張量的梯度等。有時(shí)對(duì)網(wǎng)絡(luò)進(jìn)行分析時(shí)需要查看或修改這些中間變量,此時(shí)就需要注冊(cè)一個(gè)鉤子(hook)來(lái)導(dǎo)出需要的中間變量。網(wǎng)上介紹這個(gè)的有不少,但我看了一圈,多少都有不準(zhǔn)確或不易懂的地方,我這里再總結(jié)一下,給出實(shí)際用法和注意點(diǎn)。
hook方法有四種:
torch.Tensor.register_hook()
torch.nn.Module.register_forward_hook()
torch.nn.Module.register_backward_hook()
torch.nn.Module.register_forward_pre_hook().
1, torch.Tensor.register_hook(hook)
??用來(lái)導(dǎo)出指定張量的梯度,或修改這個(gè)梯度值。
import torch def grad_hook(grad):grad *= 2 x = torch.tensor([2., 2., 2., 2.], requires_grad=True) y = torch.pow(x, 2) z = torch.mean(y) h = x.register_hook(grad_hook) z.backward() print(x.grad) h.remove() # removes the hook >>> tensor([2., 2., 2., 2.])注意:(1)上述代碼是有效的,但如果寫成 grad = grad * 2就失效了,因?yàn)榇藭r(shí)沒有對(duì)grad進(jìn)行本地操作,新的grad值沒有傳遞給指定的梯度。保險(xiǎn)起見,最好在def語(yǔ)句中寫明return grad。即:
def grad_hook(grad):grad = grad * 2return grad(2)可以用remove()方法取消hook。注意remove()必須在backward()之后,因?yàn)橹挥性趫?zhí)行backward()語(yǔ)句時(shí),pytorch才開始計(jì)算梯度,而在x.register_hook(grad_hook)時(shí)它僅僅是"注冊(cè)"了一個(gè)grad的鉤子,此時(shí)并沒有計(jì)算,而執(zhí)行remove就取消了這個(gè)鉤子,然后再backward()時(shí)鉤子就不起作用了。
(3)如果在類中定義鉤子函數(shù),輸入?yún)?shù)必須先加上self,即
2, torch.nn.Module.register_forward_hook(module, in, out)
??用來(lái)導(dǎo)出指定子模塊(可以是層、模塊等nn.Module類型)的輸入輸出張量,但只可修改輸出,常用來(lái)導(dǎo)出或修改卷積特征圖。
inps, outs = [],[] def layer_hook(module, inp, out):inps.append(inp[0].data.cpu().numpy())outs.append(out.data.cpu().numpy())hook = net.layer1.register_forward_hook(layer_hook) output = net(input) hook.remove()注意:(1)因?yàn)槟K可以是多輸入的,所以輸入是tuple型的,需要先提取其中的Tensor再操作;輸出是Tensor型的可直接用。
???(2)導(dǎo)出后不要放到顯存上,除非你有A100。
???(3)只能修改輸出out的值,不能修改輸入inp的值(不能返回,本地修改也無(wú)效),修改時(shí)最好用return形式返回,如:
??這段代碼用在manifold mixup中,用來(lái)對(duì)中間層特征進(jìn)行混合來(lái)實(shí)現(xiàn)數(shù)據(jù)增強(qiáng),其中self.lam是一個(gè)[0,1]概率值,self.indices是shuffle后的序號(hào)。
3, torch.nn.Module.register_forward_pre_hook(module, in)
??用來(lái)導(dǎo)出或修改指定子模塊的輸入張量。
def pre_hook(module, inp):inp0 = inp[0]inp0 = inp0 * 2inp = tuple([inp0])return inphook = net.layer1.register_forward_pre_hook(pre_hook) output = net(input) hook.remove()注意:(1)inp值是個(gè)tuple類型,所以需要先把其中的張量提取出來(lái),再做其他操作,然后還要再轉(zhuǎn)化為tuple返回。
(2)在執(zhí)行output = net(input)時(shí)才會(huì)調(diào)用此句,remove()可放在調(diào)用后用來(lái)取消鉤子。
4, torch.nn.Module.register_backward_hook(module, grad_in, grad_out)
??用來(lái)導(dǎo)出指定子模塊的輸入輸出張量的梯度,但只可修改輸入張量的梯度(即只能返回gin),輸出張量梯度不可修改。
gouts = [] def backward_hook(module, gin, gout):print(len(gin),len(gout))gouts.append(gout[0].data.cpu().numpy())gin0,gin1,gin2 = gingin1 = gin1*2gin2 = gin2*3gin = tuple([gin0,gin1,gin2])return ginhook = net.layer1.register_backward_hook(backward_hook) loss.backward() hook.remove()注意:
(1)其中的grad_in和grad_out都是tuple,必須要先解開,修改時(shí)執(zhí)行操作后再重新放回tuple返回。
(2)這個(gè)鉤子函數(shù)在backward()語(yǔ)句中被調(diào)用,所以remove()要放在backward()之后用來(lái)取消鉤子。
總結(jié)
以上是生活随笔為你收集整理的【pytorch学习】四种钩子方法(register_forward_hook等)的用法和注意点的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: [论文学习]Manifold Mixup
- 下一篇: 简明代码介绍类激活图CAM, GradC