Skip to content

Commit

Permalink
[WIP] refactor
Browse files Browse the repository at this point in the history
Signed-off-by: He Xian <[email protected]>
  • Loading branch information
hexian000 committed Oct 7, 2024
1 parent 3e15c8c commit 38fd54e
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 208 deletions.
1 change: 1 addition & 0 deletions make.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ case "$1" in
"d")
set -x
go fmt ./...
GCFLAGS="${GCFLAGS} -N -l"
CGO_ENABLED=0 \
nice go build ${GOFLAGS} -gcflags "${GCFLAGS}" -ldflags "${LDFLAGS}" \
-o "${OUT}" "${PACKAGE}"
Expand Down
102 changes: 63 additions & 39 deletions v3/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,36 @@ import (
"github.com/hexian000/gosnippets/slog"
)

// TunnelConfig represents a fixed tunnel between 2 peers
type TunnelConfig struct {
// local service identity
Service string `json:"service"`
// remote service
RemoteService string `json:"remoteservice"`
// (optional) mux listen address
MuxListen string `json:"muxlisten,omitempty"`
// (optional) is disabled
Disabled bool `json:"disabled,omitempty"`
// (optional) mux dial address
MuxDial string `json:"muxdial,omitempty"`
// (optional) remote service listen address
MuxDial string `json:"addr,omitempty"`
// (optional) local listener address
Listen string `json:"listen,omitempty"`
// (optional) service dial address
Dial string `json:"dial,omitempty"`
// remote service name
PeerService string `json:"peerservice"`
// (optional) keep tunnels connected
Redial bool `json:"redial"`
// (optional) client-side keep alive interval in seconds, default to 25 (every 25s)
KeepAlive int `json:"keepalive"`
// (optional) mux accept backlog, default to 256, you may not want to change this
AcceptBacklog int `json:"backlog"`
// (optional) stream window size in bytes, default to 256 KiB, increase this on long fat networks
StreamWindow uint32 `json:"window"`
}

// Config file
type Config struct {
// tunnel configs
Tunnels []TunnelConfig `json:"tunnel"`
// (optional) keep tunnels connected
Redial bool `json:"redial"`
// (optional) local peer name
PeerName string `json:"peername,omitempty"`
// (optional) mux listen address
MuxListen string `json:"muxlisten,omitempty"`
// service name to dial address
Services map[string]string `json:"services"`
// peer name to config
Peers map[string]TunnelConfig `json:"peers"`
// (optional) health check and metrics, default to "" (disabled)
HTTPListen string `json:"httplisten,omitempty"`
// TLS: (optional) SNI field in handshake, default to "example.com"
Expand All @@ -49,10 +58,6 @@ type Config struct {
AuthorizedCerts []string `json:"authcerts"`
// (optional) TCP no delay, default to true
NoDelay bool `json:"nodelay"`
// (optional) client-side keep alive interval in seconds, default to 25 (every 25s)
KeepAlive int `json:"keepalive"`
// (optional) server-side keep alive interval in seconds, default to 300 (every 5min)
ServerKeepAlive int `json:"serverkeepalive"`
// (optional) soft limit of concurrent unauthenticated connections, default to 10
StartupLimitStart int `json:"startuplimitstart"`
// (optional) probability of random disconnection when soft limit is exceeded, default to 30 (30%)
Expand All @@ -63,10 +68,8 @@ type Config struct {
MaxConn int `json:"maxconn"`
// (optional) max concurrent incoming sessions, default to 128
MaxSessions int `json:"maxsessions"`
// (optional) mux accept backlog, default to 256, you may not want to change this
AcceptBacklog int `json:"backlog"`
// (optional) stream window size in bytes, default to 256 KiB, increase this on long fat networks
StreamWindow uint32 `json:"window"`
// (optional) server-side keep alive interval in seconds, default to 300 (every 5min)
ServerKeepAlive int `json:"serverkeepalive"`
// (optional) tunnel connecting timeout in seconds, default to 15
ConnectTimeout int `json:"timeout"`
// (optional) stream open timeout in seconds, default to 30
Expand All @@ -84,16 +87,12 @@ type Config struct {
var DefaultConfig = Config{
ServerName: "example.com",
NoDelay: true,
Redial: true,
KeepAlive: 25, // every 25s
ServerKeepAlive: 300, // every 5min
StartupLimitStart: 10,
StartupLimitRate: 30,
StartupLimitFull: 60,
MaxConn: 16384,
MaxSessions: 128,
AcceptBacklog: 256,
StreamWindow: 256 * 1024, // 256 KiB
ServerKeepAlive: 300, // every 5min
ConnectTimeout: 15,
StreamOpenTimeout: 30,
StreamCloseTimeout: 120,
Expand All @@ -102,6 +101,13 @@ var DefaultConfig = Config{
LogLevel: slog.LevelNotice,
}

var DefaultTunnelConfig = TunnelConfig{
Redial: true,
KeepAlive: 25, // every 25s
AcceptBacklog: 256,
StreamWindow: 256 * 1024, // 256 KiB
}

func parseConfig(b []byte) (*Config, error) {
cfg := DefaultConfig
if err := json.Unmarshal(b, &cfg); err != nil {
Expand Down Expand Up @@ -133,12 +139,13 @@ func rangeCheckInt(key string, value int, min int, max int) error {
}

func (c *Config) Validate() error {
if err := rangeCheckInt("keepalive", c.KeepAlive, 0, 86400); err != nil {
return err
}
if err := rangeCheckInt("serverkeepalive", c.ServerKeepAlive, 0, 86400); err != nil {
return err
}
// TODO
// if err := rangeCheckInt("keepalive", c.KeepAlive, 0, 86400); err != nil {
// return err
// }
// if err := rangeCheckInt("serverkeepalive", c.ServerKeepAlive, 0, 86400); err != nil {
// return err
// }
if err := rangeCheckInt("startuplimitstart", c.StartupLimitStart, 1, math.MaxInt); err != nil {
return err
}
Expand Down Expand Up @@ -222,21 +229,38 @@ func (w *logWrapper) Write(p []byte) (n int, err error) {
}

// NewMuxConfig creates yamux.Config
func (c *Config) NewMuxConfig(isServer bool) *yamux.Config {
keepAliveInterval := time.Duration(c.KeepAlive) * time.Second
if isServer {
keepAliveInterval = time.Duration(c.ServerKeepAlive) * time.Second
func (c *Config) NewMuxConfig() *yamux.Config {
t := DefaultTunnelConfig
keepAliveInterval := time.Duration(c.ServerKeepAlive) * time.Second
enableKeepAlive := keepAliveInterval >= time.Second
if !enableKeepAlive {
keepAliveInterval = 15 * time.Second
}
return &yamux.Config{
AcceptBacklog: t.AcceptBacklog,
EnableKeepAlive: enableKeepAlive,
KeepAliveInterval: keepAliveInterval,
ConnectionWriteTimeout: time.Duration(c.WriteTimeout) * time.Second,
MaxStreamWindowSize: t.StreamWindow,
StreamOpenTimeout: time.Duration(c.StreamOpenTimeout) * time.Second,
StreamCloseTimeout: time.Duration(c.StreamCloseTimeout) * time.Second,
Logger: log.New(&logWrapper{slog.Default()}, "", 0),
}
}

// NewMuxConfig creates yamux.Config
func (t *TunnelConfig) NewMuxConfig(c *Config) *yamux.Config {
keepAliveInterval := time.Duration(t.KeepAlive) * time.Second
enableKeepAlive := keepAliveInterval >= time.Second
if !enableKeepAlive {
keepAliveInterval = 15 * time.Second
}
return &yamux.Config{
AcceptBacklog: c.AcceptBacklog,
AcceptBacklog: t.AcceptBacklog,
EnableKeepAlive: enableKeepAlive,
KeepAliveInterval: keepAliveInterval,
ConnectionWriteTimeout: time.Duration(c.WriteTimeout) * time.Second,
MaxStreamWindowSize: c.StreamWindow,
MaxStreamWindowSize: t.StreamWindow,
StreamOpenTimeout: time.Duration(c.StreamOpenTimeout) * time.Second,
StreamCloseTimeout: time.Duration(c.StreamCloseTimeout) * time.Second,
Logger: log.New(&logWrapper{slog.Default()}, "", 0),
Expand Down
84 changes: 53 additions & 31 deletions v3/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ type Handler interface {
// TLSHandler creates a tunnel
type TLSHandler struct {
s *Server
t *Tunnel

halfOpen atomic.Uint32
}
Expand All @@ -39,7 +38,7 @@ func (h *TLSHandler) Serve(ctx context.Context, conn net.Conn) {
start := time.Now()
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
slog.Errorf("%q <= %v: %s", h.t.tag, conn.RemoteAddr(), formats.Error(err))
slog.Errorf("? <= %v: %s", conn.RemoteAddr(), formats.Error(err))
return
}
}
Expand All @@ -49,50 +48,73 @@ func (h *TLSHandler) Serve(ctx context.Context, conn net.Conn) {
if tlscfg := h.s.getTLSConfig(); tlscfg != nil {
conn = tls.Server(conn, tlscfg)
} else {
slog.Warningf("%q <= %v: connection is not encrypted", h.t.tag, conn.RemoteAddr())
slog.Warningf("? <= %v: connection is not encrypted", conn.RemoteAddr())
}
req, err := proto.RecvRequest(conn)
req, err := proto.RecvMessage(conn)
if err != nil {
slog.Errorf("%q <= %v: %s", h.t.tag, conn.RemoteAddr(), formats.Error(err))
slog.Errorf("? <= %v: %s", conn.RemoteAddr(), formats.Error(err))
return
}
if req.Msg != proto.MsgClientHello {
slog.Errorf("? <= %v: %s", conn.RemoteAddr(), "invalid message")
return
}
t := h.t
rsp := &proto.ServerMsg{
Type: proto.Type,
Msg: proto.MsgHello,
Service: c.RemoteService,
rsp := &proto.Message{
Type: proto.Type,
Msg: proto.MsgServerHello,
PeerName: c.PeerName,
}
if t.c.RemoteService != "" {
rsp.Service = t.c.RemoteService
if cfg, ok := c.Peers[req.PeerName]; ok {
rsp.Service = cfg.PeerService
}
if err := proto.SendResponse(conn, rsp); err != nil {
slog.Errorf("%q <= %v: %s", h.t.tag, conn.RemoteAddr(), formats.Error(err))
if err := proto.SendMessage(conn, rsp); err != nil {
slog.Errorf("%q <= %v: %s", req.PeerName, conn.RemoteAddr(), formats.Error(err))
return
}
_ = conn.SetDeadline(time.Time{})
mux, err := yamux.Server(conn, h.s.getMuxConfig(true))
h.s.stats.authorized.Add(1)
t := h.s.findPeer(req.PeerName)
var muxcfg *yamux.Config
if t != nil {
muxcfg = t.c.NewMuxConfig(c)
} else {
muxcfg = c.NewMuxConfig()
}
mux, err := yamux.Server(conn, muxcfg)
if err != nil {
slog.Errorf("%q <= %v: %s", h.t.tag, conn.RemoteAddr(), formats.Error(err))
slog.Errorf("%q <= %v: %s", req.PeerName, conn.RemoteAddr(), formats.Error(err))
return
}
h.s.stats.authorized.Add(1)
if req.Service != "" {
if tun := h.s.findTunnel(req.Service); tun != nil {
t = tun
} else {
slog.Infof("%q <= %v: unknown service %q", t.tag, conn.RemoteAddr(), rsp.Service)
if t != nil {
t.addMux(mux, false)
}
var muxHandler Handler
if dialAddr, ok := c.Services[req.Service]; ok {
muxHandler = &ForwardHandler{
s: h.s, tag: req.PeerName, dial: dialAddr,
}
}
t.addMux(mux, false)
if muxHandler == nil {
if req.Service != "" {
slog.Infof("%q <= %v: unknown service %q", req.PeerName, conn.RemoteAddr(), rsp.Service)
}
if err := mux.GoAway(); err != nil {
slog.Errorf("%q <= %v: %s", req.PeerName, conn.RemoteAddr(), formats.Error(err))
return
}
muxHandler = &EmptyHandler{}
}
if err := h.s.g.Go(func() {
defer t.delMux(mux)
t.Serve(mux)
if t != nil {
defer t.delMux(mux)
}
h.s.Serve(mux, muxHandler)
}); err != nil {
slog.Errorf("%q <= %v: %s", t.tag, conn.RemoteAddr(), formats.Error(err))
slog.Errorf("%q <= %v: %s", req.PeerName, conn.RemoteAddr(), formats.Error(err))
ioClose(mux)
return
}
slog.Infof("%q <= %v: setup %v", t.tag, conn.RemoteAddr(), formats.Duration(time.Since(start)))
slog.Infof("%q <= %v: setup %v", req.PeerName, conn.RemoteAddr(), formats.Duration(time.Since(start)))
}

// ForwardHandler forwards connections to another plain address
Expand Down Expand Up @@ -131,20 +153,20 @@ func (h *TunnelHandler) Serve(ctx context.Context, accepted net.Conn) {
dialed, err := h.t.MuxDial(ctx)
if err != nil {
if errors.Is(err, ErrDialInProgress) {
slog.Debugf("%v -> %q: %s", accepted.RemoteAddr(), h.t.tag, formats.Error(err))
slog.Debugf("%v -> %q: %s", accepted.RemoteAddr(), h.t.peerName, formats.Error(err))
} else {
slog.Errorf("%v -> %q: %s", accepted.RemoteAddr(), h.t.tag, formats.Error(err))
slog.Errorf("%v -> %q: %s", accepted.RemoteAddr(), h.t.peerName, formats.Error(err))
}
ioClose(accepted)
return
}
if err := h.s.f.Forward(accepted, dialed); err != nil {
slog.Errorf("%v -> %q: %s", accepted.RemoteAddr(), h.t.tag, formats.Error(err))
slog.Errorf("%v -> %q: %s", accepted.RemoteAddr(), h.t.peerName, formats.Error(err))
ioClose(accepted)
ioClose(dialed)
return
}
slog.Debugf("%v -> %q: forward established", h.l.Addr(), h.t.tag)
slog.Debugf("%v -> %q: forward established", h.l.Addr(), h.t.peerName)
}

// EmptyHandler rejects all connections
Expand Down
28 changes: 12 additions & 16 deletions v3/proto/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,15 @@ var (
)

const (
MsgHello = iota
MsgClientHello = iota
MsgServerHello
)

type ClientMsg struct {
Type string `json:"type"`
Msg int `json:"msgid"`
Service string `json:"service,omitempty"`
}

type ServerMsg struct {
Type string `json:"type"`
Msg int `json:"msgid"`
Service string `json:"service,omitempty"`
type Message struct {
Type string `json:"type"`
Msg int `json:"msgid"`
PeerName string `json:"peername,omitempty"`
Service string `json:"service,omitempty"`
}

var (
Expand Down Expand Up @@ -90,11 +86,11 @@ func checkType(s string) error {
return nil
}

func Roundtrip(conn net.Conn, req *ClientMsg) (*ServerMsg, error) {
func Roundtrip(conn net.Conn, req *Message) (*Message, error) {
if err := sendmsg(conn, req); err != nil {
return nil, err
}
rsp := &ServerMsg{}
rsp := &Message{}
if err := recvmsg(conn, rsp); err != nil {
return nil, err
}
Expand All @@ -104,8 +100,8 @@ func Roundtrip(conn net.Conn, req *ClientMsg) (*ServerMsg, error) {
return rsp, nil
}

func RecvRequest(conn net.Conn) (*ClientMsg, error) {
req := &ClientMsg{}
func RecvMessage(conn net.Conn) (*Message, error) {
req := &Message{}
if err := recvmsg(conn, req); err != nil {
return nil, err
}
Expand All @@ -115,6 +111,6 @@ func RecvRequest(conn net.Conn) (*ClientMsg, error) {
return req, nil
}

func SendResponse(conn net.Conn, rsp *ServerMsg) error {
func SendMessage(conn net.Conn, rsp *Message) error {
return sendmsg(conn, rsp)
}
Loading

0 comments on commit 38fd54e

Please sign in to comment.