写算子单元测试Writing Unit Tests
寫算子單元測試Writing Unit Tests!
一些單元測試示例,可在tests/python/relay/test_op_level3.py中找到,用于累積總和與乘積算子。
梯度算子
梯度算子對于編寫Relay中的可微程序非常重要。雖然Relay的autodiff算法可區(qū)分一流的語言結(jié)構(gòu),但算子是不透明的。Relay無法查看實現(xiàn),必須提供明確的差異化規(guī)則。
Python和C++都可編寫梯度算子,但是,示例集中在Python上,更常用。
在Python中添加梯度
Python梯度算子的集合可在Python/tvm/relay/op/_tensor_grad.py中找到。將介紹兩個具有代表性的示例:sigmoid和multiply。
@register_gradient(“sigmoid”)
def sigmoid_grad(orig,grad):
“”“Returns [grad * sigmoid(x) * (1 - sigmoid(x))].”""
return [grad * orig * (ones_like(orig) - orig)]
這里的輸入是原始算子orig和要累加的梯度。返回的是一個列表,第i個索引處的元素是算子相對于算子第i個輸入的導(dǎo)數(shù)。通常,梯度將返回一個列表,包含的元素數(shù)量與基本算子的輸入數(shù)量相同。
進(jìn)一步分析這個定義前面,首先回顧一下sigmoid函數(shù)的導(dǎo)數(shù):
上面的定義類似于數(shù)學(xué)定義,但有一個重要的補充,將在下面描述。
術(shù)語orig*(類似于(orig)-orig)直接匹配導(dǎo)數(shù),這里的orig是sigmoid函數(shù),但不只是對如何計算這個函數(shù)的梯度感興趣。將這個梯度與其它梯度組合起來,這樣就可在整個程序中累積梯度。
這就是梯度術(shù)語的意義所在。在表達(dá)式gradorig(one_like(orig)-orig)中,乘以grad,表示如何使用到目前為止的梯度合成導(dǎo)數(shù)。
現(xiàn)在,考慮乘法,一個稍微有趣的示例:
@register_gradient(“multiply”)
def multiply_grad(orig,grad):
“”“Returns [grad * y,grad * x]”""
x,y=orig.args
return [collapse_sum_like(grad * y,x),
collapse_sum_like(grad * x,y)]
在本例中,返回的列表中有兩個元素,multiply是一個二進(jìn)制算子。回想一下,如果
,偏導(dǎo)數(shù)是
有一個乘法所需的步驟,對于sigmoid不是必需的,乘法具有廣播語義。梯度的shape可能與輸入的shape不匹配,使用collapse_sum_like來獲取梯度grad * 項的內(nèi)容,使shape與要區(qū)分的輸入的shape匹配。
Adding a Gradient in C++
在C++中添加一個梯度,類似于在Python中添加,但是,用于注冊的接口略有不同。
首先,確保包含src/relay/transforms/pattern_utils.h。提供了用于在RelayAST中創(chuàng)建節(jié)點的 helper函數(shù)。然后,類似于Python示例的方式,定義梯度:
tvm::Array MultiplyGrad(const Expr& orig_call,const Expr& output_grad) {
const Call& call=orig_call.Downcast();
return { CollapseSumLike(Multiply(output_grad,call.args[1]),call.args[0]),
CollapseSumLike(Multiply(output_grad,call.args[0]),call.args[1]) };
}
在C++中,不能使用Python中的算子重載,需要進(jìn)行downcast,實現(xiàn)更加冗長。即使如此,可容易地驗證這個定義,是否反映了Python中的早期示例。
現(xiàn)在,不需要使用Python裝飾器,需要在基礎(chǔ)算子的注冊末尾,添加一個對“FPrimalGradient”的set_attr調(diào)用,注冊梯度。
RELAY_REGISTER_OP(“multiply”)
// …
// Set other attributes
// …
.set_attr(“FPrimalGradient”,MultiplyGrad);
參考鏈接:
https://tvm.apache.org/docs/dev/relay_add_op.html
TVM源碼研習(xí) — TVM中的IR設(shè)計與技術(shù)實現(xiàn)
總結(jié)
以上是生活随笔為你收集整理的写算子单元测试Writing Unit Tests的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。