diff --git a/pkg/isb/tracker/message_tracker.go b/pkg/isb/tracker/message_tracker.go index c8cac63c07..dfd608e5bf 100644 --- a/pkg/isb/tracker/message_tracker.go +++ b/pkg/isb/tracker/message_tracker.go @@ -14,8 +14,8 @@ type MessageTracker struct { m map[string]*isb.ReadMessage } -// NewTracker initializes a new instance of a Tracker -func NewTracker(messages []*isb.ReadMessage) *MessageTracker { +// NewMessageTracker initializes a new instance of a Tracker +func NewMessageTracker(messages []*isb.ReadMessage) *MessageTracker { m := make(map[string]*isb.ReadMessage, len(messages)) for _, msg := range messages { id := msg.ReadOffset.String() @@ -47,3 +47,10 @@ func (t *MessageTracker) IsEmpty() bool { defer t.lock.RUnlock() return len(t.m) == 0 } + +// Len returns the number of messages currently stored in the tracker +func (t *MessageTracker) Len() int { + t.lock.RLock() + defer t.lock.RUnlock() + return len(t.m) +} diff --git a/pkg/isb/tracker/message_tracker_test.go b/pkg/isb/tracker/message_tracker_test.go index bd33d03c47..3c2ae767d0 100644 --- a/pkg/isb/tracker/message_tracker_test.go +++ b/pkg/isb/tracker/message_tracker_test.go @@ -16,7 +16,7 @@ func TestTracker_AddRequest(t *testing.T) { for i, msg := range readMessages { messages[i] = &msg } - tr := NewTracker(messages) + tr := NewMessageTracker(messages) id := readMessages[0].ReadOffset.String() m := tr.Remove(id) assert.NotNil(t, m) @@ -29,7 +29,7 @@ func TestTracker_RemoveRequest(t *testing.T) { for i, msg := range readMessages { messages[i] = &msg } - tr := NewTracker(messages) + tr := NewMessageTracker(messages) id := readMessages[0].ReadOffset.String() m := tr.Remove(id) assert.NotNil(t, m) diff --git a/pkg/sdkclient/sourcetransformer/client.go b/pkg/sdkclient/sourcetransformer/client.go index 179a3e6f8a..7f3327ac5f 100644 --- a/pkg/sdkclient/sourcetransformer/client.go +++ b/pkg/sdkclient/sourcetransformer/client.go @@ -64,7 +64,7 @@ waitUntilReady: for { select { case <-ctx.Done(): - return nil, fmt.Errorf("waiting for transformer gRPC server to be ready: %w", context.Cause(ctx)) + return nil, fmt.Errorf("waiting for transformer gRPC server to be ready: %w", ctx.Err()) default: _, err := c.IsReady(ctx, &emptypb.Empty{}) if err != nil { @@ -148,6 +148,7 @@ func (c *client) IsReady(ctx context.Context, in *emptypb.Empty) (bool, error) { } // SourceTransformFn SourceTransformerFn applies a function to each request element. +// Response channel will not be closed. Caller can select on response and error channel to exit on first error. func (c *client) SourceTransformFn(ctx context.Context, request <-chan *transformpb.SourceTransformRequest) (<-chan *transformpb.SourceTransformResponse, <-chan error) { clientErrCh := make(chan error) responseCh := make(chan *transformpb.SourceTransformResponse) @@ -157,6 +158,8 @@ func (c *client) SourceTransformFn(ctx context.Context, request <-chan *transfor // If both goroutines were sending error message to this channel (eg. stream failure), one of them will be stuck in sending can not shutdown cleanly. errCh := make(chan error, 1) + logger := logging.FromContext(ctx) + // Receive responses from the stream go func() { for { @@ -170,6 +173,8 @@ func (c *client) SourceTransformFn(ctx context.Context, request <-chan *transfor select { case <-ctx.Done(): + logger.Warnf("Context cancelled. Stopping retrieving messages from the stream") + return case responseCh <- resp: } } diff --git a/pkg/sources/transformer/grpc_transformer.go b/pkg/sources/transformer/grpc_transformer.go index edbef1e955..e578687fb3 100644 --- a/pkg/sources/transformer/grpc_transformer.go +++ b/pkg/sources/transformer/grpc_transformer.go @@ -81,7 +81,9 @@ func (u *GRPCBasedTransformer) ApplyTransform(ctx context.Context, messages []*i inputChan := make(chan *v1.SourceTransformRequest) respChan, errChan := u.client.SourceTransformFn(ctx, inputChan) - msgTracker := tracker.NewTracker(messages) + logger := logging.FromContext(ctx) + + msgTracker := tracker.NewMessageTracker(messages) go func() { defer close(inputChan) @@ -117,6 +119,7 @@ loop: return nil, err case resp, ok := <-respChan: if !ok { + logger.Warn("Response channel from source transform client was closed.") break loop } msgId := resp.GetId() @@ -169,7 +172,7 @@ loop: } if !msgTracker.IsEmpty() { - return nil, errors.New("transform response for all requests were not received from UDF") + return nil, fmt.Errorf("transform response for all requests were not received from UDF. Remaining=%d", msgTracker.Len()) } return transformResults, nil } diff --git a/pkg/udf/rpc/grpc_batch_map.go b/pkg/udf/rpc/grpc_batch_map.go index efd2ba1e57..ce65d201fb 100644 --- a/pkg/udf/rpc/grpc_batch_map.go +++ b/pkg/udf/rpc/grpc_batch_map.go @@ -89,7 +89,7 @@ func (u *GRPCBasedBatchMap) ApplyBatchMap(ctx context.Context, messages []*isb.R // key is the read offset and the reference to read message as the value. // Once the results are received from the UDF, we map the responses to the corresponding request // using a lookup on this Tracker. - trackerReq := tracker.NewTracker(messages) + trackerReq := tracker.NewMessageTracker(messages) // Read routine: this goroutine iterates over the input messages and sends each // of the read messages to the grpc client after transforming it to a BatchMapRequest. @@ -140,7 +140,7 @@ loop: // This means that either the UDF added an incorrect ID // This cannot be processed further and should result in an error // Can there be another case for this? - logger.Error("Request missing from Tracker, ", msgId) + logger.Error("Request missing from message tracker, ", msgId) return nil, fmt.Errorf("incorrect ID found during batch map processing") } // parse the responses received diff --git a/rust/monovertex/src/transformer.rs b/rust/monovertex/src/transformer.rs index a8b0e4d878..8f7cc05091 100644 --- a/rust/monovertex/src/transformer.rs +++ b/rust/monovertex/src/transformer.rs @@ -40,17 +40,10 @@ impl SourceTransformer { Error::TransformerError(format!("failed to send handshake request: {}", e)) })?; - tracing::info!("Sending stream request"); - let transform_fn_with_timeout = tokio::time::timeout( - Duration::from_secs(3), - client.source_transform_fn(Request::new(read_stream)), - ); - let Ok(resp_stream) = transform_fn_with_timeout.await else { - return Err(Error::TransformerError( - "connection to transformer gRPC server timed out".to_string(), - )); - }; - let mut resp_stream = resp_stream?.into_inner(); + let mut resp_stream = client + .source_transform_fn(Request::new(read_stream)) + .await? + .into_inner(); // first response from the server will be the handshake response. We need to check if the // server has accepted the handshake.