在 Go 语言中 Patch 非导出函数
TLDR; 使用 supermonkey[1] 可以 patch 任意導出/非導出函數。
目前在 Go 語言里寫測試還是比較麻煩的。
除了傳統的 test double,也可以通過把一個現成的對象的成員方法 Patch 掉,以達成測試執行時的特殊目的。
舉個例子,我的業務邏輯是從遠端獲取一段數據,在測試環節沒有網絡,所以我需要把和網絡交互的環節 mock 掉:
func LoadConfig() string {jsonBytes, err := redis.Get("xxxx")return string(jsonBytes) }這里的 redis.Get 中有網絡操作,寫測試時,我們的目的是為了驗證 Get 之后的邏輯是否正常,所以我們可以把這個 Get 替換為直接返回內容,不走網絡,社區中有 monkey patch 來達成這個目的:
monkey.Patch(redis.Get, func(input string) ([]byte, error) {return []byte("{"key" : 12345}"), nil })Patch 之后,redis.Get 就會按照我們替換之后的函數來執行了,還是比較方便的。
monkey patch 的基本原理不復雜,就是把進程中 .text 段中的代碼(你可以理解成 byte 數組)替換為用戶提供的替換函數。
patchvalue讀取 target 的地址使用了 reflect.ValueOf(funcVal).Pointer() 獲取函數的虛擬地址,然后把替換函數的內容以 []byte 的形式覆蓋進去。
一方面是因為 reflect 本身沒有辦法讀取非導出函數,一方面是從 Go 的語法上來講,我們沒法在包外部以字面量對非導出函數進行引用。所以目前開源的 monkey patch 是沒有辦法 patch 那些非導出函數的。
如果我們想要 patch 那些非導出函數,理論上并不需要對這個函數進行引用,只要能找到這個函數的虛擬地址就可以了,在這里提供一個思路,可以使用 nm 來找到我們想要 patch 的函數地址:
NM(1) GNU Development Tools NM(1)NAMEnm - list symbols from object filesnm 可以查看一個二進制文件中的所有符號的名字、虛擬地址、大小。還是舉個例子:
$cat hello.go package mainfunc say() {println("yyyy") }func main() {say() }build 需要帶 -l 的 gcflags,防止內聯優化:
go build -gcflags="-l" hello.go
用 nm 找找這個 say 的地址:
$nm hello | grep main 000000000044e3f0 T main 0000000000401070 T main.init 00000000004d5620 B main.initdone. 0000000000401050 T main.main 0000000000401000 T main.say ------> 這里 0000000000423620 T runtime.main 0000000000488c78 R runtime.main.f 0000000000442740 T runtime.main.func1 0000000000488c60 R runtime.main.func1.f 0000000000442780 T runtime.main.func2 0000000000488c68 R runtime.main.func2.f 00000000004b1e70 B runtime.main_init_done 0000000000488c70 R runtime.mainPC有了虛擬地址,也就有了拷貝的 target。
在 monkey 代碼的基礎上,再結合 nm 命令得到的符號地址,組合一下就是下面這樣的 demo:
package mainimport ("os""os/exec""reflect""strconv""strings""syscall""unsafe" )//go:noinline func HeiHeiHei() {println("hei") }//go:noinline func heiheiPrivate() {println("oh no") }func Replace() {println("fake") }func generateFuncName2PtrDict() map[string]uintptr {fileFullPath := os.Args[0]cmd := exec.Command("nm", fileFullPath)contentBytes, err := cmd.Output()if err != nil {println(err)return nil}var result = map[string]uintptr{}content := string(contentBytes)lines := strings.Split(content, "\n")for _, line := range lines {arr := strings.Split(line, " ")if len(arr) < 3 {continue}funcSymbol, addr := arr[2], arr[0]addrUint, _ := strconv.ParseUint(addr, 16, 64)result[funcSymbol] = uintptr(addrUint)}return result }func main() {m := generateFuncName2PtrDict()heiheiPrivate()replaceFunction(m["_main.heiheiPrivate"], (uintptr)(getPtr(reflect.ValueOf(Replace))))heiheiPrivate() }type value struct {_ uintptrptr unsafe.Pointer }func getPtr(v reflect.Value) unsafe.Pointer {return (*value)(unsafe.Pointer(&v)).ptr }// from is a pointer to the actual function // to is a pointer to a go funcvalue func replaceFunction(from, to uintptr) (original []byte) {jumpData := jmpToFunctionValue(to)f := rawMemoryAccess(from, len(jumpData))original = make([]byte, len(f))copy(original, f)copyToLocation(from, jumpData)return }// Assembles a jump to a function value func jmpToFunctionValue(to uintptr) []byte {return []byte{0x48, 0xBA,byte(to),byte(to >> 8),byte(to >> 16),byte(to >> 24),byte(to >> 32),byte(to >> 40),byte(to >> 48),byte(to >> 56), // movabs rdx,to0xFF, 0x22, // jmp QWORD PTR [rdx]} }func rawMemoryAccess(p uintptr, length int) []byte {return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{Data: p,Len: length,Cap: length,})) }func mprotectCrossPage(addr uintptr, length int, prot int) {pageSize := syscall.Getpagesize()for p := pageStart(addr); p < addr+uintptr(length); p += uintptr(pageSize) {page := rawMemoryAccess(p, pageSize)err := syscall.Mprotect(page, prot)if err != nil {panic(err)}} }// this function is super unsafe // aww yeah // It copies a slice to a raw memory location, disabling all memory protection before doing so. func copyToLocation(location uintptr, data []byte) {f := rawMemoryAccess(location, len(data))mprotectCrossPage(location, len(data), syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC)copy(f, data[:])mprotectCrossPage(location, len(data), syscall.PROT_READ|syscall.PROT_EXEC) }func pageStart(ptr uintptr) uintptr {return ptr & ^(uintptr(syscall.Getpagesize() - 1)) }go run -gcflags="-l" yourfile.go
上面的 demo 不跨平臺,建議還是直接試試開頭說的 lib 中的 example。
該思路已被封裝至 https://github.com/cch123/supermonkey 中。
參考資料
[1]
supermonkey: https://github.com/cch123/supermonkey
總結
以上是生活随笔為你收集整理的在 Go 语言中 Patch 非导出函数的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 深度阅读之《Concurrency in
- 下一篇: 极端情况下收缩 Go 进程的线程数