@@ -18,6 +18,7 @@ import (
18
18
19
19
type fakeChannel struct {
20
20
stdErr io.ReadWriter
21
+ stdOut io.ReadWriter
21
22
sentRequestName string
22
23
sentRequestPayload []byte
23
24
}
@@ -27,7 +28,7 @@ func (f *fakeChannel) Read(data []byte) (int, error) {
27
28
}
28
29
29
30
func (f * fakeChannel ) Write (data []byte ) (int , error ) {
30
- return 0 , nil
31
+ return f . stdOut . Write ( data )
31
32
}
32
33
33
34
func (f * fakeChannel ) Close () error {
@@ -145,8 +146,9 @@ func TestHandleExec(t *testing.T) {
145
146
},
146
147
}
147
148
for _ , s := range sessions {
148
- out := & bytes.Buffer {}
149
- f := & fakeChannel {stdErr : out }
149
+ stdErr := & bytes.Buffer {}
150
+ stdOut := & bytes.Buffer {}
151
+ f := & fakeChannel {stdErr : stdErr , stdOut : stdOut }
150
152
r := & ssh.Request {Payload : tc .payload }
151
153
152
154
s .channel = f
@@ -163,12 +165,14 @@ func TestHandleExec(t *testing.T) {
163
165
164
166
func TestHandleShell (t * testing.T ) {
165
167
testCases := []struct {
166
- desc string
167
- cmd string
168
- errMsg string
169
- gitlabKeyId string
170
- expectedErrString string
171
- expectedExitCode uint32
168
+ desc string
169
+ cmd string
170
+ errMsg string
171
+ gitlabKeyId string
172
+ expectedOutString string
173
+ expectedErrString string
174
+ expectedExitCode uint32
175
+ expectedWrittenBytes int64
172
176
}{
173
177
{
174
178
desc : "fails to parse command" ,
@@ -177,57 +181,70 @@ func TestHandleShell(t *testing.T) {
177
181
gitlabKeyId : "root" ,
178
182
expectedErrString : "Invalid SSH command: invalid command line string" ,
179
183
expectedExitCode : 128 ,
180
- }, {
184
+ },
185
+ {
181
186
desc : "specified command is unknown" ,
182
187
cmd : "unknown-command" ,
183
188
errMsg : "ERROR: Unknown command: unknown-command\n " ,
184
189
gitlabKeyId : "root" ,
185
190
expectedErrString : "Disallowed command" ,
186
191
expectedExitCode : 128 ,
187
- }, {
192
+ },
193
+ {
188
194
desc : "fails to parse command" ,
189
195
cmd : "discover" ,
190
196
gitlabKeyId : "" ,
191
197
errMsg : "ERROR: Failed to get username: who='' is invalid\n " ,
192
198
expectedErrString : "Failed to get username: who='' is invalid" ,
193
199
expectedExitCode : 1 ,
194
- }, {
195
- desc : "fails to parse command" ,
196
- cmd : "discover" ,
197
- errMsg : "" ,
198
- gitlabKeyId : "root" ,
199
- expectedErrString : "" ,
200
- expectedExitCode : 0 ,
200
+ },
201
+ {
202
+ desc : "parses command" ,
203
+ cmd : "discover" ,
204
+ errMsg : "" ,
205
+ gitlabKeyId : "root" ,
206
+ expectedOutString : "Welcome to GitLab, @test-user!\n " ,
207
+ expectedErrString : "" ,
208
+ expectedExitCode : 0 ,
209
+ expectedWrittenBytes : 31 ,
201
210
},
202
211
}
203
212
204
213
url := testserver .StartHttpServer (t , requests )
205
214
206
215
for _ , tc := range testCases {
207
216
t .Run (tc .desc , func (t * testing.T ) {
208
- out := & bytes.Buffer {}
217
+ stdOut := & bytes.Buffer {}
218
+ stdErr := & bytes.Buffer {}
209
219
s := & session {
210
220
gitlabKeyId : tc .gitlabKeyId ,
211
221
execCmd : tc .cmd ,
212
- channel : & fakeChannel {stdErr : out },
222
+ channel : & fakeChannel {stdErr : stdErr , stdOut : stdOut },
213
223
cfg : & config.Config {GitlabUrl : url },
214
224
}
215
225
r := & ssh.Request {}
216
226
217
- _ , exitCode , err := s .handleShell (context .Background (), r )
227
+ ctxWithLogData , exitCode , err := s .handleShell (context .Background (), r )
228
+
229
+ logData := extractDataFromContext (ctxWithLogData )
230
+
231
+ if tc .expectedOutString != "" {
232
+ require .Equal (t , tc .expectedOutString , stdOut .String ())
233
+ }
218
234
219
235
if tc .expectedErrString != "" {
220
236
require .Equal (t , tc .expectedErrString , err .Error ())
221
237
}
222
238
223
239
require .Equal (t , tc .expectedExitCode , exitCode )
240
+ require .Equal (t , tc .expectedWrittenBytes , logData .WrittenBytes )
224
241
225
242
formattedErr := & bytes.Buffer {}
226
243
if tc .errMsg != "" {
227
244
console .DisplayWarningMessage (tc .errMsg , formattedErr )
228
- require .Equal (t , formattedErr .String (), out .String ())
245
+ require .Equal (t , formattedErr .String (), stdErr .String ())
229
246
} else {
230
- require .Equal (t , tc .errMsg , out .String ())
247
+ require .Equal (t , tc .errMsg , stdErr .String ())
231
248
}
232
249
})
233
250
}
0 commit comments