From 4dfa9f835c3dc7ee93b59c449aaba134ca142500 Mon Sep 17 00:00:00 2001 From: Yashash H L Date: Thu, 19 Dec 2024 15:24:20 +0530 Subject: [PATCH] fix tests Signed-off-by: Yashash H L --- pkg/mapstreamer/service.go | 20 ---------- pkg/mapstreamer/service_test.go | 71 ++++++++++++++++++++++----------- 2 files changed, 47 insertions(+), 44 deletions(-) diff --git a/pkg/mapstreamer/service.go b/pkg/mapstreamer/service.go index c8ac4b4a..8795ff08 100644 --- a/pkg/mapstreamer/service.go +++ b/pkg/mapstreamer/service.go @@ -70,26 +70,6 @@ func (fs *Service) MapFn(stream mappb.Map_MapFnServer) error { // Use error group to manage goroutines, the groupCtx is cancelled when any of the // goroutines return an error for the first time or the first time the wait returns. g, groupCtx := errgroup.WithContext(ctx) - - // Channel to collect responses - responseCh := make(chan *mappb.MapResponse, 500) // FIXME: identify the right buffer size - defer close(responseCh) - - // Dedicated goroutine to send responses to the stream - g.Go(func() error { - for { - select { - case resp := <-responseCh: - if err := stream.Send(resp); err != nil { - log.Printf("Failed to send response: %v", err) - return err - } - case <-groupCtx.Done(): - return groupCtx.Err() - } - } - }) - var readErr error // Read requests from the stream and process them outer: diff --git a/pkg/mapstreamer/service_test.go b/pkg/mapstreamer/service_test.go index fade1fd1..fd5b35e9 100644 --- a/pkg/mapstreamer/service_test.go +++ b/pkg/mapstreamer/service_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "net" "testing" "time" @@ -294,6 +295,23 @@ func TestService_MapFn_Multiple_Messages(t *testing.T) { expectedResults := make([][]*proto.MapResponse_Result, msgCount) results := make([][]*proto.MapResponse_Result, 0) + eotCount := 0 + + for { + got, err := stream.Recv() + if err == io.EOF { + break + } + require.NoError(t, err, "Receiving message from the stream") + + if got.Status != nil && got.Status.Eot { + eotCount++ + } else { + results = append(results, got.Results) + } + } + + require.Equal(t, msgCount, eotCount, "Expected number of EOT messages") for i := 0; i < msgCount; i++ { expectedResults[i] = []*proto.MapResponse_Result{ @@ -302,14 +320,6 @@ func TestService_MapFn_Multiple_Messages(t *testing.T) { Value: []byte(fmt.Sprintf("test_%d", i)), }, } - - got, err := stream.Recv() - require.NoError(t, err, "Receiving message from the stream") - results = append(results, got.Results) - - eot, err := stream.Recv() - require.NoError(t, err, "Receiving message from the stream") - require.True(t, eot.Status.Eot, "Expected EOT message") } require.ElementsMatch(t, results, expectedResults) @@ -351,12 +361,14 @@ func TestService_MapFn_Panic(t *testing.T) { require.Equal(t, expectedStatus, gotStatus) } -func TestService_MapFn_MultipleRequestsTogether(t *testing.T) { +func TestService_MapFn_MultipleRequestsAndResponses(t *testing.T) { svc := &Service{ MapperStream: MapStreamerFunc(func(ctx context.Context, keys []string, datum Datum, messageCh chan<- Message) { defer close(messageCh) - msg := datum.Value() - messageCh <- NewMessage(msg).WithKeys([]string{keys[0] + "_test"}) + for i := 0; i < 3; i++ { // Send multiple responses for each request + msg := fmt.Sprintf("response_%d_for_%s", i, string(datum.Value())) + messageCh <- NewMessage([]byte(msg)).WithKeys([]string{keys[0] + "_test"}) + } }), } conn := newTestServer(t, func(server *grpc.Server) { @@ -385,24 +397,35 @@ func TestService_MapFn_MultipleRequestsTogether(t *testing.T) { err = stream.CloseSend() require.NoError(t, err, "Closing the send direction of the stream") - expectedResults := make([][]*proto.MapResponse_Result, msgCount) + expectedResults := make([][]*proto.MapResponse_Result, msgCount*3) results := make([][]*proto.MapResponse_Result, 0) + eotCount := 0 - for i := 0; i < msgCount; i++ { - expectedResults[i] = []*proto.MapResponse_Result{ - { - Keys: []string{"client_test"}, - Value: []byte(fmt.Sprintf("test_%d", i)), - }, - } - + for { got, err := stream.Recv() + if err == io.EOF { + break + } require.NoError(t, err, "Receiving message from the stream") - results = append(results, got.Results) - eot, err := stream.Recv() - require.NoError(t, err, "Receiving message from the stream") - require.True(t, eot.Status.Eot, "Expected EOT message") + if got.Status != nil && got.Status.Eot { + eotCount++ + } else { + results = append(results, got.Results) + } + } + + require.Equal(t, msgCount, eotCount, "Expected number of EOT messages") + + for i := 0; i < msgCount; i++ { + for j := 0; j < 3; j++ { + expectedResults[i*3+j] = []*proto.MapResponse_Result{ + { + Keys: []string{"client_test"}, + Value: []byte(fmt.Sprintf("response_%d_for_test_%d", j, i)), + }, + } + } } require.ElementsMatch(t, results, expectedResults)