Skip to content

Commit 43ec336

Browse files
committed
Initial commit
1 parent 7d9c989 commit 43ec336

7 files changed

+331
-0
lines changed

go.mod

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
module github.com/undefinedlabs/go-mpatch
2+
3+
go 1.13

patcher.go

+161
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
package mpatch // import "github.com/undefinedlabs/go-mpatch"
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"reflect"
7+
"sync"
8+
"syscall"
9+
"unsafe"
10+
)
11+
12+
type (
13+
Patch struct {
14+
targetBytes []byte
15+
target *reflect.Value
16+
redirection *reflect.Value
17+
}
18+
pointer struct {
19+
length uintptr
20+
ptr uintptr
21+
}
22+
)
23+
24+
var (
25+
patchLock = sync.Mutex{}
26+
patches = make(map[uintptr]*Patch)
27+
pageSize = syscall.Getpagesize()
28+
)
29+
30+
func PatchMethod(target, redirection interface{}) (*Patch, error) {
31+
tValue := getValueFrom(target)
32+
rValue := getValueFrom(redirection)
33+
err := isPatchable(&tValue, &rValue)
34+
if err != nil {
35+
return nil, err
36+
}
37+
patch := &Patch{target: &tValue, redirection: &rValue}
38+
err = applyPatch(patch)
39+
if err != nil {
40+
return nil, err
41+
}
42+
return patch, nil
43+
}
44+
func PatchInstanceMethodByName(target reflect.Type, methodName string, redirection interface{}) (*Patch, error) {
45+
if target.Kind() == reflect.Struct {
46+
target = reflect.PtrTo(target)
47+
}
48+
method, ok := target.MethodByName(methodName)
49+
if !ok {
50+
return nil, errors.New(fmt.Sprintf("Method '%v' not found", methodName))
51+
}
52+
return PatchMethodByReflect(method, redirection)
53+
}
54+
func PatchMethodByReflect(target reflect.Method, redirection interface{}) (*Patch, error) {
55+
tValue := &target.Func
56+
rValue := getValueFrom(redirection)
57+
err := isPatchable(tValue, &rValue)
58+
if err != nil {
59+
return nil, err
60+
}
61+
patch := &Patch{target: tValue, redirection: &rValue}
62+
err = applyPatch(patch)
63+
if err != nil {
64+
return nil, err
65+
}
66+
return patch, nil
67+
}
68+
69+
func (p *Patch) Patch() error {
70+
if p == nil {
71+
return errors.New("patch is nil")
72+
}
73+
err := isPatchable(p.target, p.redirection)
74+
if err != nil {
75+
return err
76+
}
77+
err = applyPatch(p)
78+
if err != nil {
79+
return err
80+
}
81+
return nil
82+
}
83+
func (p *Patch) Unpatch() error {
84+
if p == nil {
85+
return errors.New("patch is nil")
86+
}
87+
return applyUnpatch(p)
88+
}
89+
90+
func isPatchable(target, redirection *reflect.Value) error {
91+
if target.Kind() != reflect.Func || redirection.Kind() != reflect.Func {
92+
return errors.New("the target and/or redirection is not a Func")
93+
}
94+
if target.Type() != redirection.Type() {
95+
return errors.New(fmt.Sprintf("the target and/or redirection doesn't have the same type: %s != %s", target.Type(), redirection.Type()))
96+
}
97+
if _, ok := patches[target.Pointer()]; ok {
98+
return errors.New("the target is already patched")
99+
}
100+
return nil
101+
}
102+
103+
func applyPatch(patch *Patch) error {
104+
patchLock.Lock()
105+
defer patchLock.Unlock()
106+
tPointer := patch.target.Pointer()
107+
rPointer := getInternalPtrFromValue(*patch.redirection)
108+
rPointerJumpBytes := getJumpFuncBytes(rPointer)
109+
tPointerBytes := getMemorySliceFromPointer(tPointer, len(rPointerJumpBytes))
110+
targetBytes := make([]byte, len(tPointerBytes))
111+
copy(targetBytes, tPointerBytes)
112+
err := copyDataToPtr(tPointer, rPointerJumpBytes)
113+
if err != nil {
114+
return err
115+
}
116+
patch.targetBytes = targetBytes
117+
patches[tPointer] = patch
118+
return nil
119+
}
120+
121+
func applyUnpatch(patch *Patch) error {
122+
patchLock.Lock()
123+
defer patchLock.Unlock()
124+
if patch.targetBytes == nil || len(patch.targetBytes) == 0 {
125+
return errors.New("the target is not patched")
126+
}
127+
tPointer := patch.target.Pointer()
128+
if _, ok := patches[tPointer]; !ok {
129+
return errors.New("the target is not patched")
130+
}
131+
delete(patches, tPointer)
132+
err := copyDataToPtr(tPointer, patch.targetBytes)
133+
if err != nil {
134+
return err
135+
}
136+
return nil
137+
}
138+
139+
func getInternalPtrFromValue(value reflect.Value) uintptr {
140+
return (*pointer)(unsafe.Pointer(&value)).ptr
141+
}
142+
143+
func getValueFrom(data interface{}) reflect.Value {
144+
if cValue, ok := data.(reflect.Value); ok {
145+
return cValue
146+
} else {
147+
return reflect.ValueOf(data)
148+
}
149+
}
150+
151+
func getMemorySliceFromPointer(p uintptr, length int) []byte {
152+
return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{
153+
Data: p,
154+
Len: length,
155+
Cap: length,
156+
}))
157+
}
158+
159+
func getPageStartPtr(ptr uintptr) uintptr {
160+
return ptr & ^(uintptr(pageSize - 1))
161+
}

patcher_test.go

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package mpatch
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
)
7+
8+
//go:noinline
9+
func methodA() int { return 1 }
10+
11+
//go:noinline
12+
func methodB() int { return 2 }
13+
14+
type myStruct struct {
15+
}
16+
17+
//go:noinline
18+
func (s *myStruct) Method() int {
19+
return 1
20+
}
21+
22+
func TestPatcher(t *testing.T) {
23+
patch, err := PatchMethod(methodA, methodB)
24+
if err != nil {
25+
t.Fatal(err)
26+
}
27+
if methodA() != 2 {
28+
t.Fatal("The patch did not work")
29+
}
30+
31+
err = patch.Unpatch()
32+
if err != nil {
33+
t.Fatal(err)
34+
}
35+
if methodA() != 1 {
36+
t.Fatal("The unpatch did not work")
37+
}
38+
}
39+
40+
func TestInstancePatcher(t *testing.T) {
41+
mStruct := myStruct{}
42+
43+
var patch *Patch
44+
var err error
45+
patch, err = PatchInstanceMethodByName(reflect.TypeOf(mStruct), "Method", func(m *myStruct) int {
46+
patch.Unpatch()
47+
defer patch.Patch()
48+
return 41 + m.Method()
49+
})
50+
if err != nil {
51+
t.Fatal(err)
52+
}
53+
54+
if mStruct.Method() != 42 {
55+
t.Fatal("The patch did not work")
56+
}
57+
err = patch.Unpatch()
58+
if err != nil {
59+
t.Fatal(err)
60+
}
61+
if mStruct.Method() != 1 {
62+
t.Fatal("The unpatch did not work")
63+
}
64+
}

patcher_unix.go

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// +build !windows
2+
3+
package mpatch
4+
5+
import "syscall"
6+
7+
var writeAccess = syscall.PROT_READ | syscall.PROT_WRITE | syscall.PROT_EXEC
8+
var readAccess = syscall.PROT_READ | syscall.PROT_EXEC
9+
10+
func callMProtect(addr uintptr, length int, prot int) error {
11+
for p := getPageStartPtr(addr); p < addr+uintptr(length); p += uintptr(pageSize) {
12+
page := getMemorySliceFromPointer(p, pageSize)
13+
err := syscall.Mprotect(page, prot)
14+
if err != nil {
15+
return err
16+
}
17+
}
18+
return nil
19+
}
20+
21+
func copyDataToPtr(ptr uintptr, data []byte) error {
22+
dataLength := len(data)
23+
ptrByteSlice := getMemorySliceFromPointer(ptr, len(data))
24+
err := callMProtect(ptr, dataLength, writeAccess)
25+
if err != nil {
26+
return err
27+
}
28+
copy(ptrByteSlice, data[:])
29+
err = callMProtect(ptr, dataLength, readAccess)
30+
if err != nil {
31+
return err
32+
}
33+
return nil
34+
}

patcher_windows.go

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// +build windows
2+
3+
package mpatch
4+
5+
import (
6+
"syscall"
7+
"unsafe"
8+
)
9+
10+
const pageExecuteReadAndWrite = 0x40
11+
12+
var virtualProtectProc = syscall.NewLazyDLL("kernel32.dll").NewProc("VirtualProtect")
13+
14+
func callVirtualProtect(lpAddress uintptr, dwSize int, flNewProtect uint32, lpflOldProtect unsafe.Pointer) error {
15+
ret, _, _ := virtualProtectProc.Call(lpAddress, uintptr(dwSize), uintptr(flNewProtect), uintptr(lpflOldProtect))
16+
if ret == 0 {
17+
return syscall.GetLastError()
18+
}
19+
return nil
20+
}
21+
22+
func copyDataToPtr(ptr uintptr, data []byte) error {
23+
var oldPerms, tmp uint32
24+
dataLength := len(data)
25+
ptrByteSlice := getMemorySliceFromPointer(ptr, len(data))
26+
err := callVirtualProtect(ptr, dataLength, pageExecuteReadAndWrite, unsafe.Pointer(&oldPerms))
27+
if err != nil {
28+
return err
29+
}
30+
copy(ptrByteSlice, data[:])
31+
err = callVirtualProtect(ptr, dataLength, oldPerms, unsafe.Pointer(&tmp))
32+
if err != nil {
33+
return err
34+
}
35+
}

patcher_x32.go

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// +build 386
2+
3+
package mpatch
4+
5+
// Gets the jump function rewrite bytes
6+
func getJumpFuncBytes(to uintptr) []byte {
7+
return []byte{
8+
0xBA,
9+
byte(to),
10+
byte(to >> 8),
11+
byte(to >> 16),
12+
byte(to >> 24),
13+
0xFF, 0x22,
14+
}
15+
}

patcher_x64.go

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// +build amd64
2+
3+
package mpatch
4+
5+
// Gets the jump function rewrite bytes
6+
func getJumpFuncBytes(to uintptr) []byte {
7+
return []byte{
8+
0x48, 0xBA,
9+
byte(to),
10+
byte(to >> 8),
11+
byte(to >> 16),
12+
byte(to >> 24),
13+
byte(to >> 32),
14+
byte(to >> 40),
15+
byte(to >> 48),
16+
byte(to >> 56),
17+
0xFF, 0x22,
18+
}
19+
}

0 commit comments

Comments
 (0)