卸载pytorch_Pytorch中的hook的使用详解
首先明確一點,hook是什么呢?翻譯出來是"鉤子",顧名思義,就是掛在某個東西上的一種即插即用的結構。
有哪些hook呢?
常用的主要有3個:
1. torch.autograd.Variable.register_hook (in Automatic differentiation package)
2. torch.nn.Module.register_backward_hook (in torch.nn.Module)
3. torch.nn.Module.register_forward_hook (in torch.nn.Module)
第一個是register_hook,是針對Variable對象的(可以理解為掛在Variable用來提取某些信息的一種結構); 后面的兩個register_backward_hook和register_forward_hook是針對nn.Module(掛在Module上的)這個對象的。
其次,我們?yōu)楹我胔ook呢或者說hook有什么作用呢?
舉個例子,比如有這么一個函數(shù),
,你想通過梯度下降法來求得函數(shù)的極小值 (或者最小值)。這在pytorch里很容易實現(xiàn):import在pytorch的計算圖中,只有葉子節(jié)點 (leaf node)是可以被追蹤梯度的(即requires_grad=True),中間節(jié)點的梯度為了節(jié)省內(nèi)存的原因而沒有被保留 (對于中間變量,一旦它們完成了自身反傳的使命,梯度就會被釋放掉 ),因此當我們輸出中間變量y的梯度的時候:
y.grad系統(tǒng)會返回None。那怎么辦呢?
因此,hook就派上用場了。簡而言之,register_hook( )的作用是,當反傳時,除了完成原有的反傳,額外多完成一些任務。你可以定義一個中間變量的hook,將它的grad值打印出來,當然你也可以定義一個列表,將每次的grad值添加到里面去保留起來。
import torch from torch.autograd import Variablegrad_list = []def print_grad(grad):grad_list.append(grad)x = Variable(torch.randn(2, 1), requires_grad=True) y = x+2 z = torch.mean(torch.pow(y, 2)) lr = 1e-3 y.register_hook(print_grad) z.backward() x.data -= lr*x.grad.data需要注意的是,register_hook函數(shù)接收的是一個函數(shù)(函數(shù)名),這個函數(shù)有如下的形式:
hook(grad) -> Variable or NonePS:這個函數(shù)是可以改變被執(zhí)行變量梯度的!
v = torch.tensor([1, 1, 1], dtype = torch.float32, requires_grad=True) u = torch.pow(v, 2) z = torch.mean(u) # register hook for u h = u.register_hook(lambda grad: print(2 * grad)) # double the gradient z.backward() h.remove() # removes the hook系統(tǒng)輸出:tensor([0.6667, 0.6667, 0.6667])
這個函數(shù)返回一個句柄h, 它有一個方法h.remove( ), 可以使用這個方法將hook從變量u上"卸載"下來。
總結
以上是生活随笔為你收集整理的卸载pytorch_Pytorch中的hook的使用详解的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 4999元 红魔8 Pro系列银翼版上架
- 下一篇: 酷狗音乐驾车模式怎么设置