Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions client/cmd/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {

client := proto.NewDaemonServiceClient(conn)

stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true})
if err != nil {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
}
Expand Down Expand Up @@ -303,7 +303,7 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {

func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())
statusResp, err := getStatus(cmd.Context(), true)
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
Expand Down
6 changes: 3 additions & 3 deletions client/cmd/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {

ctx := internal.CtxInitState(cmd.Context())

resp, err := getStatus(ctx)
resp, err := getStatus(ctx, false)
if err != nil {
return err
}
Expand Down Expand Up @@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil
}

func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
Expand All @@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
}
defer conn.Close()

resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
}
Expand Down
20 changes: 10 additions & 10 deletions client/internal/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ type Engine struct {
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup

probeStunTurn *relay.StunTurnProbe
}

// Peer is an instance of the Connection Peer
Expand Down Expand Up @@ -244,6 +246,7 @@ func NewEngine(
statusRecorder: statusRecorder,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
}

sm := profilemanager.NewServiceManager("")
Expand Down Expand Up @@ -1663,7 +1666,7 @@ func (e *Engine) getRosenpassAddr() string {

// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
// and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes() bool {
func (e *Engine) RunHealthProbes(waitForResult bool) bool {
e.syncMsgMux.Lock()

signalHealthy := e.signal.IsHealthy()
Expand Down Expand Up @@ -1695,8 +1698,12 @@ func (e *Engine) RunHealthProbes() bool {
}

e.syncMsgMux.Unlock()

results := e.probeICE(stuns, turns)
var results []relay.ProbeResult
if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results)

relayHealthy := true
Expand All @@ -1713,13 +1720,6 @@ func (e *Engine) RunHealthProbes() bool {
return allHealthy
}

func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
return append(
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
)
}

// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() {
e.syncMsgMux.Lock()
Expand Down
211 changes: 189 additions & 22 deletions client/internal/relay/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package relay

import (
"context"
"crypto/sha256"
"errors"
"fmt"
"net"
"sync"
Expand All @@ -15,15 +17,180 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)

const (
DefaultCacheTTL = 20 * time.Second
probeTimeout = 6 * time.Second
)

var (
ErrCheckInProgress = errors.New("probe check is already in progress")
)

// ProbeResult holds the info about the result of a relay probe request
type ProbeResult struct {
URI string
Err error
Addr string
}

type StunTurnProbe struct {
cacheResults []ProbeResult
cacheTimestamp time.Time
cacheKey string
cacheTTL time.Duration
probeInProgress bool
probeDone chan struct{}
mu sync.Mutex
}

func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe {
return &StunTurnProbe{
cacheTTL: cacheTTL,
}
}

func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
cacheKey := generateCacheKey(stuns, turns)

p.mu.Lock()
if p.probeInProgress {
doneChan := p.probeDone
p.mu.Unlock()

select {
case <-ctx.Done():
log.Debugf("Context cancelled while waiting for probe results")
return createErrorResults(stuns, turns)
case <-doneChan:
return p.getCachedResults(cacheKey, stuns, turns)
}
}

p.probeInProgress = true
probeDone := make(chan struct{})
p.probeDone = probeDone
p.mu.Unlock()

p.doProbe(ctx, stuns, turns, cacheKey)
close(probeDone)

return p.getCachedResults(cacheKey, stuns, turns)
}

// ProbeAll probes all given servers asynchronously and returns the results
func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
cacheKey := generateCacheKey(stuns, turns)

p.mu.Lock()

if results := p.checkCache(cacheKey); results != nil {
p.mu.Unlock()
return results
}

if p.probeInProgress {
p.mu.Unlock()
return createErrorResults(stuns, turns)
}

p.probeInProgress = true
probeDone := make(chan struct{})
p.probeDone = probeDone
log.Infof("started new probe for STUN, TURN servers")
go func() {
p.doProbe(ctx, stuns, turns, cacheKey)
close(probeDone)
}()

p.mu.Unlock()

timer := time.NewTimer(1300 * time.Millisecond)
defer timer.Stop()

select {
case <-ctx.Done():
log.Debugf("Context cancelled while waiting for probe results")
return createErrorResults(stuns, turns)
case <-probeDone:
// when the probe is return fast, return the results right away
return p.getCachedResults(cacheKey, stuns, turns)
case <-timer.C:
// if the probe takes longer than 1.3s, return error results to avoid blocking
return createErrorResults(stuns, turns)
}
}

func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult {
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
age := time.Since(p.cacheTimestamp)
if age < p.cacheTTL {
results := append([]ProbeResult(nil), p.cacheResults...)
log.Debugf("returning cached probe results (age: %v)", age)
return results
}
}
return nil
}

func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
p.mu.Lock()
defer p.mu.Unlock()

if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
return append([]ProbeResult(nil), p.cacheResults...)
}
return createErrorResults(stuns, turns)
}

func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) {
defer func() {
p.mu.Lock()
p.probeInProgress = false
p.mu.Unlock()
}()
results := make([]ProbeResult, len(stuns)+len(turns))

var wg sync.WaitGroup
for i, uri := range stuns {
wg.Add(1)
go func(idx int, stunURI *stun.URI) {
defer wg.Done()

probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()

results[idx].URI = stunURI.String()
results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI)
}(i, uri)
}

stunOffset := len(stuns)
for i, uri := range turns {
wg.Add(1)
go func(idx int, turnURI *stun.URI) {
defer wg.Done()

probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()

results[idx].URI = turnURI.String()
results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI)
}(stunOffset+i, uri)
}

wg.Wait()

p.mu.Lock()
p.cacheResults = results
p.cacheTimestamp = time.Now()
p.cacheKey = cacheKey
p.mu.Unlock()

log.Debug("Stored new probe results in cache")
}

// ProbeSTUN tries binding to the given STUN uri and acquiring an address
func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
defer func() {
if probeErr != nil {
log.Debugf("stun probe error from %s: %s", uri, probeErr)
Expand Down Expand Up @@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
}

// ProbeTURN tries allocating a session from the given TURN URI
func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
defer func() {
if probeErr != nil {
log.Debugf("turn probe error from %s: %s", uri, probeErr)
Expand Down Expand Up @@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
return relayConn.LocalAddr().String(), nil
}

// ProbeAll probes all given servers asynchronously and returns the results
func ProbeAll(
ctx context.Context,
fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error),
relays []*stun.URI,
) []ProbeResult {
results := make([]ProbeResult, len(relays))

var wg sync.WaitGroup
for i, uri := range relays {
ctx, cancel := context.WithTimeout(ctx, 6*time.Second)
defer cancel()
func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
total := len(stuns) + len(turns)
results := make([]ProbeResult, total)

wg.Add(1)
go func(res *ProbeResult, stunURI *stun.URI) {
defer wg.Done()
res.URI = stunURI.String()
res.Addr, res.Err = fn(ctx, stunURI)
}(&results[i], uri)
allURIs := append(append([]*stun.URI{}, stuns...), turns...)
for i, uri := range allURIs {
results[i] = ProbeResult{
URI: uri.String(),
Err: ErrCheckInProgress,
}
}

wg.Wait()

return results
}

func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string {
h := sha256.New()
for _, uri := range stuns {
h.Write([]byte(uri.String()))
}
for _, uri := range turns {
h.Write([]byte(uri.String()))
}
return fmt.Sprintf("%x", h.Sum(nil))
}
9 changes: 3 additions & 6 deletions client/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1057,10 +1057,7 @@ func (s *Server) Status(
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)

if msg.GetFullPeerStatus {
if msg.ShouldRunProbes {
s.runProbes()
}

s.runProbes(msg.ShouldRunProbes)
fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
Expand All @@ -1070,7 +1067,7 @@ func (s *Server) Status(
return &statusResponse, nil
}

func (s *Server) runProbes() {
func (s *Server) runProbes(waitForProbeResult bool) {
if s.connectClient == nil {
return
}
Expand All @@ -1081,7 +1078,7 @@ func (s *Server) runProbes() {
}

if time.Since(s.lastProbe) > probeThreshold {
if engine.RunHealthProbes() {
if engine.RunHealthProbes(waitForProbeResult) {
s.lastProbe = time.Now()
}
}
Expand Down
Loading
Loading