Skip to content

Commit b51010b

Browse files
bradfitzzx2c4
authored andcommitted
all: use Go 1.19 and its atomic types
Signed-off-by: Brad Fitzpatrick <[email protected]> Signed-off-by: Jason A. Donenfeld <[email protected]>
1 parent d1d0842 commit b51010b

20 files changed

+156
-288
lines changed

Diff for: conn/bind_windows.go

+16-16
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ type afWinRingBind struct {
7474
type WinRingBind struct {
7575
v4, v6 afWinRingBind
7676
mu sync.RWMutex
77-
isOpen uint32
77+
isOpen atomic.Uint32 // 0, 1, or 2
7878
}
7979

8080
func NewDefaultBind() Bind { return NewWinRingBind() }
@@ -212,7 +212,7 @@ func (bind *afWinRingBind) CloseAndZero() {
212212
}
213213

214214
func (bind *WinRingBind) closeAndZero() {
215-
atomic.StoreUint32(&bind.isOpen, 0)
215+
bind.isOpen.Store(0)
216216
bind.v4.CloseAndZero()
217217
bind.v6.CloseAndZero()
218218
}
@@ -276,7 +276,7 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
276276
bind.closeAndZero()
277277
}
278278
}()
279-
if atomic.LoadUint32(&bind.isOpen) != 0 {
279+
if bind.isOpen.Load() != 0 {
280280
return nil, 0, ErrBindAlreadyOpen
281281
}
282282
var sa windows.Sockaddr
@@ -299,17 +299,17 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort
299299
return nil, 0, err
300300
}
301301
}
302-
atomic.StoreUint32(&bind.isOpen, 1)
302+
bind.isOpen.Store(1)
303303
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
304304
}
305305

306306
func (bind *WinRingBind) Close() error {
307307
bind.mu.RLock()
308-
if atomic.LoadUint32(&bind.isOpen) != 1 {
308+
if bind.isOpen.Load() != 1 {
309309
bind.mu.RUnlock()
310310
return nil
311311
}
312-
atomic.StoreUint32(&bind.isOpen, 2)
312+
bind.isOpen.Store(2)
313313
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
314314
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
315315
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
@@ -345,8 +345,8 @@ func (bind *afWinRingBind) InsertReceiveRequest() error {
345345
//go:linkname procyield runtime.procyield
346346
func procyield(cycles uint32)
347347

348-
func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) {
349-
if atomic.LoadUint32(isOpen) != 1 {
348+
func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
349+
if isOpen.Load() != 1 {
350350
return 0, nil, net.ErrClosed
351351
}
352352
bind.rx.mu.Lock()
@@ -359,7 +359,7 @@ retry:
359359
count = 0
360360
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
361361
if tries > 0 {
362-
if atomic.LoadUint32(isOpen) != 1 {
362+
if isOpen.Load() != 1 {
363363
return 0, nil, net.ErrClosed
364364
}
365365
procyield(1)
@@ -378,7 +378,7 @@ retry:
378378
if err != nil {
379379
return 0, nil, err
380380
}
381-
if atomic.LoadUint32(isOpen) != 1 {
381+
if isOpen.Load() != 1 {
382382
return 0, nil, net.ErrClosed
383383
}
384384
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
@@ -395,7 +395,7 @@ retry:
395395
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
396396
// attacker bandwidth, just like the rest of the receive path.
397397
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
398-
if atomic.LoadUint32(isOpen) != 1 {
398+
if isOpen.Load() != 1 {
399399
return 0, nil, net.ErrClosed
400400
}
401401
goto retry
@@ -421,8 +421,8 @@ func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
421421
return bind.v6.Receive(buf, &bind.isOpen)
422422
}
423423

424-
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error {
425-
if atomic.LoadUint32(isOpen) != 1 {
424+
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
425+
if isOpen.Load() != 1 {
426426
return net.ErrClosed
427427
}
428428
if len(buf) > bytesPerPacket {
@@ -444,7 +444,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3
444444
if err != nil {
445445
return err
446446
}
447-
if atomic.LoadUint32(isOpen) != 1 {
447+
if isOpen.Load() != 1 {
448448
return net.ErrClosed
449449
}
450450
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
@@ -538,7 +538,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
538538
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
539539
bind.mu.RLock()
540540
defer bind.mu.RUnlock()
541-
if atomic.LoadUint32(&bind.isOpen) != 1 {
541+
if bind.isOpen.Load() != 1 {
542542
return net.ErrClosed
543543
}
544544
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
@@ -552,7 +552,7 @@ func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
552552
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
553553
bind.mu.RLock()
554554
defer bind.mu.RUnlock()
555-
if atomic.LoadUint32(&bind.isOpen) != 1 {
555+
if bind.isOpen.Load() != 1 {
556556
return net.ErrClosed
557557
}
558558
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)

Diff for: device/alignment_test.go

-65
This file was deleted.

Diff for: device/device.go

+15-17
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ type Device struct {
3030
// will become the actual state; Up can fail.
3131
// The device can also change state multiple times between time of check and time of use.
3232
// Unsynchronized uses of state must therefore be advisory/best-effort only.
33-
state uint32 // actually a deviceState, but typed uint32 for convenience
33+
state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
3434
// stopping blocks until all inputs to Device have been closed.
3535
stopping sync.WaitGroup
3636
// mu protects state changes.
@@ -58,9 +58,8 @@ type Device struct {
5858
keyMap map[NoisePublicKey]*Peer
5959
}
6060

61-
// Keep this 8-byte aligned
6261
rate struct {
63-
underLoadUntil int64
62+
underLoadUntil atomic.Int64
6463
limiter ratelimiter.Ratelimiter
6564
}
6665

@@ -82,7 +81,7 @@ type Device struct {
8281

8382
tun struct {
8483
device tun.Device
85-
mtu int32
84+
mtu atomic.Int32
8685
}
8786

8887
ipcMutex sync.RWMutex
@@ -94,10 +93,9 @@ type Device struct {
9493
// There are three states: down, up, closed.
9594
// Transitions:
9695
//
97-
// down -----+
98-
// ↑↓ ↓
99-
// up -> closed
100-
//
96+
// down -----+
97+
// ↑↓ ↓
98+
// up -> closed
10199
type deviceState uint32
102100

103101
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
@@ -110,7 +108,7 @@ const (
110108
// deviceState returns device.state.state as a deviceState
111109
// See those docs for how to interpret this value.
112110
func (device *Device) deviceState() deviceState {
113-
return deviceState(atomic.LoadUint32(&device.state.state))
111+
return deviceState(device.state.state.Load())
114112
}
115113

116114
// isClosed reports whether the device is closed (or is closing).
@@ -149,14 +147,14 @@ func (device *Device) changeState(want deviceState) (err error) {
149147
case old:
150148
return nil
151149
case deviceStateUp:
152-
atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
150+
device.state.state.Store(uint32(deviceStateUp))
153151
err = device.upLocked()
154152
if err == nil {
155153
break
156154
}
157155
fallthrough // up failed; bring the device all the way back down
158156
case deviceStateDown:
159-
atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
157+
device.state.state.Store(uint32(deviceStateDown))
160158
errDown := device.downLocked()
161159
if err == nil {
162160
err = errDown
@@ -182,7 +180,7 @@ func (device *Device) upLocked() error {
182180
device.peers.RLock()
183181
for _, peer := range device.peers.keyMap {
184182
peer.Start()
185-
if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
183+
if peer.persistentKeepaliveInterval.Load() > 0 {
186184
peer.SendKeepalive()
187185
}
188186
}
@@ -219,11 +217,11 @@ func (device *Device) IsUnderLoad() bool {
219217
now := time.Now()
220218
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
221219
if underLoad {
222-
atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
220+
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
223221
return true
224222
}
225223
// check if recently under load
226-
return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
224+
return device.rate.underLoadUntil.Load() > now.UnixNano()
227225
}
228226

229227
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
@@ -283,7 +281,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
283281

284282
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
285283
device := new(Device)
286-
device.state.state = uint32(deviceStateDown)
284+
device.state.state.Store(uint32(deviceStateDown))
287285
device.closed = make(chan struct{})
288286
device.log = logger
289287
device.net.bind = bind
@@ -293,7 +291,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
293291
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
294292
mtu = DefaultMTU
295293
}
296-
device.tun.mtu = int32(mtu)
294+
device.tun.mtu.Store(int32(mtu))
297295
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
298296
device.rate.limiter.Init()
299297
device.indexTable.Init()
@@ -359,7 +357,7 @@ func (device *Device) Close() {
359357
if device.isClosed() {
360358
return
361359
}
362-
atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
360+
device.state.state.Store(uint32(deviceStateClosed))
363361
device.log.Verbosef("Device closing")
364362

365363
device.tun.device.Close()

Diff for: device/device_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ func BenchmarkThroughput(b *testing.B) {
333333

334334
// Measure how long it takes to receive b.N packets,
335335
// starting when we receive the first packet.
336-
var recv uint64
336+
var recv atomic.Uint64
337337
var elapsed time.Duration
338338
var wg sync.WaitGroup
339339
wg.Add(1)
@@ -342,7 +342,7 @@ func BenchmarkThroughput(b *testing.B) {
342342
var start time.Time
343343
for {
344344
<-pair[0].tun.Inbound
345-
new := atomic.AddUint64(&recv, 1)
345+
new := recv.Add(1)
346346
if new == 1 {
347347
start = time.Now()
348348
}
@@ -358,7 +358,7 @@ func BenchmarkThroughput(b *testing.B) {
358358
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
359359
pingc := pair[1].tun.Outbound
360360
var sent uint64
361-
for atomic.LoadUint64(&recv) != uint64(b.N) {
361+
for recv.Load() != uint64(b.N) {
362362
sent++
363363
pingc <- ping
364364
}

Diff for: device/keypair.go

+2-11
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"sync"
1111
"sync/atomic"
1212
"time"
13-
"unsafe"
1413

1514
"golang.zx2c4.com/wireguard/replay"
1615
)
@@ -23,7 +22,7 @@ import (
2322
*/
2423

2524
type Keypair struct {
26-
sendNonce uint64 // accessed atomically
25+
sendNonce atomic.Uint64
2726
send cipher.AEAD
2827
receive cipher.AEAD
2928
replayFilter replay.Filter
@@ -37,15 +36,7 @@ type Keypairs struct {
3736
sync.RWMutex
3837
current *Keypair
3938
previous *Keypair
40-
next *Keypair
41-
}
42-
43-
func (kp *Keypairs) storeNext(next *Keypair) {
44-
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
45-
}
46-
47-
func (kp *Keypairs) loadNext() *Keypair {
48-
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
39+
next atomic.Pointer[Keypair]
4940
}
5041

5142
func (kp *Keypairs) Current() *Keypair {

0 commit comments

Comments
 (0)