Skip to content

Commit

Permalink
[WIP] fixup 2
Browse files Browse the repository at this point in the history
Signed-off-by: He Xian <[email protected]>
  • Loading branch information
hexian000 committed Oct 11, 2024
1 parent ca6c546 commit de7c7e6
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 47 deletions.
14 changes: 2 additions & 12 deletions v3/config/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@ func (c *File) Timeout() time.Duration {
return time.Duration(c.ConnectTimeout) * time.Second
}

// FindService finds dial address by service name
func (c *File) FindService(service string) string {
if service == "" {
return ""
}
return c.Services[service]
}

// SetConnParams sets TCP params
func (c *File) SetConnParams(conn net.Conn) {
if tcpConn := conn.(*net.TCPConn); tcpConn != nil {
Expand All @@ -44,10 +36,8 @@ func (c *File) SetConnParams(conn net.Conn) {
}

// NewTLSConfig creates tls.Config
func (c *File) NewTLSConfig(sni string) (*tls.Config, error) {
if sni == "" {
sni = c.ServerName
}
func (c *File) NewTLSConfig() (*tls.Config, error) {
sni := c.ServerName
if c.Certificate == "" && c.PrivateKey == "" {
return nil, nil
}
Expand Down
6 changes: 3 additions & 3 deletions v3/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ func (h *TLSHandler) Serve(ctx context.Context, conn net.Conn) {
return
}
}
cfg := h.s.getConfig()
cfg, tlscfg := h.s.getConfig()
cfg.SetConnParams(conn)
conn = snet.FlowMeter(conn, h.s.flowStats)
if tlscfg := h.s.getTLSConfig(); tlscfg != nil {
if tlscfg != nil {
conn = tls.Server(conn, tlscfg)
} else {
slog.Warningf("? <= %v: connection is not encrypted", conn.RemoteAddr())
Expand Down Expand Up @@ -73,7 +73,7 @@ func (h *TLSHandler) Serve(ctx context.Context, conn net.Conn) {
_ = conn.SetDeadline(time.Time{})
h.s.stats.authorized.Add(1)

_, err = h.s.startMux(conn, req.PeerName, req.Service, false)
_, err = h.s.startMux(conn, cfg, req.PeerName, req.Service, false)
if err != nil {
slog.Errorf("%q <= %v: %s", req.PeerName, conn.RemoteAddr(), formats.Error(err))
return
Expand Down
3 changes: 2 additions & 1 deletion v3/metric.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ func (h *apiConfigHandler) Post(w http.ResponseWriter, r *http.Request) {
}

func (h *apiConfigHandler) Get(w http.ResponseWriter, r *http.Request) {
b, err := json.MarshalIndent(h.s.getConfig(), "", " ")
cfg, _ := h.s.getConfig()
b, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(formats.Error(err)))
Expand Down
45 changes: 24 additions & 21 deletions v3/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,23 @@ type Server struct {
// NewServer creates a server object
func NewServer(cfg *config.File) *Server {
g := routines.NewGroup()
return &Server{
s := &Server{
cfg: cfg,
listeners: make(map[string]net.Listener),
tunnels: make(map[string]*tunnel),
ctx: contextMgr{
timeout: cfg.Timeout,
contexts: make(map[context.Context]context.CancelFunc),
},
f: forwarder.New(cfg.MaxConn, g),
flowStats: &snet.FlowStats{},
recentEvents: eventlog.NewRecent(100),
g: g,
}
s.ctx.timeout = func() time.Duration {
cfg, _ := s.getConfig()
return cfg.Timeout()
}
return s
}

func (s *Server) addTunnel(peerName string) *tunnel {
Expand Down Expand Up @@ -165,15 +169,21 @@ func (s *Server) Serve(listener net.Listener, handler Handler) {
errors.Is(err, yamux.ErrSessionShutdown) {
return
}
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
time.Sleep(500 * time.Millisecond)
}
slog.Errorf("serve: %s", formats.Error(err))
return
}
s.serveOne(conn, handler)
if err := s.g.Go(func() {
s.serveOne(conn, handler)
}); err != nil {
slog.Errorf("serve: %s", formats.Error(err))
}
}
}

func (s *Server) startMux(conn net.Conn, peerName, service string, isDialed bool) (mux *yamux.Session, err error) {
cfg := s.getConfig()
func (s *Server) startMux(conn net.Conn, cfg *config.File, peerName, service string, isDialed bool) (mux *yamux.Session, err error) {
muxcfg := cfg.NewMuxConfig(peerName, isDialed)
if isDialed {
mux, err = yamux.Client(conn, muxcfg)
Expand All @@ -189,8 +199,11 @@ func (s *Server) startMux(conn net.Conn, peerName, service string, isDialed bool
}
}
var muxHandler Handler
dialAddr := cfg.FindService(service)
if dialAddr == "" {
if dialAddr, ok := cfg.Services[service]; ok {
muxHandler = &ForwardHandler{
s: s, tag: peerName, dial: dialAddr,
}
} else {
if service != "" {
slog.Warningf("%q <= %v: unknown service %q", peerName, conn.RemoteAddr(), service)
}
Expand All @@ -200,10 +213,6 @@ func (s *Server) startMux(conn net.Conn, peerName, service string, isDialed bool
return
}
muxHandler = &EmptyHandler{}
} else {
muxHandler = &ForwardHandler{
s: s, tag: peerName, dial: dialAddr,
}
}
err = s.g.Go(func() {
if t := s.findTunnel(peerName); t != nil {
Expand Down Expand Up @@ -238,7 +247,7 @@ func (s *Server) Start() error {
}
slog.Noticef("mux listen: %v", l.Addr())
h := &TLSHandler{s: s}
c := s.getConfig()
c, _ := s.getConfig()
s.l = hlistener.Wrap(l, &hlistener.Config{
Start: uint32(c.StartupLimitStart),
Full: uint32(c.StartupLimitFull),
Expand Down Expand Up @@ -294,7 +303,7 @@ func (s *Server) Shutdown() error {

// LoadConfig reloads the configuration file
func (s *Server) LoadConfig(cfg *config.File) error {
tlscfg, err := cfg.NewTLSConfig(cfg.ServerName)
tlscfg, err := cfg.NewTLSConfig()
if err != nil {
return err
}
Expand All @@ -305,14 +314,8 @@ func (s *Server) LoadConfig(cfg *config.File) error {
return nil
}

func (s *Server) getConfig() *config.File {
s.cfgMu.RLock()
defer s.cfgMu.RUnlock()
return s.cfg
}

func (s *Server) getTLSConfig() *tls.Config {
func (s *Server) getConfig() (*config.File, *tls.Config) {
s.cfgMu.RLock()
defer s.cfgMu.RUnlock()
return s.tlscfg
return s.cfg, s.tlscfg
}
19 changes: 9 additions & 10 deletions v3/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ type tunnel struct {
lastChanged time.Time
}

func (t *tunnel) getConfig() *config.Tunnel {
cfg := t.s.getConfig()
return cfg.GetTunnel(t.peerName)
func (t *tunnel) getConfig() (*config.File, *tls.Config, *config.Tunnel) {
cfg, tlscfg := t.s.getConfig()
return cfg, tlscfg, cfg.GetTunnel(t.peerName)
}

func (t *tunnel) Start() error {
c := t.getConfig()
_, _, c := t.getConfig()
if c.Listen != "" {
l, err := t.s.Listen(c.Listen)
if err != nil {
Expand Down Expand Up @@ -63,15 +63,15 @@ func (t *tunnel) redial() {
if redialCount > t.redialCount {
t.redialCount = redialCount
}
c := t.getConfig()
_, _, c := t.getConfig()
slog.Warningf("tunnel %q: redial #%d to %s: %s", t.peerName, t.redialCount, c.MuxDial, formats.Error(err))
return
}
t.redialCount = 0
}

func (t *tunnel) scheduleRedial() <-chan time.Time {
c := t.getConfig()
_, _, c := t.getConfig()
if !c.NoRedial || c.MuxDial == "" || t.redialCount < 1 {
return make(<-chan time.Time)
}
Expand Down Expand Up @@ -197,7 +197,7 @@ func (t *tunnel) dial(ctx context.Context) (*yamux.Session, error) {
if mux := t.getMux(); mux != nil {
return mux, nil
}
tuncfg := t.getConfig()
cfg, tlscfg, tuncfg := t.getConfig()
if tuncfg.MuxDial == "" {
return nil, ErrNoDialAddress
}
Expand All @@ -211,10 +211,9 @@ func (t *tunnel) dial(ctx context.Context) (*yamux.Session, error) {
return nil, err
}
}
cfg := t.s.getConfig()
cfg.SetConnParams(conn)
conn = snet.FlowMeter(conn, t.s.flowStats)
if tlscfg := t.s.getTLSConfig(); tlscfg != nil {
if tlscfg != nil {
conn = tls.Client(conn, tlscfg)
} else {
slog.Warningf("%q => %v: connection is not encrypted", t.peerName, conn.RemoteAddr())
Expand All @@ -231,7 +230,7 @@ func (t *tunnel) dial(ctx context.Context) (*yamux.Session, error) {
}
_ = conn.SetDeadline(time.Time{})

mux, err := t.s.startMux(conn, rsp.PeerName, rsp.Service, true)
mux, err := t.s.startMux(conn, cfg, rsp.PeerName, rsp.Service, true)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit de7c7e6

Please sign in to comment.