Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Yashash H L <[email protected]>
  • Loading branch information
yhl25 committed Dec 19, 2024
1 parent 4f0515b commit 4dfa9f8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 44 deletions.
20 changes: 0 additions & 20 deletions pkg/mapstreamer/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,6 @@ func (fs *Service) MapFn(stream mappb.Map_MapFnServer) error {
// Use error group to manage goroutines, the groupCtx is cancelled when any of the
// goroutines return an error for the first time or the first time the wait returns.
g, groupCtx := errgroup.WithContext(ctx)

// Channel to collect responses
responseCh := make(chan *mappb.MapResponse, 500) // FIXME: identify the right buffer size
defer close(responseCh)

// Dedicated goroutine to send responses to the stream
g.Go(func() error {
for {
select {
case resp := <-responseCh:
if err := stream.Send(resp); err != nil {
log.Printf("Failed to send response: %v", err)
return err
}
case <-groupCtx.Done():
return groupCtx.Err()
}
}
})

var readErr error
// Read requests from the stream and process them
outer:
Expand Down
71 changes: 47 additions & 24 deletions pkg/mapstreamer/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"testing"
"time"
Expand Down Expand Up @@ -294,6 +295,23 @@ func TestService_MapFn_Multiple_Messages(t *testing.T) {

expectedResults := make([][]*proto.MapResponse_Result, msgCount)
results := make([][]*proto.MapResponse_Result, 0)
eotCount := 0

for {
got, err := stream.Recv()
if err == io.EOF {
break
}
require.NoError(t, err, "Receiving message from the stream")

if got.Status != nil && got.Status.Eot {
eotCount++
} else {
results = append(results, got.Results)
}
}

require.Equal(t, msgCount, eotCount, "Expected number of EOT messages")

for i := 0; i < msgCount; i++ {
expectedResults[i] = []*proto.MapResponse_Result{
Expand All @@ -302,14 +320,6 @@ func TestService_MapFn_Multiple_Messages(t *testing.T) {
Value: []byte(fmt.Sprintf("test_%d", i)),
},
}

got, err := stream.Recv()
require.NoError(t, err, "Receiving message from the stream")
results = append(results, got.Results)

eot, err := stream.Recv()
require.NoError(t, err, "Receiving message from the stream")
require.True(t, eot.Status.Eot, "Expected EOT message")
}

require.ElementsMatch(t, results, expectedResults)
Expand Down Expand Up @@ -351,12 +361,14 @@ func TestService_MapFn_Panic(t *testing.T) {
require.Equal(t, expectedStatus, gotStatus)
}

func TestService_MapFn_MultipleRequestsTogether(t *testing.T) {
func TestService_MapFn_MultipleRequestsAndResponses(t *testing.T) {
svc := &Service{
MapperStream: MapStreamerFunc(func(ctx context.Context, keys []string, datum Datum, messageCh chan<- Message) {
defer close(messageCh)
msg := datum.Value()
messageCh <- NewMessage(msg).WithKeys([]string{keys[0] + "_test"})
for i := 0; i < 3; i++ { // Send multiple responses for each request
msg := fmt.Sprintf("response_%d_for_%s", i, string(datum.Value()))
messageCh <- NewMessage([]byte(msg)).WithKeys([]string{keys[0] + "_test"})
}
}),
}
conn := newTestServer(t, func(server *grpc.Server) {
Expand Down Expand Up @@ -385,24 +397,35 @@ func TestService_MapFn_MultipleRequestsTogether(t *testing.T) {
err = stream.CloseSend()
require.NoError(t, err, "Closing the send direction of the stream")

expectedResults := make([][]*proto.MapResponse_Result, msgCount)
expectedResults := make([][]*proto.MapResponse_Result, msgCount*3)
results := make([][]*proto.MapResponse_Result, 0)
eotCount := 0

for i := 0; i < msgCount; i++ {
expectedResults[i] = []*proto.MapResponse_Result{
{
Keys: []string{"client_test"},
Value: []byte(fmt.Sprintf("test_%d", i)),
},
}

for {
got, err := stream.Recv()
if err == io.EOF {
break
}
require.NoError(t, err, "Receiving message from the stream")
results = append(results, got.Results)

eot, err := stream.Recv()
require.NoError(t, err, "Receiving message from the stream")
require.True(t, eot.Status.Eot, "Expected EOT message")
if got.Status != nil && got.Status.Eot {
eotCount++
} else {
results = append(results, got.Results)
}
}

require.Equal(t, msgCount, eotCount, "Expected number of EOT messages")

for i := 0; i < msgCount; i++ {
for j := 0; j < 3; j++ {
expectedResults[i*3+j] = []*proto.MapResponse_Result{
{
Keys: []string{"client_test"},
Value: []byte(fmt.Sprintf("response_%d_for_test_%d", j, i)),
},
}
}
}

require.ElementsMatch(t, results, expectedResults)
Expand Down

0 comments on commit 4dfa9f8

Please sign in to comment.