Skip to content

Commit

Permalink
Unit test for transformer panick
Browse files Browse the repository at this point in the history
Signed-off-by: Sreekanth <[email protected]>
  • Loading branch information
BulkBeing committed Oct 1, 2024
1 parent 8ff0657 commit 92aa9d3
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 22 deletions.
36 changes: 26 additions & 10 deletions pkg/sourcetransformer/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ func (fs *Service) IsReady(context.Context, *emptypb.Empty) (*v1.ReadyResponse,
return &v1.ReadyResponse{Ready: true}, nil
}

var errTransformerPanic = errors.New("transformer function panicked")

// SourceTransformFn applies a function to each request element.
// In addition to map function, SourceTransformFn also supports assigning a new event time to response.
// SourceTransformFn can be used only at source vertex by source data transformer.
Expand Down Expand Up @@ -69,27 +71,36 @@ func (fs *Service) SourceTransformFn(stream v1.SourceTransform_SourceTransformFn
}

ctx := stream.Context()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// We depend on grpContext to cancel all goroutines, since it will be automatically closed when the first function returns a non-nil error.
// This error will be caught later with grp.Wait()
grp, grpCtx := errgroup.WithContext(ctx)

senderCh := make(chan *v1.SourceTransformResponse, 500) // FIXME: identify the right buffer size
// goroutine to send the response to the stream
grp.Go(func() error {
for {
var resp *v1.SourceTransformResponse
select {
case <-grpCtx.Done():
return grpCtx.Err()
default:
case resp = <-senderCh:
}
if err := stream.Send(<-senderCh); err != nil {
cancel()
return err
if err := stream.Send(resp); err != nil {
return fmt.Errorf("failed to send response to client: %w", err)
}
}
})

outer:
for {
// Stop reading new messages when we are shutting down
select {
case <-grpCtx.Done():
// If the context was cancelled while this loop is running, it will be caught and returned in one of the errgroup's goroutines.
break outer
default:
}

d, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
Expand All @@ -99,12 +110,17 @@ func (fs *Service) SourceTransformFn(stream v1.SourceTransform_SourceTransformFn
}

req := d.Request
grp.Go(func() error {
grp.Go(func() (err error) {
defer func() {
if r := recover(); r != nil {
log.Printf("panic inside source handler: %v %v", r, string(debug.Stack()))
cancel()
fs.shutdownCh <- struct{}{}
log.Printf("Panic inside source transform handler: %v %v", r, string(debug.Stack()))
// We only listen for 1 message on the shutdown channel. If multiple requests panic, only the first one will succeed.
// The one that succeds returns the errTransformerPanic. This causes grpCtx to be cancelled.
select {
case fs.shutdownCh <- struct{}{}:
case <-grpCtx.Done():
}
err = errTransformerPanic
}
}()
var hd = NewHandlerDatum(req.GetValue(), req.EventTime.AsTime(), req.Watermark.AsTime(), req.Headers)
Expand Down
60 changes: 48 additions & 12 deletions pkg/sourcetransformer/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"

proto "github.com/numaproj/numaflow-go/pkg/apis/proto/sourcetransform/v1"
"google.golang.org/protobuf/types/known/timestamppb"
)

func newServer(t *testing.T, register func(server *grpc.Server)) *grpc.ClientConn {
func newTestServer(t *testing.T, register func(server *grpc.Server)) *grpc.ClientConn {
lis := bufconn.Listen(1024 * 1024)
t.Cleanup(func() {
_ = lis.Close()
Expand All @@ -43,12 +45,7 @@ func newServer(t *testing.T, register func(server *grpc.Server)) *grpc.ClientCon
return lis.Dial()
}

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
t.Cleanup(func() {
cancel()
})

conn, err := grpc.DialContext(ctx, "", grpc.WithContextDialer(dialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
conn, err := grpc.NewClient("passthrough://", grpc.WithContextDialer(dialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
t.Cleanup(func() {
_ = conn.Close()
})
Expand Down Expand Up @@ -170,21 +167,21 @@ func TestService_sourceTransformFn(t *testing.T) {
Transformer: tt.handler,
}

conn := newServer(t, func(server *grpc.Server) {
conn := newTestServer(t, func(server *grpc.Server) {
proto.RegisterSourceTransformServer(server, svc)
})

client := proto.NewSourceTransformClient(conn)
stream, err := client.SourceTransformFn(context.Background())
assert.NoError(t, err, "Creating stream")
require.NoError(t, err, "Creating stream")

doHandshake(t, stream)

err = stream.Send(tt.args.d)
assert.NoError(t, err, "Sending message over the stream")
require.NoError(t, err, "Sending message over the stream")

got, err := stream.Recv()
assert.NoError(t, err, "Receiving message from the stream")
require.NoError(t, err, "Receiving message from the stream")

assert.Equal(t, got.Results, tt.want.Results)
})
Expand Down Expand Up @@ -215,7 +212,7 @@ func TestService_SourceTransformFn_Multiple_Messages(t *testing.T) {
return MessagesBuilder().Append(NewMessage(msg, testTime).WithKeys([]string{keys[0] + "_test"}))
}),
}
conn := newServer(t, func(server *grpc.Server) {
conn := newTestServer(t, func(server *grpc.Server) {
proto.RegisterSourceTransformServer(server, svc)
})

Expand Down Expand Up @@ -260,3 +257,42 @@ func TestService_SourceTransformFn_Multiple_Messages(t *testing.T) {
}
require.ElementsMatch(t, results, expectedResults)
}

func TestService_SourceTransformFn_Panic(t *testing.T) {
svc := &Service{
Transformer: SourceTransformFunc(func(ctx context.Context, keys []string, datum Datum) Messages {
panic("transformer panicked")
}),
// panic in the transformer causes the server to send a shutdown signal to shutdownCh channel.
// The function that errgroup runs in a goroutine will be blocked until this shutdown signal is received somewhere else.
// Since we don't listen for shutdown signal in the tests, we use buffered channel to unblock the server function.
shutdownCh: make(chan<- struct{}, 1),
}
conn := newTestServer(t, func(server *grpc.Server) {
proto.RegisterSourceTransformServer(server, svc)
})

client := proto.NewSourceTransformClient(conn)
stream, err := client.SourceTransformFn(context.Background())
require.NoError(t, err, "Creating stream")

doHandshake(t, stream)

msg := proto.SourceTransformRequest{
Request: &proto.SourceTransformRequest_Request{
Keys: []string{"client"},
Value: []byte("test"),
EventTime: timestamppb.New(time.Time{}),
Watermark: timestamppb.New(time.Time{}),
},
}
err = stream.Send(&msg)
require.NoError(t, err, "Sending message over the stream")
err = stream.CloseSend()
require.NoError(t, err, "Closing the send direction of the stream")
_, err = stream.Recv()
require.Error(t, err, "Expected error while receiving message from the stream")
gotStatus, _ := status.FromError(err)
expectedStatus := status.Convert(status.Errorf(codes.Internal, errTransformerPanic.Error()))
require.Equal(t, expectedStatus, gotStatus)
}

0 comments on commit 92aa9d3

Please sign in to comment.