From ab1c9b4018cd374068b21dbe4740b0a6b95f5368 Mon Sep 17 00:00:00 2001 From: Yashash H L Date: Mon, 23 Sep 2024 12:18:22 +0530 Subject: [PATCH] refactor source and sink Signed-off-by: Yashash H L --- pkg/sinker/service.go | 164 ++++++++++++----------- pkg/sourcer/service.go | 290 ++++++++++++++++++++++------------------- 2 files changed, 245 insertions(+), 209 deletions(-) diff --git a/pkg/sinker/service.go b/pkg/sinker/service.go index dc76b583..60405e71 100644 --- a/pkg/sinker/service.go +++ b/pkg/sinker/service.go @@ -75,7 +75,36 @@ func (fs *Service) IsReady(context.Context, *emptypb.Empty) (*sinkpb.ReadyRespon // SinkFn applies a sink function to a every element. func (fs *Service) SinkFn(stream sinkpb.Sink_SinkFnServer) error { ctx := stream.Context() + // Perform handshake before entering the main loop + if err := fs.performHandshake(stream); err != nil { + return err + } + + for { + datumStreamCh := make(chan Datum) + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + return fs.receiveRequests(stream, datumStreamCh) + }) + + g.Go(func() error { + return fs.processData(ctx, stream, datumStreamCh) + }) + + // Wait for the goroutines to finish + if err := g.Wait(); err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return err + } + } +} + +// performHandshake performs the handshake with the client. +func (fs *Service) performHandshake(stream sinkpb.Sink_SinkFnServer) error { req, err := stream.Recv() if err != nil { log.Printf("error receiving handshake from stream: %v", err) @@ -86,7 +115,6 @@ func (fs *Service) SinkFn(stream sinkpb.Sink_SinkFnServer) error { return fmt.Errorf("expected handshake message") } - // Send handshake response handshakeResponse := &sinkpb.SinkResponse{ Result: &sinkpb.SinkResponse_Result{ Status: sinkpb.Status_SUCCESS, @@ -99,86 +127,72 @@ func (fs *Service) SinkFn(stream sinkpb.Sink_SinkFnServer) error { return err } + return nil +} + +// receiveRequests receives the requests from the client writes them to the datumStreamCh channel. +func (fs *Service) receiveRequests(stream sinkpb.Sink_SinkFnServer, datumStreamCh chan<- Datum) error { + defer close(datumStreamCh) + defer func() { + if r := recover(); r != nil { + log.Printf("panic inside sink handler: %v %v", r, string(debug.Stack())) + fs.shutdownCh <- struct{}{} + } + }() + for { - datumStreamCh := make(chan Datum) - g, ctx := errgroup.WithContext(ctx) + req, err := stream.Recv() + if err == io.EOF { + log.Printf("end of sink stream") + return err + } + if err != nil { + log.Printf("error receiving from sink stream: %v", err) + return err + } - g.Go(func() error { - defer close(datumStreamCh) - defer func() { - if r := recover(); r != nil { - log.Printf("panic inside sink handler: %v %v", r, string(debug.Stack())) - fs.shutdownCh <- struct{}{} - } - }() - - for { - // Receive sink requests from the stream - req, err := stream.Recv() - if err == io.EOF { - log.Printf("end of sink stream") - return err - } - if err != nil { - log.Printf("error receiving from sink stream: %v", err) - return err - } - - if req.Status != nil && req.Status.Eot { - // End of transmission, break to start a new sink invocation - break - } - - datum := &handlerDatum{ - id: req.GetRequest().GetId(), - value: req.GetRequest().GetValue(), - keys: req.GetRequest().GetKeys(), - eventTime: req.GetRequest().GetEventTime().AsTime(), - watermark: req.GetRequest().GetWatermark().AsTime(), - headers: req.GetRequest().GetHeaders(), - } - - // Send datum to the channel - datumStreamCh <- datum - } - return nil - }) + if req.Status != nil && req.Status.Eot { + break + } - // invoke the sink function, and send the responses back to the client - g.Go(func() error { - responses := fs.Sinker.Sink(ctx, datumStreamCh) - for _, response := range responses { - var status sinkpb.Status - if response.Fallback { - status = sinkpb.Status_FALLBACK - } else if response.Success { - status = sinkpb.Status_SUCCESS - } else { - status = sinkpb.Status_FAILURE - } - - sinkResponse := &sinkpb.SinkResponse{ - Result: &sinkpb.SinkResponse_Result{ - Id: response.ID, - Status: status, - ErrMsg: response.Err, - }, - } - if err := stream.Send(sinkResponse); err != nil { - log.Printf("error sending sink response: %v", err) - return err - } - } - return nil - }) + datum := &handlerDatum{ + id: req.GetRequest().GetId(), + value: req.GetRequest().GetValue(), + keys: req.GetRequest().GetKeys(), + eventTime: req.GetRequest().GetEventTime().AsTime(), + watermark: req.GetRequest().GetWatermark().AsTime(), + headers: req.GetRequest().GetHeaders(), + } - // Wait for the goroutines to finish - err := g.Wait() - if errors.Is(err, io.EOF) { - return nil + datumStreamCh <- datum + } + return nil +} + +// processData invokes the sinker to process the data and sends the response back to the client. +func (fs *Service) processData(ctx context.Context, stream sinkpb.Sink_SinkFnServer, datumStreamCh chan Datum) error { + responses := fs.Sinker.Sink(ctx, datumStreamCh) + for _, response := range responses { + var status sinkpb.Status + if response.Fallback { + status = sinkpb.Status_FALLBACK + } else if response.Success { + status = sinkpb.Status_SUCCESS + } else { + status = sinkpb.Status_FAILURE } - if err != nil { + + sinkResponse := &sinkpb.SinkResponse{ + Result: &sinkpb.SinkResponse_Result{ + Id: response.ID, + Status: status, + ErrMsg: response.Err, + }, + } + if err := stream.Send(sinkResponse); err != nil { + log.Printf("error sending sink response: %v", err) return err } } + return nil } diff --git a/pkg/sourcer/service.go b/pkg/sourcer/service.go index 9fc6ee60..d3f0fe7c 100644 --- a/pkg/sourcer/service.go +++ b/pkg/sourcer/service.go @@ -2,13 +2,14 @@ package sourcer import ( "context" + "errors" "fmt" "io" "log" "runtime/debug" - "sync" "time" + "golang.org/x/sync/errgroup" "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/timestamppb" @@ -32,119 +33,119 @@ type Service struct { // ReadFn reads the data from the source. func (fs *Service) ReadFn(stream sourcepb.Source_ReadFnServer) error { ctx := stream.Context() - errCh := make(chan error, 1) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - // handle panic - defer func() { - if r := recover(); r != nil { - log.Printf("panic inside source handler: %v %v", r, string(debug.Stack())) - fs.shutdownCh <- struct{}{} - errCh <- fmt.Errorf("panic: %v", r) + if err := fs.performReadHandshake(stream); err != nil { + return err + } + + for { + if err := fs.receiveReadRequests(ctx, stream); err != nil { + // If the error is EOF, it means the stream has been closed. + if errors.Is(err, io.EOF) { + return nil } - }() - - // Do a handshake with the client before starting to read - req, err := stream.Recv() - if err != nil { - log.Printf("error receiving handshake from stream: %v", err) - errCh <- err - return + return err } + } +} - if req.Handshake == nil || !req.Handshake.Sot { - errCh <- fmt.Errorf("expected handshake message") - return - } +// performReadHandshake performs the handshake with the client before starting the read process. +func (fs *Service) performReadHandshake(stream sourcepb.Source_ReadFnServer) error { + req, err := stream.Recv() + if err != nil { + log.Printf("error receiving handshake from stream: %v", err) + return err + } - // Send handshake response - handshakeResponse := &sourcepb.ReadResponse{ - Status: &sourcepb.ReadResponse_Status{ - Eot: false, - Code: sourcepb.ReadResponse_Status_SUCCESS, - }, - Handshake: &sourcepb.Handshake{ - Sot: true, - }, - } - if err := stream.Send(handshakeResponse); err != nil { - errCh <- err - return + if req.Handshake == nil || !req.Handshake.Sot { + return fmt.Errorf("expected handshake message") + } + + handshakeResponse := &sourcepb.ReadResponse{ + Status: &sourcepb.ReadResponse_Status{ + Eot: false, + Code: sourcepb.ReadResponse_Status_SUCCESS, + }, + Handshake: &sourcepb.Handshake{ + Sot: true, + }, + } + if err := stream.Send(handshakeResponse); err != nil { + return err + } + + return nil +} + +// receiveReadRequests receives read requests from the client and invokes the source Read method. +// writes the read data to the message channel. +func (fs *Service) receiveReadRequests(ctx context.Context, stream sourcepb.Source_ReadFnServer) error { + messageCh := make(chan Message) + // handle panic + defer func() { + if r := recover(); r != nil { + log.Printf("panic inside source handler: %v %v", r, string(debug.Stack())) + fs.shutdownCh <- struct{}{} } + }() - for { - // Receive read requests from the stream - req, err := stream.Recv() - if err == io.EOF { - log.Printf("end of read stream") - return - } - if err != nil { - log.Printf("error receiving from read stream: %v", err) - errCh <- err - return - } + req, err := stream.Recv() + if err == io.EOF { + log.Printf("end of read stream") + return err + } + if err != nil { + log.Printf("error receiving from read stream: %v", err) + return err + } - messageCh := make(chan Message) - go func() { - defer close(messageCh) - request := readRequest{ - count: req.Request.GetNumRecords(), - timeout: time.Duration(req.Request.GetTimeoutInMs()) * time.Millisecond, - } - fs.Source.Read(ctx, &request, messageCh) - }() - - // Read messages from the channel and send them to the stream. - for msg := range messageCh { - offset := &sourcepb.Offset{ - Offset: msg.Offset().Value(), - PartitionId: msg.Offset().PartitionId(), - } - element := &sourcepb.ReadResponse{ - Result: &sourcepb.ReadResponse_Result{ - Payload: msg.Value(), - Offset: offset, - EventTime: timestamppb.New(msg.EventTime()), - Keys: msg.Keys(), - Headers: msg.Headers(), - }, - Status: &sourcepb.ReadResponse_Status{ - Eot: false, - Code: 0, - }, - } - // The error here is returned by the stream, which is already a gRPC error - if err := stream.Send(element); err != nil { - errCh <- err - return - } - } - err = stream.Send(&sourcepb.ReadResponse{ - Status: &sourcepb.ReadResponse_Status{ - Eot: true, - Code: 0, - }, - }) - if err != nil { - errCh <- err - return - } + go func() { + defer close(messageCh) + request := readRequest{ + count: req.Request.GetNumRecords(), + timeout: time.Duration(req.Request.GetTimeoutInMs()) * time.Millisecond, } + fs.Source.Read(ctx, &request, messageCh) }() - // Wait for the goroutine to finish - wg.Wait() - // Check if there was any error in the goroutine - select { - case err := <-errCh: + // invoke the processReadData method to send the read data to the client. + return fs.processReadData(stream, messageCh) +} + +// processReadData processes the read data and sends it to the client. +func (fs *Service) processReadData(stream sourcepb.Source_ReadFnServer, messageCh <-chan Message) error { + for msg := range messageCh { + offset := &sourcepb.Offset{ + Offset: msg.Offset().Value(), + PartitionId: msg.Offset().PartitionId(), + } + element := &sourcepb.ReadResponse{ + Result: &sourcepb.ReadResponse_Result{ + Payload: msg.Value(), + Offset: offset, + EventTime: timestamppb.New(msg.EventTime()), + Keys: msg.Keys(), + Headers: msg.Headers(), + }, + Status: &sourcepb.ReadResponse_Status{ + Eot: false, + Code: 0, + }, + } + if err := stream.Send(element); err != nil { + return err + } + } + err := stream.Send(&sourcepb.ReadResponse{ + Status: &sourcepb.ReadResponse_Status{ + Eot: true, + Code: 0, + }, + }) + if err != nil { return err - default: - return nil } + return nil } // ackRequest implements the AckRequest interface and is used in the ack handler. @@ -161,16 +162,29 @@ func (a *ackRequest) Offset() Offset { func (fs *Service) AckFn(stream sourcepb.Source_AckFnServer) error { ctx := stream.Context() - // handle panic - defer func() { - if r := recover(); r != nil { - log.Printf("panic inside source handler: %v %v", r, string(debug.Stack())) - fs.shutdownCh <- struct{}{} + if err := fs.performAckHandshake(stream); err != nil { + return err + } + + for { + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + return fs.receiveAckRequests(ctx, stream) + }) + + // Wait for the goroutines to finish + if err := g.Wait(); err != nil { + if errors.Is(err, io.EOF) { + return nil + } + return err } - }() + } +} - // Do a handshake with the client before starting to ack - // first message should be a handshake +// performAckHandshake performs the handshake with the client before starting the ack process. +func (fs *Service) performAckHandshake(stream sourcepb.Source_AckFnServer) error { req, err := stream.Recv() if err != nil { log.Printf("error receiving handshake from stream: %v", err) @@ -181,7 +195,6 @@ func (fs *Service) AckFn(stream sourcepb.Source_AckFnServer) error { return fmt.Errorf("expected handshake message") } - // Send handshake response handshakeResponse := &sourcepb.AckResponse{ Result: &sourcepb.AckResponse_Result{ Success: &emptypb.Empty{}, @@ -194,34 +207,43 @@ func (fs *Service) AckFn(stream sourcepb.Source_AckFnServer) error { return err } - for { - // Receive ack requests from the stream - req, err := stream.Recv() - if err == io.EOF { - log.Printf("end of ack stream") - return nil - } - if err != nil { - log.Printf("error receiving from ack stream: %v", err) - return err - } + return nil +} - request := ackRequest{ - offset: NewOffset(req.Request.Offset.GetOffset(), req.Request.Offset.GetPartitionId()), +// receiveAckRequests receives ack requests from the client and invokes the source Ack method. +func (fs *Service) receiveAckRequests(ctx context.Context, stream sourcepb.Source_AckFnServer) error { + defer func() { + if r := recover(); r != nil { + log.Printf("panic inside source handler: %v %v", r, string(debug.Stack())) + fs.shutdownCh <- struct{}{} } - fs.Source.Ack(ctx, &request) + }() - // Send ack response - ackResponse := &sourcepb.AckResponse{ - Result: &sourcepb.AckResponse_Result{ - Success: &emptypb.Empty{}, - }, - } - if err := stream.Send(ackResponse); err != nil { - log.Printf("error sending ack response: %v", err) - return err - } + req, err := stream.Recv() + if err == io.EOF { + log.Printf("end of ack stream") + return err + } + if err != nil { + log.Printf("error receiving from ack stream: %v", err) + return err + } + + request := ackRequest{ + offset: NewOffset(req.Request.Offset.GetOffset(), req.Request.Offset.GetPartitionId()), + } + fs.Source.Ack(ctx, &request) + + ackResponse := &sourcepb.AckResponse{ + Result: &sourcepb.AckResponse_Result{ + Success: &emptypb.Empty{}, + }, + } + if err := stream.Send(ackResponse); err != nil { + log.Printf("error sending ack response: %v", err) + return err } + return nil } // IsReady returns true to indicate the gRPC connection is ready.