From 4f0515bd1421e0ed5b8a29119fa47d060247e679 Mon Sep 17 00:00:00 2001 From: Yashash H L Date: Thu, 19 Dec 2024 15:03:30 +0530 Subject: [PATCH] chore: Map Stream to Support Concurrent Requests Signed-off-by: Yashash H L --- pkg/mapstreamer/service.go | 84 +++++++++++++++++++++++---------- pkg/mapstreamer/service_test.go | 57 ++++++++++++++++++++++ 2 files changed, 117 insertions(+), 24 deletions(-) diff --git a/pkg/mapstreamer/service.go b/pkg/mapstreamer/service.go index 51968e14..c8ac4b4a 100644 --- a/pkg/mapstreamer/service.go +++ b/pkg/mapstreamer/service.go @@ -64,43 +64,79 @@ func (fs *Service) MapFn(stream mappb.Map_MapFnServer) error { return err } - ctx := stream.Context() + ctx, cancel := context.WithCancel(stream.Context()) + defer cancel() + + // Use error group to manage goroutines, the groupCtx is cancelled when any of the + // goroutines return an error for the first time or the first time the wait returns. + g, groupCtx := errgroup.WithContext(ctx) + + // Channel to collect responses + responseCh := make(chan *mappb.MapResponse, 500) // FIXME: identify the right buffer size + defer close(responseCh) + + // Dedicated goroutine to send responses to the stream + g.Go(func() error { + for { + select { + case resp := <-responseCh: + if err := stream.Send(resp); err != nil { + log.Printf("Failed to send response: %v", err) + return err + } + case <-groupCtx.Done(): + return groupCtx.Err() + } + } + }) + + var readErr error + // Read requests from the stream and process them +outer: for { - req, err := recvWithContext(ctx, stream) + req, err := recvWithContext(groupCtx, stream) if errors.Is(err, context.Canceled) { - log.Printf("Context cancelled, stopping the MapStreamFn") - break + log.Printf("Context cancelled, stopping the MapFn") + break outer } if errors.Is(err, io.EOF) { - log.Printf("EOF received, stopping the MapStreamFn") - break + log.Printf("EOF received, stopping the MapFn") + break outer } - if err != nil { log.Printf("Failed to receive request: %v", err) - return err + readErr = err + // read loop is not part of the error group, so we need to cancel the context + // to signal the other goroutines to stop processing. + cancel() + break outer } + g.Go(func() error { + messageCh := make(chan Message) + workerGroup, innerCtx := errgroup.WithContext(groupCtx) - messageCh := make(chan Message) - g, groupCtx := errgroup.WithContext(ctx) + workerGroup.Go(func() error { + return fs.invokeHandler(innerCtx, req, messageCh) + }) - g.Go(func() error { - return fs.invokeHandler(groupCtx, req, messageCh) - }) + workerGroup.Go(func() error { + return fs.writeResponseToClient(innerCtx, stream, req.GetId(), messageCh) + }) - g.Go(func() error { - return fs.writeResponseToClient(groupCtx, stream, req.GetId(), messageCh) + return workerGroup.Wait() }) + } - // Wait for the error group to finish - if err := g.Wait(); err != nil { - log.Printf("error processing requests: %v", err) - if err == io.EOF { - return nil - } - fs.shutdownCh <- struct{}{} - return status.Errorf(codes.Internal, "error processing requests: %v", err) - } + // wait for all goroutines to finish + if err := g.Wait(); err != nil { + log.Printf("Stopping the MapFn with err, %s", err) + fs.shutdownCh <- struct{}{} + return status.Errorf(codes.Internal, "error processing requests: %v", err) + } + + // check if there was an error while reading from the stream + if readErr != nil { + return status.Errorf(codes.Internal, readErr.Error()) } return nil diff --git a/pkg/mapstreamer/service_test.go b/pkg/mapstreamer/service_test.go index c026491c..fade1fd1 100644 --- a/pkg/mapstreamer/service_test.go +++ b/pkg/mapstreamer/service_test.go @@ -350,3 +350,60 @@ func TestService_MapFn_Panic(t *testing.T) { expectedStatus := status.Convert(status.Errorf(codes.Internal, "error processing requests: panic inside mapStream handler: map failed")) require.Equal(t, expectedStatus, gotStatus) } + +func TestService_MapFn_MultipleRequestsTogether(t *testing.T) { + svc := &Service{ + MapperStream: MapStreamerFunc(func(ctx context.Context, keys []string, datum Datum, messageCh chan<- Message) { + defer close(messageCh) + msg := datum.Value() + messageCh <- NewMessage(msg).WithKeys([]string{keys[0] + "_test"}) + }), + } + conn := newTestServer(t, func(server *grpc.Server) { + proto.RegisterMapServer(server, svc) + }) + + client := proto.NewMapClient(conn) + stream, err := client.MapFn(context.Background()) + require.NoError(t, err, "Creating stream") + + doHandshake(t, stream) + + const msgCount = 5 + for i := 0; i < msgCount; i++ { + msg := proto.MapRequest{ + Request: &proto.MapRequest_Request{ + Keys: []string{"client"}, + Value: []byte(fmt.Sprintf("test_%d", i)), + EventTime: timestamppb.New(time.Time{}), + Watermark: timestamppb.New(time.Time{}), + }, + } + err = stream.Send(&msg) + require.NoError(t, err, "Sending message over the stream") + } + err = stream.CloseSend() + require.NoError(t, err, "Closing the send direction of the stream") + + expectedResults := make([][]*proto.MapResponse_Result, msgCount) + results := make([][]*proto.MapResponse_Result, 0) + + for i := 0; i < msgCount; i++ { + expectedResults[i] = []*proto.MapResponse_Result{ + { + Keys: []string{"client_test"}, + Value: []byte(fmt.Sprintf("test_%d", i)), + }, + } + + got, err := stream.Recv() + require.NoError(t, err, "Receiving message from the stream") + results = append(results, got.Results) + + eot, err := stream.Recv() + require.NoError(t, err, "Receiving message from the stream") + require.True(t, eot.Status.Eot, "Expected EOT message") + } + + require.ElementsMatch(t, results, expectedResults) +}