diff --git a/internal/dispatcher/dispatcher.go b/internal/dispatcher/dispatcher.go index 27f95384..fabf5a3d 100644 --- a/internal/dispatcher/dispatcher.go +++ b/internal/dispatcher/dispatcher.go @@ -87,27 +87,41 @@ func (d *Dispatcher) StartSession(ctx context.Context, domain universal.Domain) return err } for { - recv, err := d.RequestSessionInfo(ctx, domain) - if err != nil { + if retry, err := d.tryStartSession(ctx, s, domain); !retry { return err } - defer recv.Close() - select { - case reply := <-recv.Recv(): - if err = protocol.GetError(reply); err != nil { - return err - } - case <-ctx.Done(): - return ctx.Err() - case <-s.readySignal: - return nil - } - select { - case <-time.After(d.conn.RetryInterval()): - case <-ctx.Done(): - return ctx.Err() + } +} + +func (d *Dispatcher) tryStartSession(ctx context.Context, s *session, domain universal.Domain) (retry bool, err error) { + recv, err := d.RequestSessionInfo(ctx, domain) + if err != nil { + return false, err + } + defer recv.Close() + // Request sent + select { + case <-ctx.Done(): + return false, ctx.Err() + case <-s.readySignal: + return false, nil + case <-time.After(d.RetryInterval()): + return true, nil + case reply := <-recv.Recv(): + if err = protocol.GetError(reply); err != nil { + return false, err } } + // Reply received. Normally, the dispatcher will clear readySignal after processing the reply; + // the other branches handle malformed vehicle responses. + select { + case <-s.readySignal: + return false, nil + case <-ctx.Done(): + return false, ctx.Err() + case <-time.After(d.RetryInterval()): + return true, nil + } } // StartSessions starts sessions with the provided vehicle domains (or all supported domains, if diff --git a/internal/dispatcher/dispatcher_test.go b/internal/dispatcher/dispatcher_test.go index 57bb5429..37d868cf 100644 --- a/internal/dispatcher/dispatcher_test.go +++ b/internal/dispatcher/dispatcher_test.go @@ -670,7 +670,7 @@ func TestWaitForAllSessions(t *testing.T) { // Configure the Connector to only respond to the first of two handshakes conn.EnqueueSendError(nil) - conn.EnqueueSendError(errDropMessage) + conn.dropReplies = true key, err := authentication.NewECDHPrivateKey(rand.Reader) if err != nil { @@ -928,6 +928,51 @@ func TestNoValidHandshakeResponse(t *testing.T) { } } +func TestRetryNonresponsive(t *testing.T) { + // Verifies that the client tries to resend session info requests to non-responsive domains + conn := newDummyConnector(t) + defer conn.Close() + + key, err := authentication.NewECDHPrivateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dispatcher, err := New(conn, key) + if err != nil { + t.Fatal(err) + } + if err := dispatcher.Start(ctx); err != nil { + t.Fatal(err) + } + defer dispatcher.Stop() + + const maxCallbacks = 5 + callbackCount := 0 + + conn.callback = func(_ *dummyConnector, _ *universal.RoutableMessage) ([]byte, bool) { + t.Log("Received callback") + callbackCount++ // caller holds conn.lock + if callbackCount >= maxCallbacks { + cancel() + } + return nil, false + } + + if err := dispatcher.StartSession(ctx, testDomain); !errors.Is(err, context.Canceled) { + t.Errorf("Expected key not paired but got %s", err) + } + + conn.lock.Lock() + defer conn.lock.Unlock() + if callbackCount < maxCallbacks { + t.Errorf("Expected %d callbacks, got %d", maxCallbacks, callbackCount) + } +} + func TestCache(t *testing.T) { conn := newDummyConnector(t) key, err := authentication.NewECDHPrivateKey(rand.Reader) diff --git a/pkg/account/version.txt b/pkg/account/version.txt index d15723fb..1c09c74e 100644 --- a/pkg/account/version.txt +++ b/pkg/account/version.txt @@ -1 +1 @@ -0.3.2 +0.3.3