@@ -18,6 +18,7 @@ import (
1818
1919type fakeChannel struct {
2020 stdErr io.ReadWriter
21+ stdOut io.ReadWriter
2122 sentRequestName string
2223 sentRequestPayload []byte
2324}
@@ -27,7 +28,7 @@ func (f *fakeChannel) Read(data []byte) (int, error) {
2728}
2829
2930func (f * fakeChannel ) Write (data []byte ) (int , error ) {
30- return 0 , nil
31+ return f . stdOut . Write ( data )
3132}
3233
3334func (f * fakeChannel ) Close () error {
@@ -145,8 +146,9 @@ func TestHandleExec(t *testing.T) {
145146 },
146147 }
147148 for _ , s := range sessions {
148- out := & bytes.Buffer {}
149- f := & fakeChannel {stdErr : out }
149+ stdErr := & bytes.Buffer {}
150+ stdOut := & bytes.Buffer {}
151+ f := & fakeChannel {stdErr : stdErr , stdOut : stdOut }
150152 r := & ssh.Request {Payload : tc .payload }
151153
152154 s .channel = f
@@ -163,12 +165,14 @@ func TestHandleExec(t *testing.T) {
163165
164166func TestHandleShell (t * testing.T ) {
165167 testCases := []struct {
166- desc string
167- cmd string
168- errMsg string
169- gitlabKeyId string
170- expectedErrString string
171- expectedExitCode uint32
168+ desc string
169+ cmd string
170+ errMsg string
171+ gitlabKeyId string
172+ expectedOutString string
173+ expectedErrString string
174+ expectedExitCode uint32
175+ expectedWrittenBytes int64
172176 }{
173177 {
174178 desc : "fails to parse command" ,
@@ -177,57 +181,70 @@ func TestHandleShell(t *testing.T) {
177181 gitlabKeyId : "root" ,
178182 expectedErrString : "Invalid SSH command: invalid command line string" ,
179183 expectedExitCode : 128 ,
180- }, {
184+ },
185+ {
181186 desc : "specified command is unknown" ,
182187 cmd : "unknown-command" ,
183188 errMsg : "ERROR: Unknown command: unknown-command\n " ,
184189 gitlabKeyId : "root" ,
185190 expectedErrString : "Disallowed command" ,
186191 expectedExitCode : 128 ,
187- }, {
192+ },
193+ {
188194 desc : "fails to parse command" ,
189195 cmd : "discover" ,
190196 gitlabKeyId : "" ,
191197 errMsg : "ERROR: Failed to get username: who='' is invalid\n " ,
192198 expectedErrString : "Failed to get username: who='' is invalid" ,
193199 expectedExitCode : 1 ,
194- }, {
195- desc : "fails to parse command" ,
196- cmd : "discover" ,
197- errMsg : "" ,
198- gitlabKeyId : "root" ,
199- expectedErrString : "" ,
200- expectedExitCode : 0 ,
200+ },
201+ {
202+ desc : "parses command" ,
203+ cmd : "discover" ,
204+ errMsg : "" ,
205+ gitlabKeyId : "root" ,
206+ expectedOutString : "Welcome to GitLab, @test-user!\n " ,
207+ expectedErrString : "" ,
208+ expectedExitCode : 0 ,
209+ expectedWrittenBytes : 31 ,
201210 },
202211 }
203212
204213 url := testserver .StartHttpServer (t , requests )
205214
206215 for _ , tc := range testCases {
207216 t .Run (tc .desc , func (t * testing.T ) {
208- out := & bytes.Buffer {}
217+ stdOut := & bytes.Buffer {}
218+ stdErr := & bytes.Buffer {}
209219 s := & session {
210220 gitlabKeyId : tc .gitlabKeyId ,
211221 execCmd : tc .cmd ,
212- channel : & fakeChannel {stdErr : out },
222+ channel : & fakeChannel {stdErr : stdErr , stdOut : stdOut },
213223 cfg : & config.Config {GitlabUrl : url },
214224 }
215225 r := & ssh.Request {}
216226
217- _ , exitCode , err := s .handleShell (context .Background (), r )
227+ ctxWithLogData , exitCode , err := s .handleShell (context .Background (), r )
228+
229+ logData := extractDataFromContext (ctxWithLogData )
230+
231+ if tc .expectedOutString != "" {
232+ require .Equal (t , tc .expectedOutString , stdOut .String ())
233+ }
218234
219235 if tc .expectedErrString != "" {
220236 require .Equal (t , tc .expectedErrString , err .Error ())
221237 }
222238
223239 require .Equal (t , tc .expectedExitCode , exitCode )
240+ require .Equal (t , tc .expectedWrittenBytes , logData .WrittenBytes )
224241
225242 formattedErr := & bytes.Buffer {}
226243 if tc .errMsg != "" {
227244 console .DisplayWarningMessage (tc .errMsg , formattedErr )
228- require .Equal (t , formattedErr .String (), out .String ())
245+ require .Equal (t , formattedErr .String (), stdErr .String ())
229246 } else {
230- require .Equal (t , tc .errMsg , out .String ())
247+ require .Equal (t , tc .errMsg , stdErr .String ())
231248 }
232249 })
233250 }
0 commit comments