Skip to content

Commit

Permalink
chore: Map Stream to Support Concurrent Requests (#171)
Browse files Browse the repository at this point in the history
Signed-off-by: Yashash H L <[email protected]>
  • Loading branch information
yhl25 authored Dec 22, 2024
1 parent 24c4ba9 commit d8c236b
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 32 deletions.
64 changes: 40 additions & 24 deletions pkg/mapstreamer/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,43 +64,59 @@ func (fs *Service) MapFn(stream mappb.Map_MapFnServer) error {
return err
}

ctx := stream.Context()
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()

// Use error group to manage goroutines, the groupCtx is cancelled when any of the
// goroutines return an error for the first time or the first time the wait returns.
g, groupCtx := errgroup.WithContext(ctx)
var readErr error
// Read requests from the stream and process them
outer:
for {
req, err := recvWithContext(ctx, stream)
req, err := recvWithContext(groupCtx, stream)
if errors.Is(err, context.Canceled) {
log.Printf("Context cancelled, stopping the MapStreamFn")
break
log.Printf("Context cancelled, stopping the MapFn")
break outer
}
if errors.Is(err, io.EOF) {
log.Printf("EOF received, stopping the MapStreamFn")
break
log.Printf("EOF received, stopping the MapFn")
break outer
}

if err != nil {
log.Printf("Failed to receive request: %v", err)
return err
readErr = err
// read loop is not part of the error group, so we need to cancel the context
// to signal the other goroutines to stop processing.
cancel()
break outer
}
g.Go(func() error {
messageCh := make(chan Message)
workerGroup, innerCtx := errgroup.WithContext(groupCtx)

messageCh := make(chan Message)
g, groupCtx := errgroup.WithContext(ctx)
workerGroup.Go(func() error {
return fs.invokeHandler(innerCtx, req, messageCh)
})

g.Go(func() error {
return fs.invokeHandler(groupCtx, req, messageCh)
})
workerGroup.Go(func() error {
return fs.writeResponseToClient(innerCtx, stream, req.GetId(), messageCh)
})

g.Go(func() error {
return fs.writeResponseToClient(groupCtx, stream, req.GetId(), messageCh)
return workerGroup.Wait()
})
}

// Wait for the error group to finish
if err := g.Wait(); err != nil {
log.Printf("error processing requests: %v", err)
if err == io.EOF {
return nil
}
fs.shutdownCh <- struct{}{}
return status.Errorf(codes.Internal, "error processing requests: %v", err)
}
// wait for all goroutines to finish
if err := g.Wait(); err != nil {
log.Printf("Stopping the MapFn with err, %s", err)
fs.shutdownCh <- struct{}{}
return status.Errorf(codes.Internal, "error processing requests: %v", err)
}

// check if there was an error while reading from the stream
if readErr != nil {
return status.Errorf(codes.Internal, readErr.Error())
}

return nil
Expand Down
96 changes: 88 additions & 8 deletions pkg/mapstreamer/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"testing"
"time"
Expand Down Expand Up @@ -294,6 +295,23 @@ func TestService_MapFn_Multiple_Messages(t *testing.T) {

expectedResults := make([][]*proto.MapResponse_Result, msgCount)
results := make([][]*proto.MapResponse_Result, 0)
eotCount := 0

for {
got, err := stream.Recv()
if err == io.EOF {
break
}
require.NoError(t, err, "Receiving message from the stream")

if got.Status != nil && got.Status.Eot {
eotCount++
} else {
results = append(results, got.Results)
}
}

require.Equal(t, msgCount, eotCount, "Expected number of EOT messages")

for i := 0; i < msgCount; i++ {
expectedResults[i] = []*proto.MapResponse_Result{
Expand All @@ -302,14 +320,6 @@ func TestService_MapFn_Multiple_Messages(t *testing.T) {
Value: []byte(fmt.Sprintf("test_%d", i)),
},
}

got, err := stream.Recv()
require.NoError(t, err, "Receiving message from the stream")
results = append(results, got.Results)

eot, err := stream.Recv()
require.NoError(t, err, "Receiving message from the stream")
require.True(t, eot.Status.Eot, "Expected EOT message")
}

require.ElementsMatch(t, results, expectedResults)
Expand Down Expand Up @@ -350,3 +360,73 @@ func TestService_MapFn_Panic(t *testing.T) {
expectedStatus := status.Convert(status.Errorf(codes.Internal, "error processing requests: panic inside mapStream handler: map failed"))
require.Equal(t, expectedStatus, gotStatus)
}

func TestService_MapFn_MultipleRequestsAndResponses(t *testing.T) {
svc := &Service{
MapperStream: MapStreamerFunc(func(ctx context.Context, keys []string, datum Datum, messageCh chan<- Message) {
defer close(messageCh)
for i := 0; i < 3; i++ { // Send multiple responses for each request
msg := fmt.Sprintf("response_%d_for_%s", i, string(datum.Value()))
messageCh <- NewMessage([]byte(msg)).WithKeys([]string{keys[0] + "_test"})
}
}),
}
conn := newTestServer(t, func(server *grpc.Server) {
proto.RegisterMapServer(server, svc)
})

client := proto.NewMapClient(conn)
stream, err := client.MapFn(context.Background())
require.NoError(t, err, "Creating stream")

doHandshake(t, stream)

const msgCount = 5
for i := 0; i < msgCount; i++ {
msg := proto.MapRequest{
Request: &proto.MapRequest_Request{
Keys: []string{"client"},
Value: []byte(fmt.Sprintf("test_%d", i)),
EventTime: timestamppb.New(time.Time{}),
Watermark: timestamppb.New(time.Time{}),
},
}
err = stream.Send(&msg)
require.NoError(t, err, "Sending message over the stream")
}
err = stream.CloseSend()
require.NoError(t, err, "Closing the send direction of the stream")

expectedResults := make([][]*proto.MapResponse_Result, msgCount*3)
results := make([][]*proto.MapResponse_Result, 0)
eotCount := 0

for {
got, err := stream.Recv()
if err == io.EOF {
break
}
require.NoError(t, err, "Receiving message from the stream")

if got.Status != nil && got.Status.Eot {
eotCount++
} else {
results = append(results, got.Results)
}
}

require.Equal(t, msgCount, eotCount, "Expected number of EOT messages")

for i := 0; i < msgCount; i++ {
for j := 0; j < 3; j++ {
expectedResults[i*3+j] = []*proto.MapResponse_Result{
{
Keys: []string{"client_test"},
Value: []byte(fmt.Sprintf("response_%d_for_test_%d", j, i)),
},
}
}
}

require.ElementsMatch(t, results, expectedResults)
}

0 comments on commit d8c236b

Please sign in to comment.