Skip to content

Commit 37076c4

Browse files
Igor Drozdovashmckenzie
Igor Drozdov
andcommitted
Merge branch '648-gitlab-sshd-should-include-data-transfer-bytes-in-logs-3' into 'main'
Resolve "GitLab sshd should include data transfer bytes in logs" Closes #648 See merge request https://gitlab.com/gitlab-org/gitlab-shell/-/merge_requests/831 Merged-by: Igor Drozdov <[email protected]> Approved-by: Igor Drozdov <[email protected]> Reviewed-by: Igor Drozdov <[email protected]> Reviewed-by: Ash McKenzie <[email protected]> Co-authored-by: Ash McKenzie <[email protected]>
2 parents 7898d8e + 4bf9c83 commit 37076c4

File tree

10 files changed

+104
-34
lines changed

10 files changed

+104
-34
lines changed

cmd/check/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func main() {
2323
command.CheckForVersionFlag(os.Args, Version, BuildTime)
2424

2525
readWriter := &readwriter.ReadWriter{
26-
Out: os.Stdout,
26+
Out: &readwriter.CountingWriter{W: os.Stdout},
2727
In: os.Stdin,
2828
ErrOut: os.Stderr,
2929
}

cmd/gitlab-shell-authorized-keys-check/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func main() {
2424
command.CheckForVersionFlag(os.Args, Version, BuildTime)
2525

2626
readWriter := &readwriter.ReadWriter{
27-
Out: os.Stdout,
27+
Out: &readwriter.CountingWriter{W: os.Stdout},
2828
In: os.Stdin,
2929
ErrOut: os.Stderr,
3030
}

cmd/gitlab-shell-authorized-principals-check/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func main() {
2424
command.CheckForVersionFlag(os.Args, Version, BuildTime)
2525

2626
readWriter := &readwriter.ReadWriter{
27-
Out: os.Stdout,
27+
Out: &readwriter.CountingWriter{W: os.Stdout},
2828
In: os.Stdin,
2929
ErrOut: os.Stderr,
3030
}

cmd/gitlab-shell/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func main() {
3232
command.CheckForVersionFlag(os.Args, Version, BuildTime)
3333

3434
readWriter := &readwriter.ReadWriter{
35-
Out: os.Stdout,
35+
Out: &readwriter.CountingWriter{W: os.Stdout},
3636
In: os.Stdin,
3737
ErrOut: os.Stderr,
3838
}

internal/command/command.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ type LogMetadata struct {
2222
}
2323

2424
type LogData struct {
25-
Username string `json:"username"`
26-
Meta LogMetadata `json:"meta"`
25+
Username string `json:"username"`
26+
WrittenBytes int64 `json:"written_bytes"`
27+
Meta LogMetadata `json:"meta"`
2728
}
2829

2930
func CheckForVersionFlag(osArgs []string, version, buildTime string) {
@@ -87,7 +88,8 @@ func NewLogData(project, username string) LogData {
8788
}
8889

8990
return LogData{
90-
Username: username,
91+
Username: username,
92+
WrittenBytes: 0,
9193
Meta: LogMetadata{
9294
Project: project,
9395
RootNamespace: rootNameSpace,
Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,24 @@
11
package readwriter
22

3-
import "io"
3+
import (
4+
"io"
5+
)
46

57
type ReadWriter struct {
68
Out io.Writer
79
In io.Reader
810
ErrOut io.Writer
911
}
12+
13+
// CountingWriter wraps an io.Writer and counts all the writes. Accessing
14+
// the count N is not thread-safe.
15+
type CountingWriter struct {
16+
W io.Writer
17+
N int64
18+
}
19+
20+
func (cw *CountingWriter) Write(p []byte) (int, error) {
21+
n, err := cw.W.Write(p)
22+
cw.N += int64(n)
23+
return n, err
24+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package readwriter
2+
3+
import (
4+
"bytes"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestCountingWriter_Write(t *testing.T) {
11+
testString := []byte("test string")
12+
buffer := &bytes.Buffer{}
13+
14+
cw := &CountingWriter{
15+
W: buffer,
16+
}
17+
18+
n, err := cw.Write(testString)
19+
20+
require.NoError(t, err)
21+
require.Equal(t, 11, n)
22+
require.Equal(t, int64(11), cw.N)
23+
24+
cw.Write(testString)
25+
require.Equal(t, int64(22), cw.N)
26+
}

internal/sshd/session.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,10 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
167167
NamespacePath: s.namespace,
168168
}
169169

170+
countingWriter := &readwriter.CountingWriter{W: s.channel}
171+
170172
rw := &readwriter.ReadWriter{
171-
Out: s.channel,
173+
Out: countingWriter,
172174
In: s.channel,
173175
ErrOut: s.channel.Stderr(),
174176
}
@@ -183,6 +185,7 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
183185
} else {
184186
cmd, err = shellCmd.NewWithKey(s.gitlabKeyId, env, s.cfg, rw)
185187
}
188+
186189
if err != nil {
187190
if errors.Is(err, disallowedcommand.Error) {
188191
s.toStderr(ctx, "ERROR: Unknown command: %v\n", s.execCmd)
@@ -202,6 +205,12 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) (context.Co
202205
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)
203206

204207
ctxWithLogData, err := cmd.Execute(ctx)
208+
209+
logData := extractDataFromContext(ctxWithLogData)
210+
logData.WrittenBytes = countingWriter.N
211+
212+
ctxWithLogData = context.WithValue(ctx, "logData", logData)
213+
205214
if err != nil {
206215
grpcStatus := grpcstatus.Convert(err)
207216
if grpcStatus.Code() != grpccodes.Internal {

internal/sshd/session_test.go

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818

1919
type fakeChannel struct {
2020
stdErr io.ReadWriter
21+
stdOut io.ReadWriter
2122
sentRequestName string
2223
sentRequestPayload []byte
2324
}
@@ -27,7 +28,7 @@ func (f *fakeChannel) Read(data []byte) (int, error) {
2728
}
2829

2930
func (f *fakeChannel) Write(data []byte) (int, error) {
30-
return 0, nil
31+
return f.stdOut.Write(data)
3132
}
3233

3334
func (f *fakeChannel) Close() error {
@@ -145,8 +146,9 @@ func TestHandleExec(t *testing.T) {
145146
},
146147
}
147148
for _, s := range sessions {
148-
out := &bytes.Buffer{}
149-
f := &fakeChannel{stdErr: out}
149+
stdErr := &bytes.Buffer{}
150+
stdOut := &bytes.Buffer{}
151+
f := &fakeChannel{stdErr: stdErr, stdOut: stdOut}
150152
r := &ssh.Request{Payload: tc.payload}
151153

152154
s.channel = f
@@ -163,12 +165,14 @@ func TestHandleExec(t *testing.T) {
163165

164166
func TestHandleShell(t *testing.T) {
165167
testCases := []struct {
166-
desc string
167-
cmd string
168-
errMsg string
169-
gitlabKeyId string
170-
expectedErrString string
171-
expectedExitCode uint32
168+
desc string
169+
cmd string
170+
errMsg string
171+
gitlabKeyId string
172+
expectedOutString string
173+
expectedErrString string
174+
expectedExitCode uint32
175+
expectedWrittenBytes int64
172176
}{
173177
{
174178
desc: "fails to parse command",
@@ -177,57 +181,70 @@ func TestHandleShell(t *testing.T) {
177181
gitlabKeyId: "root",
178182
expectedErrString: "Invalid SSH command: invalid command line string",
179183
expectedExitCode: 128,
180-
}, {
184+
},
185+
{
181186
desc: "specified command is unknown",
182187
cmd: "unknown-command",
183188
errMsg: "ERROR: Unknown command: unknown-command\n",
184189
gitlabKeyId: "root",
185190
expectedErrString: "Disallowed command",
186191
expectedExitCode: 128,
187-
}, {
192+
},
193+
{
188194
desc: "fails to parse command",
189195
cmd: "discover",
190196
gitlabKeyId: "",
191197
errMsg: "ERROR: Failed to get username: who='' is invalid\n",
192198
expectedErrString: "Failed to get username: who='' is invalid",
193199
expectedExitCode: 1,
194-
}, {
195-
desc: "fails to parse command",
196-
cmd: "discover",
197-
errMsg: "",
198-
gitlabKeyId: "root",
199-
expectedErrString: "",
200-
expectedExitCode: 0,
200+
},
201+
{
202+
desc: "parses command",
203+
cmd: "discover",
204+
errMsg: "",
205+
gitlabKeyId: "root",
206+
expectedOutString: "Welcome to GitLab, @test-user!\n",
207+
expectedErrString: "",
208+
expectedExitCode: 0,
209+
expectedWrittenBytes: 31,
201210
},
202211
}
203212

204213
url := testserver.StartHttpServer(t, requests)
205214

206215
for _, tc := range testCases {
207216
t.Run(tc.desc, func(t *testing.T) {
208-
out := &bytes.Buffer{}
217+
stdOut := &bytes.Buffer{}
218+
stdErr := &bytes.Buffer{}
209219
s := &session{
210220
gitlabKeyId: tc.gitlabKeyId,
211221
execCmd: tc.cmd,
212-
channel: &fakeChannel{stdErr: out},
222+
channel: &fakeChannel{stdErr: stdErr, stdOut: stdOut},
213223
cfg: &config.Config{GitlabUrl: url},
214224
}
215225
r := &ssh.Request{}
216226

217-
_, exitCode, err := s.handleShell(context.Background(), r)
227+
ctxWithLogData, exitCode, err := s.handleShell(context.Background(), r)
228+
229+
logData := extractDataFromContext(ctxWithLogData)
230+
231+
if tc.expectedOutString != "" {
232+
require.Equal(t, tc.expectedOutString, stdOut.String())
233+
}
218234

219235
if tc.expectedErrString != "" {
220236
require.Equal(t, tc.expectedErrString, err.Error())
221237
}
222238

223239
require.Equal(t, tc.expectedExitCode, exitCode)
240+
require.Equal(t, tc.expectedWrittenBytes, logData.WrittenBytes)
224241

225242
formattedErr := &bytes.Buffer{}
226243
if tc.errMsg != "" {
227244
console.DisplayWarningMessage(tc.errMsg, formattedErr)
228-
require.Equal(t, formattedErr.String(), out.String())
245+
require.Equal(t, formattedErr.String(), stdErr.String())
229246
} else {
230-
require.Equal(t, tc.errMsg, out.String())
247+
require.Equal(t, tc.errMsg, stdErr.String())
231248
}
232249
})
233250
}

internal/sshd/sshd.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,9 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
217217
logData := extractDataFromContext(ctxWithLogData)
218218

219219
ctxlog.WithFields(log.Fields{
220-
"duration_s": time.Since(started).Seconds(),
221-
"meta": logData.Meta,
220+
"duration_s": time.Since(started).Seconds(),
221+
"written_bytes": logData.WrittenBytes,
222+
"meta": logData.Meta,
222223
}).Info("access: finish")
223224
}
224225

0 commit comments

Comments
 (0)