Skip to content

Commit

Permalink
make ackFn bidirectional streaming
Browse files Browse the repository at this point in the history
Signed-off-by: Yashash H L <[email protected]>
  • Loading branch information
yhl25 committed Sep 18, 2024
1 parent ee2f308 commit 59b54dd
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 40 deletions.
38 changes: 19 additions & 19 deletions pkg/apis/proto/source/v1/source.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/apis/proto/source/v1/source.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ service Source {
// The caller (numa) expects the AckFn to be successful, and it does not expect any errors.
// If there are some irrecoverable errors when the callee (UDSource) is processing the AckFn request,
// then it is best to crash because there are no other retry mechanisms possible.
rpc AckFn(stream AckRequest) returns (AckResponse);
rpc AckFn(stream AckRequest) returns (stream AckResponse);

// PendingFn returns the number of pending records at the user defined source.
rpc PendingFn(google.protobuf.Empty) returns (PendingResponse);
Expand Down
12 changes: 5 additions & 7 deletions pkg/apis/proto/source/v1/source_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 16 additions & 5 deletions pkg/sourcer/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,11 @@ func (fs *Service) ReadFn(stream sourcepb.Source_ReadFnServer) error {
// Receive read requests from the stream
req, err := stream.Recv()
if err == io.EOF {
log.Printf("end of read stream")
return
}
if err != nil {
log.Printf("error receiving from stream: %v", err)
log.Printf("error receiving from read stream: %v", err)
errCh <- err
return
}
Expand Down Expand Up @@ -172,19 +173,29 @@ func (fs *Service) AckFn(stream sourcepb.Source_AckFnServer) error {
// Receive ack requests from the stream
req, err := stream.Recv()
if err == io.EOF {
return stream.SendAndClose(&sourcepb.AckResponse{
Result: &sourcepb.AckResponse_Result{},
})
log.Printf("end of ack stream")
return nil
}
if err != nil {
log.Printf("error receiving from stream: %v", err)
log.Printf("error receiving from ack stream: %v", err)
return err
}

request := ackRequest{
offset: NewOffset(req.Request.Offset.GetOffset(), req.Request.Offset.GetPartitionId()),
}
fs.Source.Ack(ctx, &request)

// Send ack response
ackResponse := &sourcepb.AckResponse{
Result: &sourcepb.AckResponse_Result{
Success: &emptypb.Empty{},
},
}
if err := stream.Send(ackResponse); err != nil {
log.Printf("error sending ack response: %v", err)
return err
}
}
}

Expand Down
31 changes: 23 additions & 8 deletions pkg/sourcer/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ func (te *ReadFnServerErrTest) Context() context.Context {
}

type AckFnServerTest struct {
ctx context.Context
offsets []*sourcepb.Offset
ctx context.Context
offsets []*sourcepb.Offset
responses []*sourcepb.AckResponse
grpc.ServerStream
index int
}
Expand All @@ -136,20 +137,22 @@ func (a *AckFnServerTest) Recv() (*sourcepb.AckRequest, error) {
}, nil
}

func (a *AckFnServerTest) Send(response *sourcepb.AckResponse) error {
a.responses = append(a.responses, response)
return nil
}

func NewAckFnServerTest(
ctx context.Context,
offsets []*sourcepb.Offset,
) *AckFnServerTest {
return &AckFnServerTest{
ctx: ctx,
offsets: offsets,
ctx: ctx,
offsets: offsets,
responses: make([]*sourcepb.AckResponse, 0),
}
}

func (a *AckFnServerTest) SendAndClose(*sourcepb.AckResponse) error {
return nil
}

func (a *AckFnServerTest) Context() context.Context {
return a.ctx
}
Expand Down Expand Up @@ -269,6 +272,18 @@ func TestService_AckFn(t *testing.T) {

err := fs.AckFn(ackFnStream)
assert.NoError(t, err)

expectedResponses := []*sourcepb.AckResponse{
{
Result: &sourcepb.AckResponse_Result{
Success: &emptypb.Empty{},
},
},
}

if !reflect.DeepEqual(ackFnStream.responses, expectedResponses) {
t.Errorf("AckFn() responses = %v, want %v", ackFnStream.responses, expectedResponses)
}
}

func TestService_PendingFn(t *testing.T) {
Expand Down

0 comments on commit 59b54dd

Please sign in to comment.