Skip to content

Commit

Permalink
Update wireguard-go
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Jun 7, 2023
1 parent 9c9affa commit b6068ce
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 163 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ require (
github.com/sagernet/tfo-go v0.0.0-20230303015439-ffcfd8c41cf9
github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2
github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e
github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c
github.com/sagernet/wireguard-go v0.0.0-20230420044414-a7bac1754e77
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.8.3
go.etcd.io/bbolt v1.3.7
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2 h1:kDUqhc9Vsk5HJuhfI
github.com/sagernet/utls v0.0.0-20230309024959-6732c2ab36f2/go.mod h1:JKQMZq/O2qnZjdrt+B57olmfgEmLtY9iiSIEYtWvoSM=
github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e h1:7uw2njHFGE+VpWamge6o56j2RWk4omF6uLKKxMmcWvs=
github.com/sagernet/websocket v0.0.0-20220913015213-615516348b4e/go.mod h1:45TUl8+gH4SIKr4ykREbxKWTxkDlSzFENzctB1dVRRY=
github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c h1:vK2wyt9aWYHHvNLWniwijBu/n4pySypiKRhN32u/JGo=
github.com/sagernet/wireguard-go v0.0.0-20221116151939-c99467f53f2c/go.mod h1:euOmN6O5kk9dQmgSS8Df4psAl3TCjxOz0NW60EWkSaI=
github.com/sagernet/wireguard-go v0.0.0-20230420044414-a7bac1754e77 h1:g6QtRWQ2dKX7EQP++1JLNtw4C2TNxd4/ov8YUpOPOSo=
github.com/sagernet/wireguard-go v0.0.0-20230420044414-a7bac1754e77/go.mod h1:pJDdXzZIwJ+2vmnT0TKzmf8meeum+e2mTDSehw79eE0=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
Expand Down
43 changes: 27 additions & 16 deletions transport/wireguard/client_bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint1
return []conn.ReceiveFunc{c.receive}, 0, nil
}

func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) {
udpConn, err := c.connect()
if err != nil {
select {
Expand All @@ -113,22 +113,26 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) {
err = nil
return
}
n, addr, err := udpConn.ReadFrom(b)
n, addr, err := udpConn.ReadFrom(packets[0])
if err != nil {
udpConn.Close()
select {
case <-c.done:
default:
c.errorHandler.NewError(context.Background(), E.Cause(err, "read packet"))
err = nil
}
return
}
sizes[0] = n
if n > 3 {
b := packets[0]
b[1] = 0
b[2] = 0
b[3] = 0
}
ep = Endpoint(M.SocksaddrFromNet(addr))
eps[0] = Endpoint(M.SocksaddrFromNet(addr))
count = 1
return
}

Expand All @@ -155,32 +159,39 @@ func (c *ClientBind) SetMark(mark uint32) error {
return nil
}

func (c *ClientBind) Send(b []byte, ep conn.Endpoint) error {
func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
udpConn, err := c.connect()
if err != nil {
return err
}
destination := M.Socksaddr(ep.(Endpoint))
if len(b) > 3 {
reserved, loaded := c.reservedForEndpoint[destination]
if !loaded {
reserved = c.reserved
for _, b := range bufs {
if len(b) > 3 {
reserved, loaded := c.reservedForEndpoint[destination]
if !loaded {
reserved = c.reserved
}
b[1] = reserved[0]
b[2] = reserved[1]
b[3] = reserved[2]
}
_, err = udpConn.WriteTo(b, destination)
if err != nil {
udpConn.Close()
return err
}
b[1] = reserved[0]
b[2] = reserved[1]
b[3] = reserved[2]
}
_, err = udpConn.WriteTo(b, destination)
if err != nil {
udpConn.Close()
}
return err
return nil
}

func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
return Endpoint(M.ParseSocksaddr(s)), nil
}

func (c *ClientBind) BatchSize() int {
return 1
}

type wireConn struct {
net.PacketConn
access sync.Mutex
Expand Down
94 changes: 60 additions & 34 deletions transport/wireguard/device_stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import (
"net/netip"
"os"

"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/wireguard-go/tun"
wgTun "github.com/sagernet/wireguard-go/tun"

"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip"
Expand All @@ -30,14 +31,15 @@ var _ Device = (*StackDevice)(nil)
const defaultNIC tcpip.NICID = 1

type StackDevice struct {
stack *stack.Stack
mtu uint32
events chan tun.Event
outbound chan *stack.PacketBuffer
done chan struct{}
dispatcher stack.NetworkDispatcher
addr4 tcpip.Address
addr6 tcpip.Address
stack *stack.Stack
mtu uint32
events chan wgTun.Event
outbound chan *stack.PacketBuffer
packetOutbound chan *buf.Buffer
done chan struct{}
dispatcher stack.NetworkDispatcher
addr4 tcpip.Address
addr6 tcpip.Address
}

func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
Expand All @@ -47,11 +49,12 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er
HandleLocal: true,
})
tunDevice := &StackDevice{
stack: ipStack,
mtu: mtu,
events: make(chan tun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256),
done: make(chan struct{}),
stack: ipStack,
mtu: mtu,
events: make(chan wgTun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256),
packetOutbound: make(chan *buf.Buffer, 256),
done: make(chan struct{}),
}
err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
if err != nil {
Expand Down Expand Up @@ -144,50 +147,69 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr)
return udpConn, nil
}

func (w *StackDevice) Inet4Address() netip.Addr {
return M.AddrFromIP(net.IP(w.addr4))
}

func (w *StackDevice) Inet6Address() netip.Addr {
return M.AddrFromIP(net.IP(w.addr6))
}

func (w *StackDevice) Start() error {
w.events <- tun.EventUp
w.events <- wgTun.EventUp
return nil
}

func (w *StackDevice) File() *os.File {
return nil
}

func (w *StackDevice) Read(p []byte, offset int) (n int, err error) {
func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
select {
case packetBuffer, ok := <-w.outbound:
if !ok {
return 0, os.ErrClosed
}
defer packetBuffer.DecRef()
p := bufs[0]
p = p[offset:]
n := 0
for _, slice := range packetBuffer.AsSlices() {
n += copy(p[n:], slice)
}
sizes[0] = n
count = 1
return
case packet := <-w.packetOutbound:
defer packet.Release()
sizes[0] = copy(bufs[0][offset:], packet.Bytes())
count = 1
return
case <-w.done:
return 0, os.ErrClosed
}
}

func (w *StackDevice) Write(p []byte, offset int) (n int, err error) {
p = p[offset:]
if len(p) == 0 {
return
}
var networkProtocol tcpip.NetworkProtocolNumber
switch header.IPVersion(p) {
case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
for _, b := range bufs {
b = b[offset:]
if len(b) == 0 {
continue
}
var networkProtocol tcpip.NetworkProtocolNumber
switch header.IPVersion(b) {
case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
}
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: bufferv2.MakeWithData(b),
})
w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
packetBuffer.DecRef()
count++
}
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: bufferv2.MakeWithData(p),
})
defer packetBuffer.DecRef()
w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
n = len(p)
return
}

Expand All @@ -203,7 +225,7 @@ func (w *StackDevice) Name() (string, error) {
return "sing-box", nil
}

func (w *StackDevice) Events() chan tun.Event {
func (w *StackDevice) Events() <-chan wgTun.Event {
return w.events
}

Expand All @@ -222,6 +244,10 @@ func (w *StackDevice) Close() error {
return nil
}

func (w *StackDevice) BatchSize() int {
return 1
}

var _ stack.LinkEndpoint = (*wireEndpoint)(nil)

type wireEndpoint StackDevice
Expand Down
60 changes: 45 additions & 15 deletions transport/wireguard/device_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,10 @@ type SystemDevice struct {
name string
mtu int
events chan wgTun.Event
addr4 netip.Addr
addr6 netip.Addr
}

/*func (w *SystemDevice) NewEndpoint() (stack.LinkEndpoint, error) {
gTun, isGTun := w.device.(tun.GVisorTun)
if !isGTun {
return nil, tun.ErrGVisorUnsupported
}
return gTun.NewEndpoint()
}*/

func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes []netip.Prefix, mtu uint32) (*SystemDevice, error) {
var inet4Addresses []netip.Prefix
var inet6Addresses []netip.Prefix
Expand All @@ -55,11 +49,24 @@ func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes
if err != nil {
return nil, err
}
var inet4Address netip.Addr
var inet6Address netip.Addr
if len(inet4Addresses) > 0 {
inet4Address = inet4Addresses[0].Addr()
}
if len(inet6Addresses) > 0 {
inet6Address = inet6Addresses[0].Addr()
}
return &SystemDevice{
dialer.NewDefault(router, option.DialerOptions{
dialer: dialer.NewDefault(router, option.DialerOptions{
BindInterface: interfaceName,
}),
tunInterface, interfaceName, int(mtu), make(chan wgTun.Event),
device: tunInterface,
name: interfaceName,
mtu: int(mtu),
events: make(chan wgTun.Event),
addr4: inet4Address,
addr6: inet6Address,
}, nil
}

Expand All @@ -71,6 +78,14 @@ func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr
return w.dialer.ListenPacket(ctx, destination)
}

func (w *SystemDevice) Inet4Address() netip.Addr {
return w.addr4
}

func (w *SystemDevice) Inet6Address() netip.Addr {
return w.addr6
}

func (w *SystemDevice) Start() error {
w.events <- wgTun.EventUp
return nil
Expand All @@ -80,12 +95,23 @@ func (w *SystemDevice) File() *os.File {
return nil
}

func (w *SystemDevice) Read(bytes []byte, index int) (int, error) {
return w.device.Read(bytes[index-tun.PacketOffset:])
func (w *SystemDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
sizes[0], err = w.device.Read(bufs[0][offset-tun.PacketOffset:])
if err == nil {
count = 1
}
return
}

func (w *SystemDevice) Write(bytes []byte, index int) (int, error) {
return w.device.Write(bytes[index:])
func (w *SystemDevice) Write(bufs [][]byte, offset int) (count int, err error) {
for _, b := range bufs {
_, err = w.device.Write(b[offset:])
if err != nil {
return
}
count++
}
return
}

func (w *SystemDevice) Flush() error {
Expand All @@ -100,10 +126,14 @@ func (w *SystemDevice) Name() (string, error) {
return w.name, nil
}

func (w *SystemDevice) Events() chan wgTun.Event {
func (w *SystemDevice) Events() <-chan wgTun.Event {
return w.events
}

func (w *SystemDevice) Close() error {
return w.device.Close()
}

func (w *SystemDevice) BatchSize() int {
return 1
}
Loading

0 comments on commit b6068ce

Please sign in to comment.