Skip to content

Commit

Permalink
chore: Map Stream to Support Concurrent Requests
Browse files Browse the repository at this point in the history
Signed-off-by: Yashash H L <[email protected]>
  • Loading branch information
yhl25 committed Dec 19, 2024
1 parent 24c4ba9 commit 4f0515b
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 24 deletions.
84 changes: 60 additions & 24 deletions pkg/mapstreamer/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,43 +64,79 @@ func (fs *Service) MapFn(stream mappb.Map_MapFnServer) error {
return err
}

ctx := stream.Context()
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()

// Use error group to manage goroutines, the groupCtx is cancelled when any of the
// goroutines return an error for the first time or the first time the wait returns.
g, groupCtx := errgroup.WithContext(ctx)

// Channel to collect responses
responseCh := make(chan *mappb.MapResponse, 500) // FIXME: identify the right buffer size
defer close(responseCh)

// Dedicated goroutine to send responses to the stream
g.Go(func() error {
for {
select {
case resp := <-responseCh:
if err := stream.Send(resp); err != nil {
log.Printf("Failed to send response: %v", err)
return err
}
case <-groupCtx.Done():
return groupCtx.Err()
}
}
})

var readErr error
// Read requests from the stream and process them
outer:
for {
req, err := recvWithContext(ctx, stream)
req, err := recvWithContext(groupCtx, stream)
if errors.Is(err, context.Canceled) {
log.Printf("Context cancelled, stopping the MapStreamFn")
break
log.Printf("Context cancelled, stopping the MapFn")
break outer
}
if errors.Is(err, io.EOF) {
log.Printf("EOF received, stopping the MapStreamFn")
break
log.Printf("EOF received, stopping the MapFn")
break outer
}

if err != nil {
log.Printf("Failed to receive request: %v", err)
return err
readErr = err
// read loop is not part of the error group, so we need to cancel the context
// to signal the other goroutines to stop processing.
cancel()
break outer
}
g.Go(func() error {
messageCh := make(chan Message)
workerGroup, innerCtx := errgroup.WithContext(groupCtx)

messageCh := make(chan Message)
g, groupCtx := errgroup.WithContext(ctx)
workerGroup.Go(func() error {
return fs.invokeHandler(innerCtx, req, messageCh)
})

g.Go(func() error {
return fs.invokeHandler(groupCtx, req, messageCh)
})
workerGroup.Go(func() error {
return fs.writeResponseToClient(innerCtx, stream, req.GetId(), messageCh)
})

g.Go(func() error {
return fs.writeResponseToClient(groupCtx, stream, req.GetId(), messageCh)
return workerGroup.Wait()
})
}

// Wait for the error group to finish
if err := g.Wait(); err != nil {
log.Printf("error processing requests: %v", err)
if err == io.EOF {
return nil
}
fs.shutdownCh <- struct{}{}
return status.Errorf(codes.Internal, "error processing requests: %v", err)
}
// wait for all goroutines to finish
if err := g.Wait(); err != nil {
log.Printf("Stopping the MapFn with err, %s", err)
fs.shutdownCh <- struct{}{}
return status.Errorf(codes.Internal, "error processing requests: %v", err)
}

// check if there was an error while reading from the stream
if readErr != nil {
return status.Errorf(codes.Internal, readErr.Error())
}

return nil
Expand Down
57 changes: 57 additions & 0 deletions pkg/mapstreamer/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,60 @@ func TestService_MapFn_Panic(t *testing.T) {
expectedStatus := status.Convert(status.Errorf(codes.Internal, "error processing requests: panic inside mapStream handler: map failed"))
require.Equal(t, expectedStatus, gotStatus)
}

func TestService_MapFn_MultipleRequestsTogether(t *testing.T) {
svc := &Service{
MapperStream: MapStreamerFunc(func(ctx context.Context, keys []string, datum Datum, messageCh chan<- Message) {
defer close(messageCh)
msg := datum.Value()
messageCh <- NewMessage(msg).WithKeys([]string{keys[0] + "_test"})
}),
}
conn := newTestServer(t, func(server *grpc.Server) {
proto.RegisterMapServer(server, svc)
})

client := proto.NewMapClient(conn)
stream, err := client.MapFn(context.Background())
require.NoError(t, err, "Creating stream")

doHandshake(t, stream)

const msgCount = 5
for i := 0; i < msgCount; i++ {
msg := proto.MapRequest{
Request: &proto.MapRequest_Request{
Keys: []string{"client"},
Value: []byte(fmt.Sprintf("test_%d", i)),
EventTime: timestamppb.New(time.Time{}),
Watermark: timestamppb.New(time.Time{}),
},
}
err = stream.Send(&msg)
require.NoError(t, err, "Sending message over the stream")
}
err = stream.CloseSend()
require.NoError(t, err, "Closing the send direction of the stream")

expectedResults := make([][]*proto.MapResponse_Result, msgCount)
results := make([][]*proto.MapResponse_Result, 0)

for i := 0; i < msgCount; i++ {
expectedResults[i] = []*proto.MapResponse_Result{
{
Keys: []string{"client_test"},
Value: []byte(fmt.Sprintf("test_%d", i)),
},
}

got, err := stream.Recv()
require.NoError(t, err, "Receiving message from the stream")
results = append(results, got.Results)

eot, err := stream.Recv()
require.NoError(t, err, "Receiving message from the stream")
require.True(t, eot.Status.Eot, "Expected EOT message")
}

require.ElementsMatch(t, results, expectedResults)
}

0 comments on commit 4f0515b

Please sign in to comment.