Skip to content

Commit b368379

Browse files
authored
credentials/alts: Optimize reads (#8204)
1 parent 4b5505d commit b368379

File tree

4 files changed

+90
-24
lines changed

4 files changed

+90
-24
lines changed

credentials/alts/internal/conn/common.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,10 @@ func SliceForAppend(in []byte, n int) (head, tail []byte) {
5454
func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
5555
// If the size field is not complete, return the provided buffer as
5656
// remaining buffer.
57-
if len(b) < MsgLenFieldSize {
57+
length, sufficientBytes := parseMessageLength(b)
58+
if !sufficientBytes {
5859
return nil, b, nil
5960
}
60-
msgLenField := b[:MsgLenFieldSize]
61-
length := binary.LittleEndian.Uint32(msgLenField)
6261
if length > maxLen {
6362
return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
6463
}
@@ -68,3 +67,14 @@ func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
6867
}
6968
return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil
7069
}
70+
71+
// parseMessageLength returns the message length based on frame header. It also
72+
// returns a boolean indicating if the buffer contains sufficient bytes to parse
73+
// the length header. If there are insufficient bytes, (0, false) is returned.
74+
func parseMessageLength(b []byte) (uint32, bool) {
75+
if len(b) < MsgLenFieldSize {
76+
return 0, false
77+
}
78+
msgLenField := b[:MsgLenFieldSize]
79+
return binary.LittleEndian.Uint32(msgLenField), true
80+
}

credentials/alts/internal/conn/record.go

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ const (
6363
// The maximum write buffer size. This *must* be multiple of
6464
// altsRecordDefaultLength.
6565
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
66+
// The initial buffer used to read from the network.
67+
altsReadBufferInitialSize = 32 * 1024 // 32KiB
6668
)
6769

6870
var (
@@ -83,7 +85,7 @@ type conn struct {
8385
net.Conn
8486
crypto ALTSRecordCrypto
8587
// buf holds data that has been read from the connection and decrypted,
86-
// but has not yet been returned by Read.
88+
// but has not yet been returned by Read. It is a sub-slice of protected.
8789
buf []byte
8890
payloadLengthLimit int
8991
// protected holds data read from the network but have not yet been
@@ -111,21 +113,13 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
111113
}
112114
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
113115
payloadLengthLimit := altsRecordDefaultLength - overhead
114-
var protectedBuf []byte
115-
if protected == nil {
116-
// We pre-allocate protected to be of size
117-
// 2*altsRecordDefaultLength-1 during initialization. We only
118-
// read from the network into protected when protected does not
119-
// contain a complete frame, which is at most
120-
// altsRecordDefaultLength-1 (bytes). And we read at most
121-
// altsRecordDefaultLength (bytes) data into protected at one
122-
// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
123-
// to buffer data read from the network.
124-
protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1)
125-
} else {
126-
protectedBuf = make([]byte, len(protected))
127-
copy(protectedBuf, protected)
128-
}
116+
// We pre-allocate protected to be of size 32KB during initialization.
117+
// We increase the size of the buffer by the required amount if it can't
118+
// hold a complete encrypted record.
119+
protectedBuf := make([]byte, max(altsReadBufferInitialSize, len(protected)))
120+
// Copy additional data from hanshaker service.
121+
copy(protectedBuf, protected)
122+
protectedBuf = protectedBuf[:len(protected)]
129123

130124
altsConn := &conn{
131125
Conn: c,
@@ -162,11 +156,21 @@ func (p *conn) Read(b []byte) (n int, err error) {
162156
// Check whether a complete frame has been received yet.
163157
for len(framedMsg) == 0 {
164158
if len(p.protected) == cap(p.protected) {
165-
tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength)
166-
copy(tmp, p.protected)
167-
p.protected = tmp
159+
// We can parse the length header to know exactly how large
160+
// the buffer needs to be to hold the entire frame.
161+
length, didParse := parseMessageLength(p.protected)
162+
if !didParse {
163+
// The protected buffer is initialized with a capacity of
164+
// larger than 4B. It should always hold the message length
165+
// header.
166+
panic(fmt.Sprintf("protected buffer length shorter than expected: %d vs %d", len(p.protected), MsgLenFieldSize))
167+
}
168+
oldProtectedBuf := p.protected
169+
p.protected = make([]byte, int(length)+MsgLenFieldSize)
170+
copy(p.protected, oldProtectedBuf)
171+
p.protected = p.protected[:len(oldProtectedBuf)]
168172
}
169-
n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)])
173+
n, err = p.Conn.Read(p.protected[len(p.protected):cap(p.protected)])
170174
if err != nil {
171175
return 0, err
172176
}
@@ -185,6 +189,15 @@ func (p *conn) Read(b []byte) (n int, err error) {
185189
}
186190
ciphertext := msg[msgTypeFieldSize:]
187191

192+
// Decrypt directly into the buffer, avoiding a copy from p.buf if
193+
// possible.
194+
if len(b) >= len(ciphertext) {
195+
dec, err := p.crypto.Decrypt(b[:0], ciphertext)
196+
if err != nil {
197+
return 0, err
198+
}
199+
return len(dec), nil
200+
}
188201
// Decrypt requires that if the dst and ciphertext alias, they
189202
// must alias exactly. Code here used to use msg[:0], but msg
190203
// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than

credentials/alts/internal/conn/record_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"math"
2727
"net"
2828
"reflect"
29+
"strings"
2930
"testing"
3031

3132
core "google.golang.org/grpc/credentials/alts/internal"
@@ -188,6 +189,48 @@ func (s) TestLargeMsg(t *testing.T) {
188189
}
189190
}
190191

192+
// TestLargeRecord writes a very large ALTS record and verifies that the server
193+
// receives it correctly. The large ALTS record should cause the reader to
194+
// expand it's read buffer to hold the entire record and store the decrypted
195+
// message until the receiver reads all of the bytes.
196+
func (s) TestLargeRecord(t *testing.T) {
197+
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
198+
msg := []byte(strings.Repeat("a", 2*altsReadBufferInitialSize))
199+
// Increase the size of ALTS records written by the client.
200+
clientConn.payloadLengthLimit = math.MaxInt32
201+
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
202+
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
203+
}
204+
rcvMsg := make([]byte, len(msg))
205+
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
206+
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
207+
}
208+
if !reflect.DeepEqual(msg, rcvMsg) {
209+
t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg)
210+
}
211+
}
212+
213+
// BenchmarkLargeMessage measures the performance of ALTS conns for sending and
214+
// receiving a large message.
215+
func BenchmarkLargeMessage(b *testing.B) {
216+
msgLen := 20 * 1024 * 1024 // 20 MiB
217+
msg := make([]byte, msgLen)
218+
rcvMsg := make([]byte, len(msg))
219+
b.ResetTimer()
220+
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
221+
for range b.N {
222+
// Write 20 MiB 5 times to transfer a total of 100 MiB.
223+
for range 5 {
224+
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
225+
b.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
226+
}
227+
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
228+
b.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
229+
}
230+
}
231+
}
232+
}
233+
191234
func testIncorrectMsgType(t *testing.T, rp string) {
192235
// framedMsg is an empty ciphertext with correct framing but wrong
193236
// message type.

credentials/alts/internal/handshaker/handshaker.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*al
308308
// whatever received from the network and send it to the handshaker service.
309309
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
310310
var lastWriteTime time.Time
311+
buf := make([]byte, frameLimit)
311312
for {
312313
if len(resp.OutFrames) > 0 {
313314
lastWriteTime = time.Now()
@@ -318,7 +319,6 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b
318319
if resp.Result != nil {
319320
return resp.Result, extra, nil
320321
}
321-
buf := make([]byte, frameLimit)
322322
n, err := h.conn.Read(buf)
323323
if err != nil && err != io.EOF {
324324
return nil, nil, err

0 commit comments

Comments
 (0)