diff --git a/pkg/batchmapper/service.go b/pkg/batchmapper/service.go index 1b8576aa..676e2349 100644 --- a/pkg/batchmapper/service.go +++ b/pkg/batchmapper/service.go @@ -44,14 +44,14 @@ func (fs *Service) MapFn(stream mappb.Map_MapFnServer) error { for { datumStreamCh := make(chan Datum) - g, ctx := errgroup.WithContext(ctx) + g, groupCtx := errgroup.WithContext(ctx) g.Go(func() error { - return fs.receiveRequests(ctx, stream, datumStreamCh) + return fs.receiveRequests(groupCtx, stream, datumStreamCh) }) g.Go(func() error { - return fs.processData(ctx, stream, datumStreamCh) + return fs.processData(groupCtx, stream, datumStreamCh) }) // Wait for the goroutines to finish @@ -91,19 +91,39 @@ func (fs *Service) performHandshake(stream mappb.Map_MapFnServer) error { return nil } +// recvWithContext wraps stream.Recv() to respect context cancellation. +func recvWithContext(ctx context.Context, stream mappb.Map_MapFnServer) (*mappb.MapRequest, error) { + type recvResult struct { + req *mappb.MapRequest + err error + } + + resultCh := make(chan recvResult, 1) + go func() { + req, err := stream.Recv() + resultCh <- recvResult{req: req, err: err} + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-resultCh: + return result.req, result.err + } +} + // receiveRequests receives the requests from the client and writes them to the datumStreamCh channel. func (fs *Service) receiveRequests(ctx context.Context, stream mappb.Map_MapFnServer, datumStreamCh chan<- Datum) error { defer close(datumStreamCh) for { - select { - case <-ctx.Done(): - return nil - default: + req, err := recvWithContext(ctx, stream) + if errors.Is(err, context.Canceled) { + log.Printf("Context cancelled, stopping the MapBatchFn") + return err } - req, err := stream.Recv() - if err == io.EOF { - log.Printf("end of batch map stream") + if errors.Is(err, io.EOF) { + log.Printf("EOF received, stopping the MapBatchFn") return err } if err != nil { @@ -123,8 +143,11 @@ func (fs *Service) receiveRequests(ctx context.Context, stream mappb.Map_MapFnSe watermark: req.GetRequest().GetWatermark().AsTime(), headers: req.GetRequest().GetHeaders(), } - - datumStreamCh <- datum + select { + case <-ctx.Done(): + return ctx.Err() + case datumStreamCh <- datum: + } } return nil } diff --git a/pkg/mapper/service.go b/pkg/mapper/service.go index a3fbdfae..bc1cb64d 100644 --- a/pkg/mapper/service.go +++ b/pkg/mapper/service.go @@ -2,6 +2,7 @@ package mapper import ( "context" + "errors" "fmt" "io" "log" @@ -35,6 +36,27 @@ func (fs *Service) IsReady(context.Context, *emptypb.Empty) (*mappb.ReadyRespons return &mappb.ReadyResponse{Ready: true}, nil } +// recvWithContext wraps stream.Recv() to respect context cancellation. +func recvWithContext(ctx context.Context, stream mappb.Map_MapFnServer) (*mappb.MapRequest, error) { + type recvResult struct { + req *mappb.MapRequest + err error + } + + resultCh := make(chan recvResult, 1) + go func() { + req, err := stream.Recv() + resultCh <- recvResult{req: req, err: err} + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-resultCh: + return result.req, result.err + } +} + // MapFn applies a user defined function to each request element and returns a list of results. func (fs *Service) MapFn(stream mappb.Map_MapFnServer) error { // perform handshake with client before processing requests @@ -72,13 +94,13 @@ func (fs *Service) MapFn(stream mappb.Map_MapFnServer) error { // Read requests from the stream and process them outer: for { - select { - case <-groupCtx.Done(): + req, err := recvWithContext(groupCtx, stream) + if errors.Is(err, context.Canceled) { + log.Printf("Context cancelled, stopping the MapFn") break outer - default: } - req, err := stream.Recv() - if err == io.EOF { + if errors.Is(err, io.EOF) { + log.Printf("EOF received, stopping the MapFn") break outer } if err != nil { @@ -96,6 +118,7 @@ outer: // wait for all goroutines to finish if err := g.Wait(); err != nil { + log.Printf("Stopping the MapFn with err, %s", err) fs.shutdownCh <- struct{}{} return status.Errorf(codes.Internal, "error processing requests: %v", err) } diff --git a/pkg/mapstreamer/service.go b/pkg/mapstreamer/service.go index 80c2d52d..51968e14 100644 --- a/pkg/mapstreamer/service.go +++ b/pkg/mapstreamer/service.go @@ -2,6 +2,7 @@ package mapstreamer import ( "context" + "errors" "fmt" "io" "log" @@ -35,6 +36,27 @@ func (fs *Service) IsReady(context.Context, *emptypb.Empty) (*mappb.ReadyRespons return &mappb.ReadyResponse{Ready: true}, nil } +// recvWithContext wraps stream.Recv() to respect context cancellation. +func recvWithContext(ctx context.Context, stream mappb.Map_MapFnServer) (*mappb.MapRequest, error) { + type recvResult struct { + req *mappb.MapRequest + err error + } + + resultCh := make(chan recvResult, 1) + go func() { + req, err := stream.Recv() + resultCh <- recvResult{req: req, err: err} + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-resultCh: + return result.req, result.err + } +} + // MapFn applies a function to each request element and streams the results back. func (fs *Service) MapFn(stream mappb.Map_MapFnServer) error { // perform handshake with client before processing requests @@ -43,12 +65,17 @@ func (fs *Service) MapFn(stream mappb.Map_MapFnServer) error { } ctx := stream.Context() - for { - req, err := stream.Recv() - if err == io.EOF { + req, err := recvWithContext(ctx, stream) + if errors.Is(err, context.Canceled) { + log.Printf("Context cancelled, stopping the MapStreamFn") break } + if errors.Is(err, io.EOF) { + log.Printf("EOF received, stopping the MapStreamFn") + break + } + if err != nil { log.Printf("Failed to receive request: %v", err) return err diff --git a/pkg/sinker/service.go b/pkg/sinker/service.go index f1e3fb46..8c8c7f9a 100644 --- a/pkg/sinker/service.go +++ b/pkg/sinker/service.go @@ -83,14 +83,14 @@ func (fs *Service) SinkFn(stream sinkpb.Sink_SinkFnServer) error { for { datumStreamCh := make(chan Datum) - g, ctx := errgroup.WithContext(ctx) + g, groupCtx := errgroup.WithContext(ctx) g.Go(func() error { - return fs.receiveRequests(stream, datumStreamCh) + return fs.receiveRequests(groupCtx, stream, datumStreamCh) }) g.Go(func() error { - return fs.processData(ctx, stream, datumStreamCh) + return fs.processData(groupCtx, stream, datumStreamCh) }) // Wait for the goroutines to finish @@ -130,12 +130,33 @@ func (fs *Service) performHandshake(stream sinkpb.Sink_SinkFnServer) error { return nil } +// recvWithContext wraps stream.Recv() to respect context cancellation. +func recvWithContext(ctx context.Context, stream sinkpb.Sink_SinkFnServer) (*sinkpb.SinkRequest, error) { + type recvResult struct { + req *sinkpb.SinkRequest + err error + } + + resultCh := make(chan recvResult, 1) + go func() { + req, err := stream.Recv() + resultCh <- recvResult{req: req, err: err} + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-resultCh: + return result.req, result.err + } +} + // receiveRequests receives the requests from the client writes them to the datumStreamCh channel. -func (fs *Service) receiveRequests(stream sinkpb.Sink_SinkFnServer, datumStreamCh chan<- Datum) error { +func (fs *Service) receiveRequests(ctx context.Context, stream sinkpb.Sink_SinkFnServer, datumStreamCh chan<- Datum) error { defer close(datumStreamCh) for { - req, err := stream.Recv() + req, err := recvWithContext(ctx, stream) if err == io.EOF { log.Printf("end of sink stream") return err @@ -158,7 +179,11 @@ func (fs *Service) receiveRequests(stream sinkpb.Sink_SinkFnServer, datumStreamC headers: req.GetRequest().GetHeaders(), } - datumStreamCh <- datum + select { + case <-ctx.Done(): + return nil + case datumStreamCh <- datum: + } } return nil } diff --git a/pkg/sourcer/examples/simple_source/go.mod b/pkg/sourcer/examples/simple_source/go.mod index 8d3e3836..f938de0f 100644 --- a/pkg/sourcer/examples/simple_source/go.mod +++ b/pkg/sourcer/examples/simple_source/go.mod @@ -5,7 +5,6 @@ go 1.21 replace github.com/numaproj/numaflow-go => ../../../.. require ( - github.com/google/uuid v1.6.0 github.com/numaproj/numaflow-go v0.8.1 github.com/stretchr/testify v1.9.0 ) @@ -16,6 +15,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect golang.org/x/net v0.29.0 // indirect + golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.25.0 // indirect golang.org/x/text v0.18.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect diff --git a/pkg/sourcer/examples/simple_source/go.sum b/pkg/sourcer/examples/simple_source/go.sum index 52a730fd..b7646a93 100644 --- a/pkg/sourcer/examples/simple_source/go.sum +++ b/pkg/sourcer/examples/simple_source/go.sum @@ -3,8 +3,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -17,6 +15,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= diff --git a/pkg/sourcer/service.go b/pkg/sourcer/service.go index 1e0fa61c..741cc28f 100644 --- a/pkg/sourcer/service.go +++ b/pkg/sourcer/service.go @@ -9,6 +9,7 @@ import ( "runtime/debug" "time" + "golang.org/x/sync/errgroup" "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/timestamppb" @@ -43,6 +44,8 @@ func (fs *Service) ReadFn(stream sourcepb.Source_ReadFnServer) error { if errors.Is(err, io.EOF) { return nil } + log.Printf("error processing requests: %v", err) + fs.shutdownCh <- struct{}{} return err } } @@ -76,19 +79,33 @@ func (fs *Service) performReadHandshake(stream sourcepb.Source_ReadFnServer) err return nil } +// recvWithContext wraps stream.Recv() to respect context cancellation for ReadFn. +func recvWithContextRead(ctx context.Context, stream sourcepb.Source_ReadFnServer) (*sourcepb.ReadRequest, error) { + type recvResult struct { + req *sourcepb.ReadRequest + err error + } + + resultCh := make(chan recvResult, 1) + go func() { + req, err := stream.Recv() + resultCh <- recvResult{req: req, err: err} + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-resultCh: + return result.req, result.err + } +} + // receiveReadRequests receives read requests from the client and invokes the source Read method. // writes the read data to the message channel. func (fs *Service) receiveReadRequests(ctx context.Context, stream sourcepb.Source_ReadFnServer) error { messageCh := make(chan Message) - // handle panic - defer func() { - if r := recover(); r != nil { - log.Printf("panic inside source handler: %v %v", r, string(debug.Stack())) - fs.shutdownCh <- struct{}{} - } - }() - - req, err := stream.Recv() + eg, groupCtx := errgroup.WithContext(ctx) + req, err := recvWithContextRead(groupCtx, stream) if err == io.EOF { log.Printf("end of read stream") return err @@ -98,22 +115,44 @@ func (fs *Service) receiveReadRequests(ctx context.Context, stream sourcepb.Sour return err } - go func() { - defer close(messageCh) + eg.Go(func() (err error) { + // handle panic + defer func() { + if r := recover(); r != nil { + log.Printf("panic inside source handler: %v %v", r, string(debug.Stack())) + err = fmt.Errorf("panic inside source handler: %v", r) + return + } + close(messageCh) + }() request := readRequest{ count: req.Request.GetNumRecords(), timeout: time.Duration(req.Request.GetTimeoutInMs()) * time.Millisecond, } fs.Source.Read(ctx, &request, messageCh) - }() + return nil + }) // invoke the processReadData method to send the read data to the client. - return fs.processReadData(stream, messageCh) + eg.Go(func() error { + return fs.processReadData(groupCtx, stream, messageCh) + }) + + if err := eg.Wait(); err != nil { + return err + } + return nil } // processReadData processes the read data and sends it to the client. -func (fs *Service) processReadData(stream sourcepb.Source_ReadFnServer, messageCh <-chan Message) error { - for msg := range messageCh { +func (fs *Service) processReadData(ctx context.Context, stream sourcepb.Source_ReadFnServer, messageCh <-chan Message) error { + select { + case <-ctx.Done(): + return ctx.Err() + case msg, ok := <-messageCh: + if !ok { + break + } offset := &sourcepb.Offset{ Offset: msg.Offset().Value(), PartitionId: msg.Offset().PartitionId(), @@ -204,16 +243,38 @@ func (fs *Service) performAckHandshake(stream sourcepb.Source_AckFnServer) error return nil } +// recvWithContext wraps stream.Recv() to respect context cancellation for AckFn. +func recvWithContextAck(ctx context.Context, stream sourcepb.Source_AckFnServer) (*sourcepb.AckRequest, error) { + type recvResult struct { + req *sourcepb.AckRequest + err error + } + + resultCh := make(chan recvResult, 1) + go func() { + req, err := stream.Recv() + resultCh <- recvResult{req: req, err: err} + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-resultCh: + return result.req, result.err + } +} + // receiveAckRequests receives ack requests from the client and invokes the source Ack method. -func (fs *Service) receiveAckRequests(ctx context.Context, stream sourcepb.Source_AckFnServer) error { +func (fs *Service) receiveAckRequests(ctx context.Context, stream sourcepb.Source_AckFnServer) (err error) { defer func() { if r := recover(); r != nil { log.Printf("panic inside source handler: %v %v", r, string(debug.Stack())) fs.shutdownCh <- struct{}{} + err = fmt.Errorf("panic inside source handler: %v", r) } }() - req, err := stream.Recv() + req, err := recvWithContextAck(ctx, stream) if err == io.EOF { log.Printf("end of ack stream") return err diff --git a/pkg/sourcetransformer/service.go b/pkg/sourcetransformer/service.go index c98ce5dd..e19690f3 100644 --- a/pkg/sourcetransformer/service.go +++ b/pkg/sourcetransformer/service.go @@ -40,6 +40,27 @@ func (fs *Service) IsReady(context.Context, *emptypb.Empty) (*v1.ReadyResponse, var errTransformerPanic = errors.New("transformer function panicked") +// recvWithContext wraps stream.Recv() to respect context cancellation. +func recvWithContext(ctx context.Context, stream v1.SourceTransform_SourceTransformFnServer) (*v1.SourceTransformRequest, error) { + type recvResult struct { + req *v1.SourceTransformRequest + err error + } + + resultCh := make(chan recvResult, 1) + go func() { + req, err := stream.Recv() + resultCh <- recvResult{req: req, err: err} + }() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case result := <-resultCh: + return result.req, result.err + } +} + // SourceTransformFn applies a function to each request element. // In addition to map function, SourceTransformFn also supports assigning a new event time to response. // SourceTransformFn can be used only at source vertex by source data transformer. @@ -76,15 +97,13 @@ func (fs *Service) SourceTransformFn(stream v1.SourceTransform_SourceTransformFn var readErr error outer: for { - select { - case <-groupCtx.Done(): - // Stop reading new messages when we are shutting down + d, err := recvWithContext(groupCtx, stream) + if errors.Is(err, context.Canceled) { + log.Printf("Context cancelled, stopping the SourceTransformFn") break outer - default: - // get out of select and process } - d, err := stream.Recv() - if err == io.EOF { + if errors.Is(err, io.EOF) { + log.Printf("EOF received, stopping the SourceTransformFn") break outer } if err != nil { @@ -102,9 +121,9 @@ outer: // wait for all the goroutines to finish, if any of the goroutines return an error, wait will return that error immediately. if err := grp.Wait(); err != nil { + log.Printf("Stopping the SourceTransformFn with err, %s", err) fs.shutdownCh <- struct{}{} - statusErr := status.Errorf(codes.Internal, err.Error()) - return statusErr + return status.Errorf(codes.Internal, err.Error()) } // check if there was an error while reading from the stream