From d8c236b6d08cc82eb030c6149f0728c85f254e7b Mon Sep 17 00:00:00 2001 From: Yashash H L Date: Sun, 22 Dec 2024 17:22:45 +0530 Subject: [PATCH] chore: Map Stream to Support Concurrent Requests (#171) Signed-off-by: Yashash H L --- pkg/mapstreamer/service.go | 64 +++++++++++++--------- pkg/mapstreamer/service_test.go | 96 ++++++++++++++++++++++++++++++--- 2 files changed, 128 insertions(+), 32 deletions(-) diff --git a/pkg/mapstreamer/service.go b/pkg/mapstreamer/service.go index 51968e14..8795ff08 100644 --- a/pkg/mapstreamer/service.go +++ b/pkg/mapstreamer/service.go @@ -64,43 +64,59 @@ func (fs *Service) MapFn(stream mappb.Map_MapFnServer) error { return err } - ctx := stream.Context() + ctx, cancel := context.WithCancel(stream.Context()) + defer cancel() + + // Use error group to manage goroutines, the groupCtx is cancelled when any of the + // goroutines return an error for the first time or the first time the wait returns. + g, groupCtx := errgroup.WithContext(ctx) + var readErr error + // Read requests from the stream and process them +outer: for { - req, err := recvWithContext(ctx, stream) + req, err := recvWithContext(groupCtx, stream) if errors.Is(err, context.Canceled) { - log.Printf("Context cancelled, stopping the MapStreamFn") - break + log.Printf("Context cancelled, stopping the MapFn") + break outer } if errors.Is(err, io.EOF) { - log.Printf("EOF received, stopping the MapStreamFn") - break + log.Printf("EOF received, stopping the MapFn") + break outer } - if err != nil { log.Printf("Failed to receive request: %v", err) - return err + readErr = err + // read loop is not part of the error group, so we need to cancel the context + // to signal the other goroutines to stop processing. + cancel() + break outer } + g.Go(func() error { + messageCh := make(chan Message) + workerGroup, innerCtx := errgroup.WithContext(groupCtx) - messageCh := make(chan Message) - g, groupCtx := errgroup.WithContext(ctx) + workerGroup.Go(func() error { + return fs.invokeHandler(innerCtx, req, messageCh) + }) - g.Go(func() error { - return fs.invokeHandler(groupCtx, req, messageCh) - }) + workerGroup.Go(func() error { + return fs.writeResponseToClient(innerCtx, stream, req.GetId(), messageCh) + }) - g.Go(func() error { - return fs.writeResponseToClient(groupCtx, stream, req.GetId(), messageCh) + return workerGroup.Wait() }) + } - // Wait for the error group to finish - if err := g.Wait(); err != nil { - log.Printf("error processing requests: %v", err) - if err == io.EOF { - return nil - } - fs.shutdownCh <- struct{}{} - return status.Errorf(codes.Internal, "error processing requests: %v", err) - } + // wait for all goroutines to finish + if err := g.Wait(); err != nil { + log.Printf("Stopping the MapFn with err, %s", err) + fs.shutdownCh <- struct{}{} + return status.Errorf(codes.Internal, "error processing requests: %v", err) + } + + // check if there was an error while reading from the stream + if readErr != nil { + return status.Errorf(codes.Internal, readErr.Error()) } return nil diff --git a/pkg/mapstreamer/service_test.go b/pkg/mapstreamer/service_test.go index c026491c..fd5b35e9 100644 --- a/pkg/mapstreamer/service_test.go +++ b/pkg/mapstreamer/service_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "net" "testing" "time" @@ -294,6 +295,23 @@ func TestService_MapFn_Multiple_Messages(t *testing.T) { expectedResults := make([][]*proto.MapResponse_Result, msgCount) results := make([][]*proto.MapResponse_Result, 0) + eotCount := 0 + + for { + got, err := stream.Recv() + if err == io.EOF { + break + } + require.NoError(t, err, "Receiving message from the stream") + + if got.Status != nil && got.Status.Eot { + eotCount++ + } else { + results = append(results, got.Results) + } + } + + require.Equal(t, msgCount, eotCount, "Expected number of EOT messages") for i := 0; i < msgCount; i++ { expectedResults[i] = []*proto.MapResponse_Result{ @@ -302,14 +320,6 @@ func TestService_MapFn_Multiple_Messages(t *testing.T) { Value: []byte(fmt.Sprintf("test_%d", i)), }, } - - got, err := stream.Recv() - require.NoError(t, err, "Receiving message from the stream") - results = append(results, got.Results) - - eot, err := stream.Recv() - require.NoError(t, err, "Receiving message from the stream") - require.True(t, eot.Status.Eot, "Expected EOT message") } require.ElementsMatch(t, results, expectedResults) @@ -350,3 +360,73 @@ func TestService_MapFn_Panic(t *testing.T) { expectedStatus := status.Convert(status.Errorf(codes.Internal, "error processing requests: panic inside mapStream handler: map failed")) require.Equal(t, expectedStatus, gotStatus) } + +func TestService_MapFn_MultipleRequestsAndResponses(t *testing.T) { + svc := &Service{ + MapperStream: MapStreamerFunc(func(ctx context.Context, keys []string, datum Datum, messageCh chan<- Message) { + defer close(messageCh) + for i := 0; i < 3; i++ { // Send multiple responses for each request + msg := fmt.Sprintf("response_%d_for_%s", i, string(datum.Value())) + messageCh <- NewMessage([]byte(msg)).WithKeys([]string{keys[0] + "_test"}) + } + }), + } + conn := newTestServer(t, func(server *grpc.Server) { + proto.RegisterMapServer(server, svc) + }) + + client := proto.NewMapClient(conn) + stream, err := client.MapFn(context.Background()) + require.NoError(t, err, "Creating stream") + + doHandshake(t, stream) + + const msgCount = 5 + for i := 0; i < msgCount; i++ { + msg := proto.MapRequest{ + Request: &proto.MapRequest_Request{ + Keys: []string{"client"}, + Value: []byte(fmt.Sprintf("test_%d", i)), + EventTime: timestamppb.New(time.Time{}), + Watermark: timestamppb.New(time.Time{}), + }, + } + err = stream.Send(&msg) + require.NoError(t, err, "Sending message over the stream") + } + err = stream.CloseSend() + require.NoError(t, err, "Closing the send direction of the stream") + + expectedResults := make([][]*proto.MapResponse_Result, msgCount*3) + results := make([][]*proto.MapResponse_Result, 0) + eotCount := 0 + + for { + got, err := stream.Recv() + if err == io.EOF { + break + } + require.NoError(t, err, "Receiving message from the stream") + + if got.Status != nil && got.Status.Eot { + eotCount++ + } else { + results = append(results, got.Results) + } + } + + require.Equal(t, msgCount, eotCount, "Expected number of EOT messages") + + for i := 0; i < msgCount; i++ { + for j := 0; j < 3; j++ { + expectedResults[i*3+j] = []*proto.MapResponse_Result{ + { + Keys: []string{"client_test"}, + Value: []byte(fmt.Sprintf("response_%d_for_test_%d", j, i)), + }, + } + } + } + + require.ElementsMatch(t, results, expectedResults) +}