GELU函数的近似
一、背景
GELU(Gaussian Error Linear Unit)函數(shù)的定義為
其中
考慮高斯誤差函數(shù)
通過令得
由于高斯誤差函數(shù)里面涉及了指數(shù)運算和積分運算,如何利用初等函數(shù)進(jìn)行擬合,對于提高運算效率就顯得比較有意義了。
二、方法
高斯誤差函數(shù)的圖像為
from scipy.special import erf import numpy as np import matplotlib.pyplot as plt fig, ax = plt.subplots()ax.spines['right'].set_visible(False) ax.spines['top'].set_color('none') ax.xaxis.set_ticks_position('bottom') ax.spines['bottom'].set_position(('data',0)) ax.yaxis.set_ticks_position('left') ax.spines['left'].set_position(('data',0))x=np.linspace(-5,5,100) y=erf(x)plt.plot(x,y) plt.title('graph of erf function') plt.show()可以驗證erf(x)為奇函數(shù),首先想到的是利用tanh(x)做近似??梢钥紤]兩種擬合方式,局部擬合和全局?jǐn)M合。全局?jǐn)M合也就是數(shù)學(xué)里面的一致性問題。這里只考慮利用對erf(x)進(jìn)行擬合。
局部擬合
局部擬合考慮利用泰勒展開,擬合前幾項的系數(shù)
則
令
求解得到
此時,擬合的函數(shù)為
import numpy as np import matplotlib.pyplot as plt a=np.sqrt(2/np.pi) b=(4-np.pi)/(3*np.sqrt(2)*np.pi**(3/2)) b1=(4-np.pi)/(6*np.pi)x=np.linspace(-5,5,100) #y=erf(x) def gelu(x):return 1/2*x*(1+erf(x/np.sqrt(2))) def gelu_pro1(x):return 1/2*x*(1+np.tanh(np.sqrt(2/np.pi)*(x+b1*x**3)))fig, ax = plt.subplots() ax.spines['right'].set_visible(False) ax.spines['top'].set_color('none') ax.xaxis.set_ticks_position('bottom') ax.spines['bottom'].set_position(('data',0)) ax.yaxis.set_ticks_position('left') ax.spines['left'].set_position(('data',0))plt.plot(x,gelu(x)) plt.plot(x,gelu_pro1(x),'r-') plt.legend(['gelu','gelu_pro']) plt.title('performance of the approximate') plt.show()全局?jǐn)M合
考慮到中的求解能夠保證對GELU的一階近似,這里我們先固定a,然后對b進(jìn)行求解。即
import numpy as np from scipy.special import erf from scipy.optimize import minimizea=np.sqrt(2/np.pi)def f(x,b):return np.abs(erf(x/np.sqrt(2))-np.tanh(a*x+b*x**3))def g(b):return np.max([f(x,b) for x in np.arange(0,4,0.001)])options={'xtol':1e-10,'ftol':1e-10,'maxiter':100000} result=minimize(g,0,method='Powell',options=options) print(result.x)
???????
總結(jié)
- 上一篇: 新浪微博终于完成多数ui
- 下一篇: Unity3D优化