pytorch 笔记: 扩展torch.autograd
1 擴展torch.autograd
????????向 autograd 添加操作需要為每個操作實現(xiàn)一個新的 Function 子類。
???????? 回想一下,函數(shù)是 autograd 用來編碼操作歷史和計算梯度的東西。
2 何時使用
????????通常,如果您想在模型中執(zhí)行不可微分或依賴非 Pytorch 庫中有的函數(shù)(例如 NumPy)的計算,但仍希望您的操作與其他操作鏈接并使用 autograd 引擎,可以通過擴展torch.autograd實現(xiàn)自定義函數(shù) .
????????在某些情況下,還可以使用自定義函數(shù)來提高性能和內存使用率:如果您使用擴展函數(shù)實現(xiàn)了前向和反向傳遞,則可以將它們包裝在 Function 中以與 autograd 引擎交互。
3 何時不建議使用
????????如果您已經可以根據 PyTorch 的內置操作編寫函數(shù),那么它的后向圖(很可能)已經能夠被 autograd 記錄。 在這種情況下,您不需要自己實現(xiàn)反向傳播功能。考慮使用一個普通的pytorch 函數(shù)即可。
4 如何使用
采取以下步驟:
1. 創(chuàng)建子類Function, 并實現(xiàn) forward() 和 backward() 方法。
2. 在 ctx 參數(shù)上調用正確的方法。
3. 聲明你的函數(shù)是否支持雙反向傳播(double backward)。
4. 使用 gradcheck 驗證您的梯度(反向傳播的實現(xiàn))是否正確。
4.1 第一步
????????在創(chuàng)建類?Function 之后,您需要定義 2 個方法:
- forward() 是執(zhí)行操作的代碼(前向傳播)
????????它可以采用任意數(shù)量的參數(shù),其中一些是optimal的(如果設定默認值的話)。
????????這里接受各種 Python 對象。
????????跟蹤歷史的張量參數(shù)(即 requires_grad=True 的tensor)將在調用之前轉換為不跟蹤歷史的張量參數(shù)【但它們如何被使用將在計算圖中注冊】。請注意,此邏輯不會遍歷列表/字典/任何其他數(shù)據結構,只會考慮作為調用的直接參數(shù)的張量。
????????如果有多個輸出,您可以返回單個張量輸出或張量元組。
- backward() 定義梯度公式。
? ? ? 它將被賦予與和forward的輸出一樣多的張量參數(shù),其中每個參數(shù)都代表對應輸出的梯度。重要的是永遠不要就地修改這些梯度(即不要有inplace操作)。
????????它應該返回與forward輸入一樣多的張量,每個張量都包含對應輸入的梯度。
????????如果您的輸入不需要梯度(needs_input_grad 是一個布爾元組,指示每個輸入是否需要梯度計算),或者是非張量對象,您可以返回 None。
????????此外,如果你有 forward() 的可選參數(shù),你可以返回比輸入更多的梯度,只要它們都是 None。
4.2 第二步
????????需要正確使用 forward 的 ctx 相關函數(shù),以確保新函數(shù)與 autograd 引擎一起正常工作。
- save_for_backward() 可以 保存稍后在反向傳播中需要使用的、前向的輸入張量 。? ?
????????????????任何東西,即非張量? ?和既不是輸入也不是輸出的張量,都應該直接存儲在 ctx 上。
- mark_dirty() 必須用于標記任何由 forward 函數(shù)就地修改的輸入。
- 使用 mark_non_differentiable() 來告訴autograd 引擎某一個輸出是否不可微。默認情況下,所有可微分類型的輸出張量都將設置為需要梯度。不可微分類型(即整數(shù)類型)的張量永遠不會被標記為需要梯度。
- set_materialize_grads() 可用于告訴 autograd 引擎在輸出不依賴于輸入的情況下優(yōu)化梯度計算,方法是不物化給予后向函數(shù)的梯度張量。也就是說,如果設置為 False,python 中的 None 對象或 C++ 中的“未定義張量”(x.defined() 為 False 的張量 x)將不會在向后調用之前轉換為填充零的張量,因此您的代碼將需要處理這些對象,就好像它們是用零填充的張量一樣。此設置的默認值為 True。 ?
4.3 第三步
????????如果你的函數(shù)不支持兩次反向傳播double backward,你應該通過使用 once_differentiable() 向后修飾來明確聲明它。 使用此裝飾器,嘗試通過您的函數(shù)執(zhí)行雙重反向傳播double backward將產生錯誤。
4.4 第四步
????????建議您使用 torch.autograd.gradcheck() 來檢查您的后向函數(shù)是否正確計算前向梯度,方法是使用后向函數(shù)計算雅可比矩陣,并將值元素與使用有限差分數(shù)值計算的雅可比進行比較
5 示例
from torch.autograd import Function # Inherit from Function class LinearFunction(Function):@staticmethod# 注意這里forward和backward都是靜態(tài)函數(shù)def forward(ctx, input, weight, bias=None):# bias 是一個可選變量,所以可以沒有梯度ctx.save_for_backward(input, weight, bias)#這里就使用了step2中的save_for_backward#也就是保存稍后在反向傳播中需要使用的、前向的輸入或輸張量output = input.mm(weight.t())if bias is not None:output += bias.unsqueeze(0).expand_as(output)return output#類似的前向傳播輸出定義@staticmethoddef backward(ctx, grad_output):#由于這個方法只有一個輸出(output),因而在backward中,只需要有一個輸入即可(ctx不算的話)input, weight, bias = ctx.saved_tensors#這呼應的就是前面的ctx.save_for_backwardgrad_input = grad_weight = grad_bias = Noneif ctx.needs_input_grad[0]:grad_input = grad_output.mm(weight)if ctx.needs_input_grad[1]:grad_weight = grad_output.t().mm(input)if bias is not None and ctx.needs_input_grad[2]:grad_bias = grad_output.sum(0)'''這些 needs_input_grad 檢查是可選的,只是為了提高效率。 如果你想讓你的代碼更簡單,你可以跳過它們。 為不需要的輸入返回梯度不會返回錯誤。 '''return grad_input, grad_weight, grad_bias現(xiàn)在,為了更容易使用這些自定義操作,我們建議為它們的 apply 方法設置別名:
linear = LinearFunction.apply這樣之后,linear的效果就和我們正常的比如'loss=torch.nn.MSELoss'的loss差不多了?
5.1 用非tensor 參數(shù)化的方法
class MulConstant(Function):@staticmethoddef forward(ctx, tensor, constant):ctx.constant = constant#非tensor的變量就不用save_for_backward了,直接存在ctx里面即可return tensor * constant@staticmethoddef backward(ctx, grad_output):return grad_output * ctx.constant, None#非tensor 變量的梯度為06 檢查效果
????????您可能想檢查您實現(xiàn)的反向傳播方法是否實際計算了函數(shù)的導數(shù)。 可以通過與使用小的有限差分的數(shù)值近似進行比較:
from torch.autograd import gradcheckinput = (torch.randn(20,20,dtype=torch.double,requires_grad=True), torch.randn(30,20,dtype=torch.double,requires_grad=True)) test = gradcheck(linear, input, eps=1e-6, atol=1e-4) ''' gradcheck 將張量的元組作為輸入,檢查用這些張量評估的梯度是否足夠接近數(shù)值近似值,如果它們都驗證了這個條件,則返回 True。 ''' print(test) #True總結
以上是生活随笔為你收集整理的pytorch 笔记: 扩展torch.autograd的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python笔记: staticmeth
- 下一篇: NTU 课程笔记:CV6422 置信区间