Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Map Stream to Support Concurrent Requests #171

Merged
merged 2 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 40 additions & 24 deletions pkg/mapstreamer/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,43 +64,59 @@ func (fs *Service) MapFn(stream mappb.Map_MapFnServer) error {
return err
}

ctx := stream.Context()
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()

// Use error group to manage goroutines, the groupCtx is cancelled when any of the
// goroutines return an error for the first time or the first time the wait returns.
g, groupCtx := errgroup.WithContext(ctx)
var readErr error
// Read requests from the stream and process them
outer:
for {
req, err := recvWithContext(ctx, stream)
req, err := recvWithContext(groupCtx, stream)
if errors.Is(err, context.Canceled) {
log.Printf("Context cancelled, stopping the MapStreamFn")
break
log.Printf("Context cancelled, stopping the MapFn")
break outer
}
if errors.Is(err, io.EOF) {
log.Printf("EOF received, stopping the MapStreamFn")
break
log.Printf("EOF received, stopping the MapFn")
break outer
}

if err != nil {
log.Printf("Failed to receive request: %v", err)
return err
readErr = err
// read loop is not part of the error group, so we need to cancel the context
// to signal the other goroutines to stop processing.
cancel()
break outer
}
g.Go(func() error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is bounded because we control the concurrency on the numaflow side, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

messageCh := make(chan Message)
workerGroup, innerCtx := errgroup.WithContext(groupCtx)

messageCh := make(chan Message)
g, groupCtx := errgroup.WithContext(ctx)
workerGroup.Go(func() error {
return fs.invokeHandler(innerCtx, req, messageCh)
})

g.Go(func() error {
return fs.invokeHandler(groupCtx, req, messageCh)
})
workerGroup.Go(func() error {
return fs.writeResponseToClient(innerCtx, stream, req.GetId(), messageCh)
})

g.Go(func() error {
return fs.writeResponseToClient(groupCtx, stream, req.GetId(), messageCh)
return workerGroup.Wait()
})
}

// Wait for the error group to finish
if err := g.Wait(); err != nil {
log.Printf("error processing requests: %v", err)
if err == io.EOF {
return nil
}
fs.shutdownCh <- struct{}{}
return status.Errorf(codes.Internal, "error processing requests: %v", err)
}
// wait for all goroutines to finish
if err := g.Wait(); err != nil {
log.Printf("Stopping the MapFn with err, %s", err)
fs.shutdownCh <- struct{}{}
return status.Errorf(codes.Internal, "error processing requests: %v", err)
}

// check if there was an error while reading from the stream
if readErr != nil {
return status.Errorf(codes.Internal, readErr.Error())
}

return nil
Expand Down
96 changes: 88 additions & 8 deletions pkg/mapstreamer/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"testing"
"time"
Expand Down Expand Up @@ -294,6 +295,23 @@ func TestService_MapFn_Multiple_Messages(t *testing.T) {

expectedResults := make([][]*proto.MapResponse_Result, msgCount)
results := make([][]*proto.MapResponse_Result, 0)
eotCount := 0

for {
got, err := stream.Recv()
if err == io.EOF {
break
}
require.NoError(t, err, "Receiving message from the stream")

if got.Status != nil && got.Status.Eot {
eotCount++
} else {
results = append(results, got.Results)
}
}

require.Equal(t, msgCount, eotCount, "Expected number of EOT messages")

for i := 0; i < msgCount; i++ {
expectedResults[i] = []*proto.MapResponse_Result{
Expand All @@ -302,14 +320,6 @@ func TestService_MapFn_Multiple_Messages(t *testing.T) {
Value: []byte(fmt.Sprintf("test_%d", i)),
},
}

got, err := stream.Recv()
require.NoError(t, err, "Receiving message from the stream")
results = append(results, got.Results)

eot, err := stream.Recv()
require.NoError(t, err, "Receiving message from the stream")
require.True(t, eot.Status.Eot, "Expected EOT message")
}

require.ElementsMatch(t, results, expectedResults)
Expand Down Expand Up @@ -350,3 +360,73 @@ func TestService_MapFn_Panic(t *testing.T) {
expectedStatus := status.Convert(status.Errorf(codes.Internal, "error processing requests: panic inside mapStream handler: map failed"))
require.Equal(t, expectedStatus, gotStatus)
}

func TestService_MapFn_MultipleRequestsAndResponses(t *testing.T) {
svc := &Service{
MapperStream: MapStreamerFunc(func(ctx context.Context, keys []string, datum Datum, messageCh chan<- Message) {
defer close(messageCh)
for i := 0; i < 3; i++ { // Send multiple responses for each request
msg := fmt.Sprintf("response_%d_for_%s", i, string(datum.Value()))
messageCh <- NewMessage([]byte(msg)).WithKeys([]string{keys[0] + "_test"})
}
}),
}
conn := newTestServer(t, func(server *grpc.Server) {
proto.RegisterMapServer(server, svc)
})

client := proto.NewMapClient(conn)
stream, err := client.MapFn(context.Background())
require.NoError(t, err, "Creating stream")

doHandshake(t, stream)

const msgCount = 5
for i := 0; i < msgCount; i++ {
msg := proto.MapRequest{
Request: &proto.MapRequest_Request{
Keys: []string{"client"},
Value: []byte(fmt.Sprintf("test_%d", i)),
EventTime: timestamppb.New(time.Time{}),
Watermark: timestamppb.New(time.Time{}),
},
}
err = stream.Send(&msg)
require.NoError(t, err, "Sending message over the stream")
}
err = stream.CloseSend()
require.NoError(t, err, "Closing the send direction of the stream")

expectedResults := make([][]*proto.MapResponse_Result, msgCount*3)
results := make([][]*proto.MapResponse_Result, 0)
eotCount := 0

for {
got, err := stream.Recv()
if err == io.EOF {
break
}
require.NoError(t, err, "Receiving message from the stream")

if got.Status != nil && got.Status.Eot {
eotCount++
} else {
results = append(results, got.Results)
}
}

require.Equal(t, msgCount, eotCount, "Expected number of EOT messages")

for i := 0; i < msgCount; i++ {
for j := 0; j < 3; j++ {
expectedResults[i*3+j] = []*proto.MapResponse_Result{
{
Keys: []string{"client_test"},
Value: []byte(fmt.Sprintf("response_%d_for_test_%d", j, i)),
},
}
}
}

require.ElementsMatch(t, results, expectedResults)
}
Loading