From 34bde908f4157707c686a10b679d2aa2b63ad7e3 Mon Sep 17 00:00:00 2001 From: Zhang Date: Tue, 4 May 2021 20:49:40 +0000 Subject: [PATCH] Introduce client timeout for session start request. cr https://code.amazon.com/reviews/CR-51021781 --- src/sessionmanagerplugin/session/session.go | 31 ++++++++++++------- .../session/session_test.go | 15 +++++++++ src/ssmclicommands/startsession.go | 1 + src/ssmclicommands/startsession_test.go | 1 - 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/sessionmanagerplugin/session/session.go b/src/sessionmanagerplugin/session/session.go index e9a41117..b5b57e16 100644 --- a/src/sessionmanagerplugin/session/session.go +++ b/src/sessionmanagerplugin/session/session.go @@ -20,6 +20,7 @@ import ( "fmt" "io" "os" + "time" "github.com/aws/SSMCLI/src/datachannel" "github.com/aws/SSMCLI/src/log" @@ -34,9 +35,8 @@ import ( const ( LegacyArgumentLength = 4 - ArgumentLength = 7 StartSessionOperation = "StartSession" - VersionFile = "VERSION" + ClientTimeoutSecond = time.Duration(10 * time.Second) ) var SessionRegistry = map[string]ISessionPlugin{} @@ -81,6 +81,7 @@ type Session struct { SessionType string SessionProperties interface{} DisplayMode sessionutil.DisplayMode + Timeout time.Duration } //startSession create the datachannel for session @@ -175,6 +176,7 @@ func ValidateInputAndStartSession(args []string, out io.Writer) { session.ClientId = clientId session.TargetId = target session.DataChannel = &datachannel.DataChannel{} + session.Timeout = ClientTimeoutSecond default: fmt.Fprint(out, "Invalid Operation") @@ -200,17 +202,24 @@ func (s *Session) Execute(log log.T) (err error) { return } + select { // The session type is set either by handshake or the first packet received. - if !<-s.DataChannel.IsSessionTypeSet() { - log.Errorf("unable to SessionType for session %s", s.SessionId) - return errors.New("unable to determine SessionType") - } else { - s.SessionType = s.DataChannel.GetSessionType() - s.SessionProperties = s.DataChannel.GetSessionProperties() - if err = setSessionHandlersWithSessionType(s, log); err != nil { - log.Errorf("Session ending with error: %v", err) - return + case isSessionTypeSet := <-s.DataChannel.IsSessionTypeSet(): + if !isSessionTypeSet { + log.Errorf("unable to SessionType for session %s", s.SessionId) + return errors.New("unable to determine SessionType") + } else { + s.SessionType = s.DataChannel.GetSessionType() + s.SessionProperties = s.DataChannel.GetSessionProperties() + if err = setSessionHandlersWithSessionType(s, log); err != nil { + log.Errorf("Session ending with error: %v", err) + return + } } + case <-time.After(s.Timeout): + log.Errorf("client timeout: unable to receive message %s", s.SessionId) + s.TerminateSession(log) + return errors.New("client timeout: unable to receive message " + s.SessionId) } return } diff --git a/src/sessionmanagerplugin/session/session_test.go b/src/sessionmanagerplugin/session/session_test.go index a8c770c7..8fc13814 100644 --- a/src/sessionmanagerplugin/session/session_test.go +++ b/src/sessionmanagerplugin/session/session_test.go @@ -18,6 +18,7 @@ import ( "bytes" "fmt" "testing" + "time" wsChannelMock "github.com/aws/SSMCLI/src/communicator/mocks" dataChannelMock "github.com/aws/SSMCLI/src/datachannel/mocks" @@ -80,6 +81,20 @@ func TestExecute(t *testing.T) { assert.Contains(t, err.Error(), "start session error for Standard_Stream") } +func TestExecuteWhenNoResponseFromDataChannel(t *testing.T) { + sessionMock := &Session{} + sessionMock.DataChannel = mockDataChannel + sessionMock.Timeout = time.Duration(5 * time.Millisecond) + SetupMockActions() + mockDataChannel.On("Open", mock.Anything).Return(nil) + isSessionTypeSetMock := make(chan bool, 1) + mockDataChannel.On("IsSessionTypeSet").Return(isSessionTypeSetMock) + + err := sessionMock.Execute(logger) + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "client timeout") +} + func SetupMockActions() { mockDataChannel.On("Initialize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return() mockDataChannel.On("SetWebsocket", mock.Anything, mock.Anything, mock.Anything).Return() diff --git a/src/ssmclicommands/startsession.go b/src/ssmclicommands/startsession.go index 6f8474fc..b7e2b49a 100644 --- a/src/ssmclicommands/startsession.go +++ b/src/ssmclicommands/startsession.go @@ -192,6 +192,7 @@ func (s *StartSessionCommand) Execute(parameters map[string][]string) (error, st ClientId: clientId, TargetId: instanceId, DataChannel: &datachannel.DataChannel{}, + Timeout: session.ClientTimeoutSecond, } if err = executeSession(log, &session); err != nil { diff --git a/src/ssmclicommands/startsession_test.go b/src/ssmclicommands/startsession_test.go index e895e196..763e04d0 100644 --- a/src/ssmclicommands/startsession_test.go +++ b/src/ssmclicommands/startsession_test.go @@ -21,7 +21,6 @@ import ( "github.com/aws/SSMCLI/src/log" "github.com/aws/SSMCLI/src/sessionmanagerplugin/session" "github.com/aws/aws-sdk-go/service/ssm" - "github.com/stretchr/testify/assert" )