Skip to content

Commit

Permalink
Introduce client timeout for session start request.
Browse files Browse the repository at this point in the history
  • Loading branch information
YujiaozhAws authored and nitikaaws committed Jun 9, 2021
1 parent 65933d1 commit 34bde90
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
31 changes: 20 additions & 11 deletions src/sessionmanagerplugin/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"io"
"os"
"time"

"github.com/aws/SSMCLI/src/datachannel"
"github.com/aws/SSMCLI/src/log"
Expand All @@ -34,9 +35,8 @@ import (

const (
LegacyArgumentLength = 4
ArgumentLength = 7
StartSessionOperation = "StartSession"
VersionFile = "VERSION"
ClientTimeoutSecond = time.Duration(10 * time.Second)
)

var SessionRegistry = map[string]ISessionPlugin{}
Expand Down Expand Up @@ -81,6 +81,7 @@ type Session struct {
SessionType string
SessionProperties interface{}
DisplayMode sessionutil.DisplayMode
Timeout time.Duration
}

//startSession create the datachannel for session
Expand Down Expand Up @@ -175,6 +176,7 @@ func ValidateInputAndStartSession(args []string, out io.Writer) {
session.ClientId = clientId
session.TargetId = target
session.DataChannel = &datachannel.DataChannel{}
session.Timeout = ClientTimeoutSecond

default:
fmt.Fprint(out, "Invalid Operation")
Expand All @@ -200,17 +202,24 @@ func (s *Session) Execute(log log.T) (err error) {
return
}

select {
// The session type is set either by handshake or the first packet received.
if !<-s.DataChannel.IsSessionTypeSet() {
log.Errorf("unable to SessionType for session %s", s.SessionId)
return errors.New("unable to determine SessionType")
} else {
s.SessionType = s.DataChannel.GetSessionType()
s.SessionProperties = s.DataChannel.GetSessionProperties()
if err = setSessionHandlersWithSessionType(s, log); err != nil {
log.Errorf("Session ending with error: %v", err)
return
case isSessionTypeSet := <-s.DataChannel.IsSessionTypeSet():
if !isSessionTypeSet {
log.Errorf("unable to SessionType for session %s", s.SessionId)
return errors.New("unable to determine SessionType")
} else {
s.SessionType = s.DataChannel.GetSessionType()
s.SessionProperties = s.DataChannel.GetSessionProperties()
if err = setSessionHandlersWithSessionType(s, log); err != nil {
log.Errorf("Session ending with error: %v", err)
return
}
}
case <-time.After(s.Timeout):
log.Errorf("client timeout: unable to receive message %s", s.SessionId)
s.TerminateSession(log)
return errors.New("client timeout: unable to receive message " + s.SessionId)
}
return
}
15 changes: 15 additions & 0 deletions src/sessionmanagerplugin/session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"bytes"
"fmt"
"testing"
"time"

wsChannelMock "github.com/aws/SSMCLI/src/communicator/mocks"
dataChannelMock "github.com/aws/SSMCLI/src/datachannel/mocks"
Expand Down Expand Up @@ -80,6 +81,20 @@ func TestExecute(t *testing.T) {
assert.Contains(t, err.Error(), "start session error for Standard_Stream")
}

func TestExecuteWhenNoResponseFromDataChannel(t *testing.T) {
sessionMock := &Session{}
sessionMock.DataChannel = mockDataChannel
sessionMock.Timeout = time.Duration(5 * time.Millisecond)
SetupMockActions()
mockDataChannel.On("Open", mock.Anything).Return(nil)
isSessionTypeSetMock := make(chan bool, 1)
mockDataChannel.On("IsSessionTypeSet").Return(isSessionTypeSetMock)

err := sessionMock.Execute(logger)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "client timeout")
}

func SetupMockActions() {
mockDataChannel.On("Initialize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return()
mockDataChannel.On("SetWebsocket", mock.Anything, mock.Anything, mock.Anything).Return()
Expand Down
1 change: 1 addition & 0 deletions src/ssmclicommands/startsession.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ func (s *StartSessionCommand) Execute(parameters map[string][]string) (error, st
ClientId: clientId,
TargetId: instanceId,
DataChannel: &datachannel.DataChannel{},
Timeout: session.ClientTimeoutSecond,
}

if err = executeSession(log, &session); err != nil {
Expand Down
1 change: 0 additions & 1 deletion src/ssmclicommands/startsession_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"github.com/aws/SSMCLI/src/log"
"github.com/aws/SSMCLI/src/sessionmanagerplugin/session"
"github.com/aws/aws-sdk-go/service/ssm"

"github.com/stretchr/testify/assert"
)

Expand Down

0 comments on commit 34bde90

Please sign in to comment.