Skip to content

Commit 18a76dc

Browse files
committed
Calling SyscallN directly when dealing with pointer-pointers
1 parent 1af1852 commit 18a76dc

File tree

4 files changed

+20
-167
lines changed

4 files changed

+20
-167
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ go 1.13
44

55
require (
66
github.com/stretchr/testify v1.8.1
7-
golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c
7+
golang.org/x/sys v0.8.0
88
)

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
1111
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
1212
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
1313
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
14-
golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c h1:Lyn7+CqXIiC+LOR9aHD6jDK+hPcmAuCfuXztd1v4w1Q=
15-
golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
14+
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
15+
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
1616
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
1717
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
1818
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

sys.go

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
1+
//go:build windows
12
// +build windows
23

34
package wincred
45

56
import (
67
"reflect"
8+
"syscall"
79
"unsafe"
810

9-
syscall "golang.org/x/sys/windows"
11+
"golang.org/x/sys/windows"
1012
)
1113

1214
var (
13-
modadvapi32 = syscall.NewLazyDLL("advapi32.dll")
14-
procCredRead proc = modadvapi32.NewProc("CredReadW")
15+
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
16+
procCredRead = modadvapi32.NewProc("CredReadW")
1517
procCredWrite proc = modadvapi32.NewProc("CredWriteW")
1618
procCredDelete proc = modadvapi32.NewProc("CredDeleteW")
1719
procCredFree proc = modadvapi32.NewProc("CredFree")
18-
procCredEnumerate proc = modadvapi32.NewProc("CredEnumerateW")
20+
procCredEnumerate = modadvapi32.NewProc("CredEnumerateW")
1921
)
2022

2123
// Interface for syscall.Proc: helps testing
@@ -29,7 +31,7 @@ type sysCREDENTIAL struct {
2931
Type uint32
3032
TargetName *uint16
3133
Comment *uint16
32-
LastWritten syscall.Filetime
34+
LastWritten windows.Filetime
3335
CredentialBlobSize uint32
3436
CredentialBlob uintptr
3537
Persist uint32
@@ -59,15 +61,16 @@ const (
5961
sysCRED_TYPE_DOMAIN_EXTENDED sysCRED_TYPE = 0x6
6062

6163
// https://docs.microsoft.com/en-us/windows/desktop/Debug/system-error-codes
62-
sysERROR_NOT_FOUND = syscall.Errno(1168)
63-
sysERROR_INVALID_PARAMETER = syscall.Errno(87)
64+
sysERROR_NOT_FOUND = windows.Errno(1168)
65+
sysERROR_INVALID_PARAMETER = windows.Errno(87)
6466
)
6567

6668
// https://docs.microsoft.com/en-us/windows/desktop/api/wincred/nf-wincred-credreadw
6769
func sysCredRead(targetName string, typ sysCRED_TYPE) (*Credential, error) {
6870
var pcred *sysCREDENTIAL
69-
targetNamePtr, _ := syscall.UTF16PtrFromString(targetName)
70-
ret, _, err := procCredRead.Call(
71+
targetNamePtr, _ := windows.UTF16PtrFromString(targetName)
72+
ret, _, err := syscall.SyscallN(
73+
procCredRead.Addr(),
7174
uintptr(unsafe.Pointer(targetNamePtr)),
7275
uintptr(typ),
7376
0,
@@ -98,7 +101,7 @@ func sysCredWrite(cred *Credential, typ sysCRED_TYPE) error {
98101

99102
// https://docs.microsoft.com/en-us/windows/desktop/api/wincred/nf-wincred-creddeletew
100103
func sysCredDelete(cred *Credential, typ sysCRED_TYPE) error {
101-
targetNamePtr, _ := syscall.UTF16PtrFromString(cred.TargetName)
104+
targetNamePtr, _ := windows.UTF16PtrFromString(cred.TargetName)
102105
ret, _, err := procCredDelete.Call(
103106
uintptr(unsafe.Pointer(targetNamePtr)),
104107
uintptr(typ),
@@ -117,9 +120,10 @@ func sysCredEnumerate(filter string, all bool) ([]*Credential, error) {
117120
var pcreds uintptr
118121
var filterPtr *uint16
119122
if !all {
120-
filterPtr, _ = syscall.UTF16PtrFromString(filter)
123+
filterPtr, _ = windows.UTF16PtrFromString(filter)
121124
}
122-
ret, _, err := procCredEnumerate.Call(
125+
ret, _, err := syscall.SyscallN(
126+
procCredEnumerate.Addr(),
123127
uintptr(unsafe.Pointer(filterPtr)),
124128
0,
125129
uintptr(unsafe.Pointer(&count)),

sys_test.go

Lines changed: 1 addition & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
//go:build windows
12
// +build windows
23

34
package wincred
45

56
import (
67
"errors"
78
"testing"
8-
"unsafe"
99

1010
"github.com/stretchr/testify/assert"
1111
"github.com/stretchr/testify/mock"
@@ -32,80 +32,6 @@ func (t *mockProc) Call(a ...uintptr) (r1, r2 uintptr, lastErr error) {
3232
return uintptr(args.Int(0)), uintptr(args.Int(1)), args.Error(2)
3333
}
3434

35-
func TestSysCredRead_MockFailure(t *testing.T) {
36-
// The test error
37-
testError := errors.New("test error")
38-
// Mock `CreadRead`: returns failure state and the error
39-
mockCredRead := new(mockProc)
40-
mockCredRead.On("Call", mock.AnythingOfType("[]uintptr")).Return(0, 0, testError)
41-
mockCredRead.Setup(&procCredRead)
42-
defer mockCredRead.TearDown()
43-
// Mock `CredFree`: Must not be called
44-
mockCredFree := new(mockProc)
45-
mockCredFree.On("Call", mock.AnythingOfType("[]uintptr")).Return(0, 0, nil)
46-
mockCredFree.Setup(&procCredFree)
47-
defer mockCredFree.TearDown()
48-
49-
// Test it:
50-
var res *Credential
51-
var err error
52-
assert.NotPanics(t, func() { res, err = sysCredRead("foo", sysCRED_TYPE_GENERIC) })
53-
assert.Nil(t, res)
54-
assert.NotNil(t, err)
55-
assert.Equal(t, "test error", err.Error())
56-
mockCredRead.AssertNumberOfCalls(t, "Call", 1)
57-
mockCredFree.AssertNumberOfCalls(t, "Call", 0)
58-
}
59-
60-
func TestSysCredRead_Mock(t *testing.T) {
61-
// prepare some test data
62-
cred := new(Credential)
63-
cred.TargetName = "Foo"
64-
cred.Comment = "Bar"
65-
cred.CredentialBlob = []byte{1, 2, 3}
66-
credSys := sysFromCredential(cred)
67-
t.Log(credSys) // Workaround to keep the object alive
68-
69-
// Mock `CreadRead`: returns success and sets the pointer to the prepared sysCred struct
70-
mockCredRead := new(mockProc)
71-
mockCredRead.
72-
On("Call", mock.AnythingOfType("[]uintptr")).
73-
Return(1, 0, nil).
74-
Run(func(args mock.Arguments) {
75-
arg := args.Get(0).([]uintptr)
76-
assert.Equal(t, 4, len(arg))
77-
*(**sysCREDENTIAL)(unsafe.Pointer(arg[3])) = credSys
78-
})
79-
mockCredRead.Setup(&procCredRead)
80-
defer mockCredRead.TearDown()
81-
82-
// Mock `CredFree`: Must be called as well with the correct pointer
83-
mockCredFree := new(mockProc)
84-
mockCredFree.
85-
On("Call", mock.AnythingOfType("[]uintptr")).
86-
Return(0, 0, nil).
87-
Run(func(args mock.Arguments) {
88-
arg := args.Get(0).([]uintptr)
89-
assert.Equal(t, 1, len(arg))
90-
assert.Equal(t, uintptr(unsafe.Pointer(credSys)), arg[0])
91-
})
92-
mockCredFree.Setup(&procCredFree)
93-
defer mockCredFree.TearDown()
94-
95-
// Test it:
96-
var res *Credential
97-
var err error
98-
assert.NotPanics(t, func() { res, err = sysCredRead("Foo", sysCRED_TYPE_GENERIC) })
99-
mockCredRead.AssertNumberOfCalls(t, "Call", 1)
100-
mockCredFree.AssertNumberOfCalls(t, "Call", 1)
101-
assert.NotNil(t, res)
102-
assert.Nil(t, err)
103-
assert.Equal(t, "Foo", res.TargetName)
104-
assert.Equal(t, "Bar", res.Comment)
105-
assert.Equal(t, []byte{1, 2, 3}, res.CredentialBlob)
106-
assert.NotEqual(t, &cred, &res)
107-
}
108-
10935
func TestSysCredWrite_MockFailure(t *testing.T) {
11036
// Mock `CreadWrite`: returns failure state and the error
11137
mockCredWrite := new(mockProc)
@@ -163,80 +89,3 @@ func TestSysCredDelete_Mock(t *testing.T) {
16389
assert.Nil(t, err)
16490
mockCredDelete.AssertNumberOfCalls(t, "Call", 1)
16591
}
166-
167-
func TestSysCredEnumerate_MockFailure(t *testing.T) {
168-
// The test error
169-
testError := errors.New("test error")
170-
// Mock `CreadEnumerate`: returns failure state and the error
171-
mockCredEnumerate := new(mockProc)
172-
mockCredEnumerate.On("Call", mock.AnythingOfType("[]uintptr")).Return(0, 0, testError)
173-
mockCredEnumerate.Setup(&procCredEnumerate)
174-
defer mockCredEnumerate.TearDown()
175-
// Mock `CredFree`: Must not be called
176-
mockCredFree := new(mockProc)
177-
mockCredFree.On("Call", mock.AnythingOfType("[]uintptr")).Return(0, 0, nil)
178-
mockCredFree.Setup(&procCredFree)
179-
defer mockCredFree.TearDown()
180-
181-
// Test it:
182-
var res []*Credential
183-
var err error
184-
assert.NotPanics(t, func() { res, err = sysCredEnumerate("", true) })
185-
assert.Nil(t, res)
186-
assert.NotNil(t, err)
187-
assert.Equal(t, "test error", err.Error())
188-
mockCredEnumerate.AssertNumberOfCalls(t, "Call", 1)
189-
mockCredFree.AssertNumberOfCalls(t, "Call", 0)
190-
}
191-
192-
func TestSysCredEnumerate_Mock(t *testing.T) {
193-
// prepare some test data
194-
creds := []*Credential{new(Credential), new(Credential)}
195-
creds[0].TargetName = "Foo"
196-
creds[1].TargetName = "Bar"
197-
credsSys := [](*sysCREDENTIAL){
198-
sysFromCredential(creds[0]),
199-
sysFromCredential(creds[1]),
200-
}
201-
t.Log(credsSys[0]) // Workaround to keep the object alive
202-
t.Log(credsSys[1]) // Workaround to keep the object alive
203-
204-
// Mock `CreadEnumerate`: returns success and sets the pointer to the prepared sysCreds array
205-
mockCredEnumerate := new(mockProc)
206-
mockCredEnumerate.
207-
On("Call", mock.AnythingOfType("[]uintptr")).
208-
Return(1, 0, nil).
209-
Run(func(args mock.Arguments) {
210-
arg := args.Get(0).([]uintptr)
211-
assert.Equal(t, 4, len(arg))
212-
*(*int)(unsafe.Pointer(arg[2])) = len(credsSys)
213-
*(*[]*sysCREDENTIAL)(unsafe.Pointer(arg[3])) = credsSys
214-
})
215-
mockCredEnumerate.Setup(&procCredEnumerate)
216-
defer mockCredEnumerate.TearDown()
217-
218-
// Mock `CredFree`: Must be called as well with the correct pointer
219-
mockCredFree := new(mockProc)
220-
mockCredFree.
221-
On("Call", mock.AnythingOfType("[]uintptr")).
222-
Return(0, 0, nil).
223-
Run(func(args mock.Arguments) {
224-
arg := args.Get(0).([]uintptr)
225-
assert.Equal(t, 1, len(arg))
226-
assert.Equal(t, uintptr(unsafe.Pointer(&credsSys[0])), arg[0])
227-
})
228-
mockCredFree.Setup(&procCredFree)
229-
defer mockCredFree.TearDown()
230-
231-
// Test it:
232-
var res []*Credential
233-
var err error
234-
assert.NotPanics(t, func() { res, err = sysCredEnumerate("", true) })
235-
mockCredEnumerate.AssertNumberOfCalls(t, "Call", 1)
236-
mockCredFree.AssertNumberOfCalls(t, "Call", 1)
237-
assert.NotNil(t, res)
238-
assert.Nil(t, err)
239-
assert.Equal(t, 2, len(res))
240-
assert.Equal(t, "Foo", res[0].TargetName)
241-
assert.Equal(t, "Bar", res[1].TargetName)
242-
}

0 commit comments

Comments
 (0)