Skip to content

Commit

Permalink
Close stream before closing gRPC connection
Browse files Browse the repository at this point in the history
Signed-off-by: Sreekanth <[email protected]>
  • Loading branch information
BulkBeing committed Sep 29, 2024
1 parent a320559 commit c7ed77f
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 38 deletions.
18 changes: 7 additions & 11 deletions pkg/isb/tracker/message_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,18 @@ type MessageTracker struct {
}

// NewTracker initializes a new instance of a Tracker
func NewTracker() *MessageTracker {
func NewTracker(messages []*isb.ReadMessage) *MessageTracker {
m := make(map[string]*isb.ReadMessage, len(messages))
for _, msg := range messages {
id := msg.ReadOffset.String()
m[id] = msg
}
return &MessageTracker{
m: make(map[string]*isb.ReadMessage),
m: m,
lock: sync.RWMutex{},
}
}

// Add add a new entry for a given message to the Tracker.
// the key is chosen as the read offset of the message
func (t *MessageTracker) Add(msg *isb.ReadMessage) {
id := msg.ReadOffset.String()
t.lock.Lock()
defer t.lock.Unlock()
t.m[id] = msg
}

// Remove will remove the entry for a given id and return the stored value corresponding to this id.
// A `nil` return value indicates that the id doesn't exist in the tracker.
func (t *MessageTracker) Remove(id string) *isb.ReadMessage {
Expand Down
15 changes: 9 additions & 6 deletions pkg/isb/tracker/message_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,30 @@ import (

"github.com/stretchr/testify/assert"

"github.com/numaproj/numaflow/pkg/isb"
"github.com/numaproj/numaflow/pkg/isb/testutils"
)

func TestTracker_AddRequest(t *testing.T) {
tr := NewTracker()
readMessages := testutils.BuildTestReadMessages(3, time.Unix(1661169600, 0), nil)
for _, msg := range readMessages {
tr.Add(&msg)
messages := make([]*isb.ReadMessage, len(readMessages))
for i, msg := range readMessages {
messages[i] = &msg
}
tr := NewTracker(messages)
id := readMessages[0].ReadOffset.String()
m := tr.Remove(id)
assert.NotNil(t, m)
assert.Equal(t, readMessages[0], *m)
}

func TestTracker_RemoveRequest(t *testing.T) {
tr := NewTracker()
readMessages := testutils.BuildTestReadMessages(3, time.Unix(1661169600, 0), nil)
for _, msg := range readMessages {
tr.Add(&msg)
messages := make([]*isb.ReadMessage, len(readMessages))
for i, msg := range readMessages {
messages[i] = &msg
}
tr := NewTracker(messages)
id := readMessages[0].ReadOffset.String()
m := tr.Remove(id)
assert.NotNil(t, m)
Expand Down
26 changes: 18 additions & 8 deletions pkg/sdkclient/sourcetransformer/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ func NewFromClient(ctx context.Context, c transformpb.SourceTransformClient) (Cl

// CloseConn closes the grpc client connection.
func (c *client) CloseConn(_ context.Context) error {
err := c.stream.CloseSend()
if err != nil {
return err
}
if c.conn == nil {
return nil
}
Expand All @@ -145,16 +149,22 @@ func (c *client) IsReady(ctx context.Context, in *emptypb.Empty) (bool, error) {

// SourceTransformFn SourceTransformerFn applies a function to each request element.
func (c *client) SourceTransformFn(ctx context.Context, request <-chan *transformpb.SourceTransformRequest) (<-chan *transformpb.SourceTransformResponse, <-chan error) {
errCh := make(chan error)
clientErrCh := make(chan error)
responseCh := make(chan *transformpb.SourceTransformResponse)

// This channel is to send the error from the goroutine that receives messages from the stream to the goroutine that sends requests to the server.
// This ensures that we don't need to use clientErrCh in both goroutines. The caller of this function will only be listening for the first error value in clientErrCh.
// If both goroutines were sending error message to this channel (eg. stream failure), one of them will be stuck in sending can not shutdown cleanly.
errCh := make(chan error, 1)

// Receive responses from the stream
go func() {
defer close(responseCh)
for {
resp, err := c.stream.Recv()
if err != nil {
// we don't need an EOF check because we only close the stream during shutdown.
errCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn", err)
close(errCh)
return
}

Expand All @@ -169,7 +179,10 @@ func (c *client) SourceTransformFn(ctx context.Context, request <-chan *transfor
for {
select {
case <-ctx.Done():
errCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", ctx.Err())
clientErrCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", ctx.Err())
return
case err := <-errCh:
clientErrCh <- err
return
case msg, ok := <-request:
if !ok {
Expand All @@ -178,15 +191,12 @@ func (c *client) SourceTransformFn(ctx context.Context, request <-chan *transfor
}
err := c.stream.Send(msg)
if err != nil {
select {
case <-ctx.Done():
case errCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", err):
}
clientErrCh <- sdkerr.ToUDFErr("c.grpcClt.SourceTransformFn stream.Send", err)
return
}
}
}
}()

return responseCh, errCh
return responseCh, clientErrCh
}
3 changes: 1 addition & 2 deletions pkg/sources/transformer/grpc_transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,11 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*i
inputChan := make(chan *v1.SourceTransformRequest)
respChan, errChan := u.client.SourceTransformFn(ctx, inputChan)

msgTracker := tracker.NewTracker()
msgTracker := tracker.NewTracker(messages)

go func() {
defer close(inputChan)
for _, msg := range messages {
msgTracker.Add(msg)
req := &v1.SourceTransformRequest{
Request: &v1.SourceTransformRequest_Request{
Keys: msg.Keys,
Expand Down
2 changes: 1 addition & 1 deletion pkg/sources/transformer/grpc_transformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func TestGRPCBasedTransformer_BasicApplyWithServer(t *testing.T) {

expectedUDFErr := &rpc.ApplyUDFErr{
UserUDFErr: false,
Message: "gRPC client.SourceTransformFn failed, Canceled: context canceled",
Message: "gRPC client.SourceTransformFn failed, NonRetryable: context canceled",
InternalErr: rpc.InternalErr{
Flag: true,
MainCarDown: false,
Expand Down
13 changes: 3 additions & 10 deletions pkg/udf/rpc/grpc_batch_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,14 @@ import (

// GRPCBasedBatchMap is a map applier that uses gRPC client to invoke the map UDF. It implements the applier.MapApplier interface.
type GRPCBasedBatchMap struct {
vertexName string
client batchmapper.Client
requestTracker *tracker.MessageTracker
vertexName string
client batchmapper.Client
}

func NewUDSgRPCBasedBatchMap(vertexName string, client batchmapper.Client) *GRPCBasedBatchMap {
return &GRPCBasedBatchMap{
vertexName: vertexName,
client: client,
// requestTracker is used to store the read messages in a key, value manner where
// key is the read offset and the reference to read message as the value.
// Once the results are received from the UDF, we map the responses to the corresponding request
// using a lookup on this Tracker.
requestTracker: tracker.NewTracker(),
}
}

Expand Down Expand Up @@ -95,7 +89,7 @@ func (u *GRPCBasedBatchMap) ApplyBatchMap(ctx context.Context, messages []*isb.R
// key is the read offset and the reference to read message as the value.
// Once the results are received from the UDF, we map the responses to the corresponding request
// using a lookup on this Tracker.
trackerReq := tracker.NewTracker()
trackerReq := tracker.NewTracker(messages)

// Read routine: this goroutine iterates over the input messages and sends each
// of the read messages to the grpc client after transforming it to a BatchMapRequest.
Expand All @@ -105,7 +99,6 @@ func (u *GRPCBasedBatchMap) ApplyBatchMap(ctx context.Context, messages []*isb.R
go func() {
defer close(inputChan)
for _, msg := range messages {
trackerReq.Add(msg)
inputChan <- u.parseInputRequest(msg)
}
}()
Expand Down

0 comments on commit c7ed77f

Please sign in to comment.