Python代码中的偏函数
技術背景
在數學中我們都學過偏導數\(\frac{\partial f(x,y)}{\partial x}\),而這里我們提到的偏函數,指的是\(f(y)(x)\)。也就是說,在代碼實現的過程中,雖然我們實現的一個函數可能帶有很多個變量,但是可以用偏函數的形式把其中一些不需要拆分和變化的變量轉變為固有變量。比較典型的兩個例子是計算偏導數和多進程優化。雖然大部分支持自動微分的框架都有相應的支持偏導數的接口,多進程操作中也可以指定額外的args,但是這些自帶的方法在形式上都是比較tricky的,感覺并不如使用偏函數優雅和簡潔。這里我們主要介紹python中可能會用到的偏函數功能--partial。
Partial簡單案例
我們先來一個最簡單的乘法函數\(f(x,y)=xy\)。假如說我們想得到該函數關于y的偏導數,注意,這里y是第二個輸入的變量,不是第一個位置,一般自動微分框架都默認都第一個位置的變量計算偏導數。相關代碼實現如下所示:
from functools import partial
def mul(x, y):
print (locals())
return x * y
x = 2
y = 3
res_0 = mul(x, y)
partial_mul = partial(mul, x=x)
res_1 = partial_mul(y=3)
print ('The result is: {}'.format(res_0))
print ('The result is: {}'.format(res_1))
這段代碼的運行結果為:
{'x': 2, 'y': 3}
{'x': 2, 'y': 3}
The result is: 6
The result is: 6
我們現在來分析一下上面這個案例中所體現的信息:
- 在使用partial函數時使用的是關鍵字參數,即時原本的變量不是一個關鍵字參數,而是一個位置參數。
- 雖然得到的偏函數partial_mul運行的方式跟函數一致,但其實它是一個partial的對象類型。
- 在生成partial_mul對象時已經執行過一遍函數,因此函數中的打印語句被打印了兩次。
- 偏函數的計算結果肯定是跟原函數保持一致的,但是在一些特殊場景下,我們可能會用到這種單變量的偏函數。
Concurrent多核并行場景
現在我們稍微修改一下上面的案例,我們要用concurrent這個并行工具去分別執行上述乘法任務,同時輸入的x也變成了一個多維的數組。然后為了驗證并行算法,這里每計算一次元素乘法,我們都用time.sleep方法讓進程休眠2秒鐘時間。由于此時的參數y還是一個標量,但是每次乘法計算我們都需要輸入這個標量,因此我們直接將其封裝到一個partial偏函數中,使得函數變成:\(f(x,y)=f(y)(x)=P(x)\),然后對x這個入參進行并行化操作:
import numpy as np
import concurrent.futures
from functools import partial
import time
# 定義休眠函數
def mul(x, y):
time.sleep(2)
return x * y
# 定義入參
x = np.array([1, 2, 3], np.float32)
y = 3.
# 有阻塞計算
time_0 = time.time()
res_0 = []
for _x in x:
res_0.append(mul(_x, y))
res_0 = np.array(res_0, np.float32)
time_1 = time.time()
# 并行計算
partial_mul = partial(mul, y=y)
time_2 = time.time()
with concurrent.futures.ProcessPoolExecutor(max_workers=x.shape[0]) as executor:
res = executor.map(partial_mul, x)
res_1 = np.array(list(res), np.float32)
time_3 = time.time()
print ('The result is: {}, and for loop time cost is: {}s'.format(res_0, time_1 - time_0))
print ('The result is: {}, and concurrent time cost is : {}s'.format(res_1, time_3 - time_2))
如果有感興趣的童鞋也可以去嘗試一下,在這種場景下的并行運算,如果參量y不是一個可迭代式的變量,是無法用zip壓縮傳到map函數中去的。上述代碼的運行結果如下:
The result is: [3. 6. 9.], and for loop time cost is: 6.005392789840698s
The result is: [3. 6. 9.], and concurrent time cost is : 2.0451698303222656s
這個計算時長其實就約等于休眠時長,因為這里我們開啟了3個進程來進行休眠,因此并行時長是2s。
Jax自動微分場景
這里我們用Jax的自動微分框架做一個示例,沒有安裝Jax和Jaxlib的想運行需要自行安裝相關軟件。雖然在Jax的grad函數中,支持argnums這樣的參數配置,但從代碼層面角度來說,總是顯得可讀性并不好。正常情況下我們算偏導數\(\frac{\partial f(x,y)}{\partial x}\)其實更合理的表述應該是\(\frac{\partial P(x)}{\partial x}\)。而如果按照Jax這種寫法,更像是從\([\frac{\partial f(x, y)}{\partial x}, \frac{\partial f(x, y)}{\partial x}]\)兩個元素中取了第一個元素。當然,這只是表述上的問題,也是我個人的理解,其實并不影響程序的正確性。這里使用partial偏函數的相關案例如下所示:
from functools import partial
from jax import grad
from jax import numpy as jnp
# Jax要求grad函數輸出結果為標量,所以要加一項求和
def mul(x, y):
f = x * y
return f.sum()
# 定義輸入變量
x = jnp.array([1, 2, 3], jnp.float32)
y = 3.
# 定義偏函數和對應偏導數
partial_mul = partial(mul, y=y)
grad_mul = grad(partial_mul)
print (grad_mul(x))
執行結果如下:
[3. 3. 3.]
總結概要
本文介紹了在Python中使用偏函數partial的方法,并且介紹了兩個使用partial函數的案例,分別是concurrent并行場景和基于jax的自動微分場景。在這些相關的場景下,我們用partial函數更多時候可以使得代碼的可讀性更好,在性能上其實并沒有什么提升。如果不想使用partial函數,類似的功能也可以使用參考鏈接中所介紹的方法,實現一個裝飾器,也可以做到一樣的功能。
版權聲明
本文首發鏈接為:https://www.cnblogs.com/dechinphy/p/partial.html
作者ID:DechinPhy
更多原著文章:https://www.cnblogs.com/dechinphy/
請博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html
參考鏈接
- https://www.cnblogs.com/huyangblog/p/8999866.html
總結
以上是生活随笔為你收集整理的Python代码中的偏函数的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 铭凡 UM790 Pro 迷你主机价格公
- 下一篇: Link三频电竞路由器发布(三频电竞无线