Skip to content

Commit 4bf9c83

Browse files
committed
Set logData.WrittenBytes from CountingWriter.N
1 parent a509a44 commit 4bf9c83

File tree

2 files changed

+50
-24
lines changed

2 files changed

+50
-24
lines changed

internal/sshd/session.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,10 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
167167
NamespacePath: s.namespace,
168168
}
169169

170+
countingWriter := &readwriter.CountingWriter{W: s.channel}
171+
170172
rw := &readwriter.ReadWriter{
171-
Out: &readwriter.CountingWriter{W: s.channel},
173+
Out: countingWriter,
172174
In: s.channel,
173175
ErrOut: s.channel.Stderr(),
174176
}
@@ -183,6 +185,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
183185
} else {
184186
cmd, err = shellCmd.NewWithKey(s.gitlabKeyId, env, s.cfg, rw)
185187
}
188+
186189
if err != nil {
187190
if errors.Is(err, disallowedcommand.Error) {
188191
s.toStderr(ctx, "ERROR: Unknown command: %v\n", s.execCmd)
@@ -202,6 +205,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
202205
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
203206

204207
ctxWithLogData, err := cmd.Execute(ctx)
208+
209+
logData := extractDataFromContext(ctxWithLogData)
210+
logData.WrittenBytes = countingWriter.N
211+
212+
ctxWithLogData = context.WithValue(ctx, "logData", logData)
213+
205214
if err != nil {
206215
grpcStatus := grpcstatus.Convert(err)
207216
if grpcStatus.Code() != grpccodes.Internal {

internal/sshd/session_test.go

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818

1919
type fakeChannel struct {
2020
stdErr io.ReadWriter
21+
stdOut io.ReadWriter
2122
sentRequestName string
2223
sentRequestPayload []byte
2324
}
@@ -27,7 +28,7 @@ func (f *fakeChannel) Read(data []byte) (int, error) {
2728
}
2829

2930
func (f *fakeChannel) Write(data []byte) (int, error) {
30-
return 0, nil
31+
return f.stdOut.Write(data)
3132
}
3233

3334
func (f *fakeChannel) Close() error {
@@ -145,8 +146,9 @@ func TestHandleExec(t *testing.T) {
145146
},
146147
}
147148
for _, s := range sessions {
148-
out := &bytes.Buffer{}
149-
f := &fakeChannel{stdErr: out}
149+
stdErr := &bytes.Buffer{}
150+
stdOut := &bytes.Buffer{}
151+
f := &fakeChannel{stdErr: stdErr, stdOut: stdOut}
150152
r := &ssh.Request{Payload: tc.payload}
151153

152154
s.channel = f
@@ -163,12 +165,14 @@ func TestHandleExec(t *testing.T) {
163165

164166
func TestHandleShell(t *testing.T) {
165167
testCases := []struct {
166-
desc string
167-
cmd string
168-
errMsg string
169-
gitlabKeyId string
170-
expectedErrString string
171-
expectedExitCode uint32
168+
desc string
169+
cmd string
170+
errMsg string
171+
gitlabKeyId string
172+
expectedOutString string
173+
expectedErrString string
174+
expectedExitCode uint32
175+
expectedWrittenBytes int64
172176
}{
173177
{
174178
desc: "fails to parse command",
@@ -177,57 +181,70 @@ func TestHandleShell(t *testing.T) {
177181
gitlabKeyId: "root",
178182
expectedErrString: "Invalid SSH command: invalid command line string",
179183
expectedExitCode: 128,
180-
}, {
184+
},
185+
{
181186
desc: "specified command is unknown",
182187
cmd: "unknown-command",
183188
errMsg: "ERROR: Unknown command: unknown-command\n",
184189
gitlabKeyId: "root",
185190
expectedErrString: "Disallowed command",
186191
expectedExitCode: 128,
187-
}, {
192+
},
193+
{
188194
desc: "fails to parse command",
189195
cmd: "discover",
190196
gitlabKeyId: "",
191197
errMsg: "ERROR: Failed to get username: who='' is invalid\n",
192198
expectedErrString: "Failed to get username: who='' is invalid",
193199
expectedExitCode: 1,
194-
}, {
195-
desc: "fails to parse command",
196-
cmd: "discover",
197-
errMsg: "",
198-
gitlabKeyId: "root",
199-
expectedErrString: "",
200-
expectedExitCode: 0,
200+
},
201+
{
202+
desc: "parses command",
203+
cmd: "discover",
204+
errMsg: "",
205+
gitlabKeyId: "root",
206+
expectedOutString: "Welcome to GitLab, @test-user!\n",
207+
expectedErrString: "",
208+
expectedExitCode: 0,
209+
expectedWrittenBytes: 31,
201210
},
202211
}
203212

204213
url := testserver.StartHttpServer(t, requests)
205214

206215
for _, tc := range testCases {
207216
t.Run(tc.desc, func(t *testing.T) {
208-
out := &bytes.Buffer{}
217+
stdOut := &bytes.Buffer{}
218+
stdErr := &bytes.Buffer{}
209219
s := &session{
210220
gitlabKeyId: tc.gitlabKeyId,
211221
execCmd: tc.cmd,
212-
channel: &fakeChannel{stdErr: out},
222+
channel: &fakeChannel{stdErr: stdErr, stdOut: stdOut},
213223
cfg: &config.Config{GitlabUrl: url},
214224
}
215225
r := &ssh.Request{}
216226

217-
_, exitCode, err := s.handleShell(context.Background(), r)
227+
ctxWithLogData, exitCode, err := s.handleShell(context.Background(), r)
228+
229+
logData := extractDataFromContext(ctxWithLogData)
230+
231+
if tc.expectedOutString != "" {
232+
require.Equal(t, tc.expectedOutString, stdOut.String())
233+
}
218234

219235
if tc.expectedErrString != "" {
220236
require.Equal(t, tc.expectedErrString, err.Error())
221237
}
222238

223239
require.Equal(t, tc.expectedExitCode, exitCode)
240+
require.Equal(t, tc.expectedWrittenBytes, logData.WrittenBytes)
224241

225242
formattedErr := &bytes.Buffer{}
226243
if tc.errMsg != "" {
227244
console.DisplayWarningMessage(tc.errMsg, formattedErr)
228-
require.Equal(t, formattedErr.String(), out.String())
245+
require.Equal(t, formattedErr.String(), stdErr.String())
229246
} else {
230-
require.Equal(t, tc.errMsg, out.String())
247+
require.Equal(t, tc.errMsg, stdErr.String())
231248
}
232249
})
233250
}

0 commit comments

Comments
 (0)