Skip to content

Commit 0243978

Browse files
committed
global: use netip where possible now
There are more places where we'll need to add it later, when Go 1.18 comes out with support for it in the "net" package. Also, allowedips still uses slices internally, which might be suboptimal. Signed-off-by: Jason A. Donenfeld <[email protected]>
1 parent 851efb1 commit 0243978

22 files changed

+239
-280
lines changed

Diff for: conn/bind_linux.go

+19-31
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"unsafe"
1515

1616
"golang.org/x/sys/unix"
17+
"golang.zx2c4.com/go118/netip"
1718
)
1819

1920
type ipv4Source struct {
@@ -70,32 +71,30 @@ var _ Bind = (*LinuxSocketBind)(nil)
7071

7172
func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
7273
var end LinuxSocketEndpoint
73-
addr, err := parseEndpoint(s)
74+
e, err := netip.ParseAddrPort(s)
7475
if err != nil {
7576
return nil, err
7677
}
7778

78-
ipv4 := addr.IP.To4()
79-
if ipv4 != nil {
79+
if e.Addr().Is4() {
8080
dst := end.dst4()
8181
end.isV6 = false
82-
dst.Port = addr.Port
83-
copy(dst.Addr[:], ipv4)
82+
dst.Port = int(e.Port())
83+
dst.Addr = e.Addr().As4()
8484
end.ClearSrc()
8585
return &end, nil
8686
}
8787

88-
ipv6 := addr.IP.To16()
89-
if ipv6 != nil {
90-
zone, err := zoneToUint32(addr.Zone)
88+
if e.Addr().Is6() {
89+
zone, err := zoneToUint32(e.Addr().Zone())
9190
if err != nil {
9291
return nil, err
9392
}
9493
dst := end.dst6()
9594
end.isV6 = true
96-
dst.Port = addr.Port
95+
dst.Port = int(e.Port())
9796
dst.ZoneId = zone
98-
copy(dst.Addr[:], ipv6[:])
97+
dst.Addr = e.Addr().As16()
9998
end.ClearSrc()
10099
return &end, nil
101100
}
@@ -266,29 +265,19 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
266265
}
267266
}
268267

269-
func (end *LinuxSocketEndpoint) SrcIP() net.IP {
268+
func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
270269
if !end.isV6 {
271-
return net.IPv4(
272-
end.src4().Src[0],
273-
end.src4().Src[1],
274-
end.src4().Src[2],
275-
end.src4().Src[3],
276-
)
270+
return netip.AddrFrom4(end.src4().Src)
277271
} else {
278-
return end.src6().src[:]
272+
return netip.AddrFrom16(end.src6().src)
279273
}
280274
}
281275

282-
func (end *LinuxSocketEndpoint) DstIP() net.IP {
276+
func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
283277
if !end.isV6 {
284-
return net.IPv4(
285-
end.dst4().Addr[0],
286-
end.dst4().Addr[1],
287-
end.dst4().Addr[2],
288-
end.dst4().Addr[3],
289-
)
278+
return netip.AddrFrom4(end.src4().Src)
290279
} else {
291-
return end.dst6().Addr[:]
280+
return netip.AddrFrom16(end.dst6().Addr)
292281
}
293282
}
294283

@@ -305,14 +294,13 @@ func (end *LinuxSocketEndpoint) SrcToString() string {
305294
}
306295

307296
func (end *LinuxSocketEndpoint) DstToString() string {
308-
var udpAddr net.UDPAddr
309-
udpAddr.IP = end.DstIP()
297+
var port int
310298
if !end.isV6 {
311-
udpAddr.Port = end.dst4().Port
299+
port = end.dst4().Port
312300
} else {
313-
udpAddr.Port = end.dst6().Port
301+
port = end.dst6().Port
314302
}
315-
return udpAddr.String()
303+
return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
316304
}
317305

318306
func (end *LinuxSocketEndpoint) ClearDst() {

Diff for: conn/bind_std.go

+12-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"net"
1111
"sync"
1212
"syscall"
13+
14+
"golang.zx2c4.com/go118/netip"
1315
)
1416

1517
// StdNetBind is meant to be a temporary solution on platforms for which
@@ -32,18 +34,22 @@ var _ Bind = (*StdNetBind)(nil)
3234
var _ Endpoint = (*StdNetEndpoint)(nil)
3335

3436
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
35-
addr, err := parseEndpoint(s)
36-
return (*StdNetEndpoint)(addr), err
37+
e, err := netip.ParseAddrPort(s)
38+
return (*StdNetEndpoint)(&net.UDPAddr{
39+
IP: e.Addr().AsSlice(),
40+
Port: int(e.Port()),
41+
Zone: e.Addr().Zone(),
42+
}), err
3743
}
3844

3945
func (*StdNetEndpoint) ClearSrc() {}
4046

41-
func (e *StdNetEndpoint) DstIP() net.IP {
42-
return (*net.UDPAddr)(e).IP
47+
func (e *StdNetEndpoint) DstIP() netip.Addr {
48+
return netip.AddrFromSlice((*net.UDPAddr)(e).IP)
4349
}
4450

45-
func (e *StdNetEndpoint) SrcIP() net.IP {
46-
return nil // not supported
51+
func (e *StdNetEndpoint) SrcIP() netip.Addr {
52+
return netip.Addr{} // not supported
4753
}
4854

4955
func (e *StdNetEndpoint) DstToBytes() []byte {

Diff for: conn/bind_windows.go

+9-10
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"unsafe"
1616

1717
"golang.org/x/sys/windows"
18+
"golang.zx2c4.com/go118/netip"
1819

1920
"golang.zx2c4.com/wireguard/conn/winrio"
2021
)
@@ -128,18 +129,18 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
128129

129130
func (*WinRingEndpoint) ClearSrc() {}
130131

131-
func (e *WinRingEndpoint) DstIP() net.IP {
132+
func (e *WinRingEndpoint) DstIP() netip.Addr {
132133
switch e.family {
133134
case windows.AF_INET:
134-
return append([]byte{}, e.data[2:6]...)
135+
return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
135136
case windows.AF_INET6:
136-
return append([]byte{}, e.data[6:22]...)
137+
return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
137138
}
138-
return nil
139+
return netip.Addr{}
139140
}
140141

141-
func (e *WinRingEndpoint) SrcIP() net.IP {
142-
return nil // not supported
142+
func (e *WinRingEndpoint) SrcIP() netip.Addr {
143+
return netip.Addr{} // not supported
143144
}
144145

145146
func (e *WinRingEndpoint) DstToBytes() []byte {
@@ -161,15 +162,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte {
161162
func (e *WinRingEndpoint) DstToString() string {
162163
switch e.family {
163164
case windows.AF_INET:
164-
addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
165-
return addr.String()
165+
netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
166166
case windows.AF_INET6:
167167
var zone string
168168
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
169169
zone = strconv.FormatUint(uint64(scope), 10)
170170
}
171-
addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))}
172-
return addr.String()
171+
return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
173172
}
174173
return ""
175174
}

Diff for: conn/bindtest/bindtest.go

+5-9
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ import (
1010
"math/rand"
1111
"net"
1212
"os"
13-
"strconv"
1413

14+
"golang.zx2c4.com/go118/netip"
1515
"golang.zx2c4.com/wireguard/conn"
1616
)
1717

@@ -61,9 +61,9 @@ func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d
6161

6262
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
6363

64-
func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
64+
func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
6565

66-
func (c ChannelEndpoint) SrcIP() net.IP { return nil }
66+
func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
6767

6868
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
6969
c.closeSignal = make(chan bool)
@@ -119,13 +119,9 @@ func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
119119
}
120120

121121
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
122-
_, port, err := net.SplitHostPort(s)
122+
addr, err := netip.ParseAddrPort(s)
123123
if err != nil {
124124
return nil, err
125125
}
126-
i, err := strconv.ParseUint(port, 10, 16)
127-
if err != nil {
128-
return nil, err
129-
}
130-
return ChannelEndpoint(i), nil
126+
return ChannelEndpoint(addr.Port()), nil
131127
}

Diff for: conn/conn.go

+4-33
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ package conn
99
import (
1010
"errors"
1111
"fmt"
12-
"net"
1312
"reflect"
1413
"runtime"
1514
"strings"
15+
16+
"golang.zx2c4.com/go118/netip"
1617
)
1718

1819
// A ReceiveFunc receives a single inbound packet from the network.
@@ -68,8 +69,8 @@ type Endpoint interface {
6869
SrcToString() string // returns the local source address (ip:port)
6970
DstToString() string // returns the destination address (ip:port)
7071
DstToBytes() []byte // used for mac2 cookie calculations
71-
DstIP() net.IP
72-
SrcIP() net.IP
72+
DstIP() netip.Addr
73+
SrcIP() netip.Addr
7374
}
7475

7576
var (
@@ -119,33 +120,3 @@ func (fn ReceiveFunc) PrettyName() string {
119120
}
120121
return name
121122
}
122-
123-
func parseEndpoint(s string) (*net.UDPAddr, error) {
124-
// ensure that the host is an IP address
125-
126-
host, _, err := net.SplitHostPort(s)
127-
if err != nil {
128-
return nil, err
129-
}
130-
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
131-
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
132-
// trying to make sure with a small sanity test that this is a real IP address and
133-
// not something that's likely to incur DNS lookups.
134-
host = host[:i]
135-
}
136-
if ip := net.ParseIP(host); ip == nil {
137-
return nil, errors.New("Failed to parse IP address: " + host)
138-
}
139-
140-
// parse address and port
141-
142-
addr, err := net.ResolveUDPAddr("udp", s)
143-
if err != nil {
144-
return nil, err
145-
}
146-
ip4 := addr.IP.To4()
147-
if ip4 != nil {
148-
addr.IP = ip4
149-
}
150-
return addr, err
151-
}

0 commit comments

Comments
 (0)