From 92aa9d38dce6f48691b26a11236cd928de5ff607 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Tue, 1 Oct 2024 07:03:17 +0530 Subject: [PATCH] Unit test for transformer panick Signed-off-by: Sreekanth --- pkg/sourcetransformer/service.go | 36 +++++++++++----- pkg/sourcetransformer/service_test.go | 60 +++++++++++++++++++++------ 2 files changed, 74 insertions(+), 22 deletions(-) diff --git a/pkg/sourcetransformer/service.go b/pkg/sourcetransformer/service.go index 2b749012..bbd91f19 100644 --- a/pkg/sourcetransformer/service.go +++ b/pkg/sourcetransformer/service.go @@ -38,6 +38,8 @@ func (fs *Service) IsReady(context.Context, *emptypb.Empty) (*v1.ReadyResponse, return &v1.ReadyResponse{Ready: true}, nil } +var errTransformerPanic = errors.New("transformer function panicked") + // SourceTransformFn applies a function to each request element. // In addition to map function, SourceTransformFn also supports assigning a new event time to response. // SourceTransformFn can be used only at source vertex by source data transformer. @@ -69,27 +71,36 @@ func (fs *Service) SourceTransformFn(stream v1.SourceTransform_SourceTransformFn } ctx := stream.Context() - ctx, cancel := context.WithCancel(ctx) - defer cancel() + // We depend on grpContext to cancel all goroutines, since it will be automatically closed when the first function returns a non-nil error. + // This error will be caught later with grp.Wait() grp, grpCtx := errgroup.WithContext(ctx) senderCh := make(chan *v1.SourceTransformResponse, 500) // FIXME: identify the right buffer size // goroutine to send the response to the stream grp.Go(func() error { for { + var resp *v1.SourceTransformResponse select { case <-grpCtx.Done(): return grpCtx.Err() - default: + case resp = <-senderCh: } - if err := stream.Send(<-senderCh); err != nil { - cancel() - return err + if err := stream.Send(resp); err != nil { + return fmt.Errorf("failed to send response to client: %w", err) } } }) +outer: for { + // Stop reading new messages when we are shutting down + select { + case <-grpCtx.Done(): + // If the context was cancelled while this loop is running, it will be caught and returned in one of the errgroup's goroutines. + break outer + default: + } + d, err := stream.Recv() if err != nil { if errors.Is(err, io.EOF) { @@ -99,12 +110,17 @@ func (fs *Service) SourceTransformFn(stream v1.SourceTransform_SourceTransformFn } req := d.Request - grp.Go(func() error { + grp.Go(func() (err error) { defer func() { if r := recover(); r != nil { - log.Printf("panic inside source handler: %v %v", r, string(debug.Stack())) - cancel() - fs.shutdownCh <- struct{}{} + log.Printf("Panic inside source transform handler: %v %v", r, string(debug.Stack())) + // We only listen for 1 message on the shutdown channel. If multiple requests panic, only the first one will succeed. + // The one that succeds returns the errTransformerPanic. This causes grpCtx to be cancelled. + select { + case fs.shutdownCh <- struct{}{}: + case <-grpCtx.Done(): + } + err = errTransformerPanic } }() var hd = NewHandlerDatum(req.GetValue(), req.EventTime.AsTime(), req.Watermark.AsTime(), req.Headers) diff --git a/pkg/sourcetransformer/service_test.go b/pkg/sourcetransformer/service_test.go index b9f89167..79449c9a 100644 --- a/pkg/sourcetransformer/service_test.go +++ b/pkg/sourcetransformer/service_test.go @@ -11,14 +11,16 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" proto "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1" "google.golang.org/protobuf/types/known/timestamppb" ) -func newServer(t *testing.T, register func(server *grpc.Server)) *grpc.ClientConn { +func newTestServer(t *testing.T, register func(server *grpc.Server)) *grpc.ClientConn { lis := bufconn.Listen(1024 * 1024) t.Cleanup(func() { _ = lis.Close() @@ -43,12 +45,7 @@ func newServer(t *testing.T, register func(server *grpc.Server)) *grpc.ClientCon return lis.Dial() } - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - t.Cleanup(func() { - cancel() - }) - - conn, err := grpc.DialContext(ctx, "", grpc.WithContextDialer(dialer), grpc.WithTransportCredentials(insecure.NewCredentials())) + conn, err := grpc.NewClient("passthrough://", grpc.WithContextDialer(dialer), grpc.WithTransportCredentials(insecure.NewCredentials())) t.Cleanup(func() { _ = conn.Close() }) @@ -170,21 +167,21 @@ func TestService_sourceTransformFn(t *testing.T) { Transformer: tt.handler, } - conn := newServer(t, func(server *grpc.Server) { + conn := newTestServer(t, func(server *grpc.Server) { proto.RegisterSourceTransformServer(server, svc) }) client := proto.NewSourceTransformClient(conn) stream, err := client.SourceTransformFn(context.Background()) - assert.NoError(t, err, "Creating stream") + require.NoError(t, err, "Creating stream") doHandshake(t, stream) err = stream.Send(tt.args.d) - assert.NoError(t, err, "Sending message over the stream") + require.NoError(t, err, "Sending message over the stream") got, err := stream.Recv() - assert.NoError(t, err, "Receiving message from the stream") + require.NoError(t, err, "Receiving message from the stream") assert.Equal(t, got.Results, tt.want.Results) }) @@ -215,7 +212,7 @@ func TestService_SourceTransformFn_Multiple_Messages(t *testing.T) { return MessagesBuilder().Append(NewMessage(msg, testTime).WithKeys([]string{keys[0] + "_test"})) }), } - conn := newServer(t, func(server *grpc.Server) { + conn := newTestServer(t, func(server *grpc.Server) { proto.RegisterSourceTransformServer(server, svc) }) @@ -260,3 +257,42 @@ func TestService_SourceTransformFn_Multiple_Messages(t *testing.T) { } require.ElementsMatch(t, results, expectedResults) } + +func TestService_SourceTransformFn_Panic(t *testing.T) { + svc := &Service{ + Transformer: SourceTransformFunc(func(ctx context.Context, keys []string, datum Datum) Messages { + panic("transformer panicked") + }), + // panic in the transformer causes the server to send a shutdown signal to shutdownCh channel. + // The function that errgroup runs in a goroutine will be blocked until this shutdown signal is received somewhere else. + // Since we don't listen for shutdown signal in the tests, we use buffered channel to unblock the server function. + shutdownCh: make(chan<- struct{}, 1), + } + conn := newTestServer(t, func(server *grpc.Server) { + proto.RegisterSourceTransformServer(server, svc) + }) + + client := proto.NewSourceTransformClient(conn) + stream, err := client.SourceTransformFn(context.Background()) + require.NoError(t, err, "Creating stream") + + doHandshake(t, stream) + + msg := proto.SourceTransformRequest{ + Request: &proto.SourceTransformRequest_Request{ + Keys: []string{"client"}, + Value: []byte("test"), + EventTime: timestamppb.New(time.Time{}), + Watermark: timestamppb.New(time.Time{}), + }, + } + err = stream.Send(&msg) + require.NoError(t, err, "Sending message over the stream") + err = stream.CloseSend() + require.NoError(t, err, "Closing the send direction of the stream") + _, err = stream.Recv() + require.Error(t, err, "Expected error while receiving message from the stream") + gotStatus, _ := status.FromError(err) + expectedStatus := status.Convert(status.Errorf(codes.Internal, errTransformerPanic.Error())) + require.Equal(t, expectedStatus, gotStatus) +}