diff --git a/src/main/java/io/numaproj/numaflow/sinker/Server.java b/src/main/java/io/numaproj/numaflow/sinker/Server.java index d1c23eac..53dbb1a2 100644 --- a/src/main/java/io/numaproj/numaflow/sinker/Server.java +++ b/src/main/java/io/numaproj/numaflow/sinker/Server.java @@ -75,7 +75,7 @@ public void start() throws Exception { log.info( "Server started, listening on {}", grpcConfig.isLocal() ? - "localhost:" + grpcConfig.getPort():grpcConfig.getSocketPath()); + "localhost:" + grpcConfig.getPort() : grpcConfig.getSocketPath()); // register shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(() -> { diff --git a/src/main/java/io/numaproj/numaflow/sinker/Service.java b/src/main/java/io/numaproj/numaflow/sinker/Service.java index ff04eadb..9c18ed97 100644 --- a/src/main/java/io/numaproj/numaflow/sinker/Service.java +++ b/src/main/java/io/numaproj/numaflow/sinker/Service.java @@ -1,6 +1,7 @@ package io.numaproj.numaflow.sinker; import com.google.protobuf.Empty; +import io.grpc.Status; import io.grpc.stub.StreamObserver; import io.numaproj.numaflow.sink.v1.SinkGrpc; import io.numaproj.numaflow.sink.v1.SinkOuterClass; @@ -8,14 +9,11 @@ import java.time.Instant; import java.util.List; -import java.util.concurrent.ExecutionException; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; -import static io.numaproj.numaflow.sink.v1.SinkGrpc.getSinkFnMethod; - @Slf4j class Service extends SinkGrpc.SinkImplBase { // sinkTaskExecutor is the executor for the sinker. It is a fixed size thread pool @@ -24,12 +22,6 @@ class Service extends SinkGrpc.SinkImplBase { private final ExecutorService sinkTaskExecutor = Executors .newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2); - // SHUTDOWN_TIME is the time to wait for the sinker to shut down, in seconds. - // We use 30 seconds as the default value because it provides a balance between giving tasks enough time to complete - // and not delaying program termination unduly. - private final long SHUTDOWN_TIME = 30; - - private final Sinker sinker; public Service(Sinker sinker) { @@ -41,25 +33,60 @@ public Service(Sinker sinker) { */ @Override public StreamObserver<SinkOuterClass.SinkRequest> sinkFn(StreamObserver<SinkOuterClass.SinkResponse> responseObserver) { - if (this.sinker == null) { - return io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall( - getSinkFnMethod(), - responseObserver); - } + return new StreamObserver<>() { + private boolean startOfStream = true; + private CompletableFuture<ResponseList> result; + private DatumIteratorImpl datumStream; + private boolean handshakeDone = false; - DatumIteratorImpl datumStream = new DatumIteratorImpl(); + @Override + public void onNext(SinkOuterClass.SinkRequest request) { + // make sure the handshake is done before processing the messages + if (!handshakeDone) { + if (!request.hasHandshake() || !request.getHandshake().getSot()) { + responseObserver.onError(Status.INVALID_ARGUMENT + .withDescription("Handshake request not received") + .asException()); + return; + } + responseObserver.onNext(SinkOuterClass.SinkResponse.newBuilder() + .setHandshake(request.getHandshake()) + .build()); + handshakeDone = true; + return; + } - Future<ResponseList> result = sinkTaskExecutor.submit(() -> this.sinker.processMessages( - datumStream)); + // Create a DatumIterator to write the messages to the sinker + // and start the sinker if it is the start of the stream + if (startOfStream) { + datumStream = new DatumIteratorImpl(); + result = CompletableFuture.supplyAsync( + () -> sinker.processMessages(datumStream), + sinkTaskExecutor); + startOfStream = false; + } - return new StreamObserver<SinkOuterClass.SinkRequest>() { - @Override - public void onNext(SinkOuterClass.SinkRequest d) { try { - datumStream.writeMessage(constructHandlerDatum(d)); - } catch (InterruptedException e) { - Thread.interrupted(); - onError(e); + if (request.getStatus().getEot()) { + // End of transmission, write EOF datum to the stream + // Wait for the result and send the response back to the client + datumStream.writeMessage(HandlerDatum.EOF_DATUM); + + ResponseList responses = result.join(); + responses.getResponses().forEach(response -> { + SinkOuterClass.SinkResponse sinkResponse = buildResponse(response); + responseObserver.onNext(sinkResponse); + }); + + // reset the startOfStream flag, since the stream has ended + // so that the next request will be treated as the start of the stream + startOfStream = true; + } else { + datumStream.writeMessage(constructHandlerDatum(request)); + } + } catch (Exception e) { + log.error("Encountered error in sinkFn - {}", e.getMessage()); + responseObserver.onError(e); } } @@ -71,26 +98,23 @@ public void onError(Throwable throwable) { @Override public void onCompleted() { - SinkOuterClass.SinkResponse response = SinkOuterClass.SinkResponse - .newBuilder() - .build(); - try { - datumStream.writeMessage(HandlerDatum.EOF_DATUM); - // wait until the sink handler returns, result.get() is a blocking call - ResponseList responses = result.get(); - // construct responseList from responses - response = buildResponseList(responses); - - } catch (InterruptedException | ExecutionException e) { - e.printStackTrace(); - onError(e); - } - responseObserver.onNext(response); responseObserver.onCompleted(); } }; } + private SinkOuterClass.SinkResponse buildResponse(Response response) { + SinkOuterClass.Status status = response.getFallback() ? SinkOuterClass.Status.FALLBACK : + response.getSuccess() ? SinkOuterClass.Status.SUCCESS : SinkOuterClass.Status.FAILURE; + return SinkOuterClass.SinkResponse.newBuilder() + .setResult(SinkOuterClass.SinkResponse.Result.newBuilder() + .setId(response.getId() == null ? "" : response.getId()) + .setErrMsg(response.getErr() == null ? "" : response.getErr()) + .setStatus(status) + .build()) + .build(); + } + /** * IsReady is the heartbeat endpoint for gRPC. */ @@ -104,37 +128,28 @@ public void isReady( private HandlerDatum constructHandlerDatum(SinkOuterClass.SinkRequest d) { return new HandlerDatum( - d.getKeysList().toArray(new String[0]), - d.getValue().toByteArray(), + d.getRequest().getKeysList().toArray(new String[0]), + d.getRequest().getValue().toByteArray(), Instant.ofEpochSecond( - d.getWatermark().getSeconds(), - d.getWatermark().getNanos()), + d.getRequest().getWatermark().getSeconds(), + d.getRequest().getWatermark().getNanos()), Instant.ofEpochSecond( - d.getEventTime().getSeconds(), - d.getEventTime().getNanos()), - d.getId(), - d.getHeadersMap() + d.getRequest().getEventTime().getSeconds(), + d.getRequest().getEventTime().getNanos()), + d.getRequest().getId(), + d.getRequest().getHeadersMap() ); } - public SinkOuterClass.SinkResponse buildResponseList(ResponseList responses) { - var responseBuilder = SinkOuterClass.SinkResponse.newBuilder(); - responses.getResponses().forEach(response -> { - SinkOuterClass.Status status = response.getFallback() ? SinkOuterClass.Status.FALLBACK : - response.getSuccess() ? SinkOuterClass.Status.SUCCESS : SinkOuterClass.Status.FAILURE; - responseBuilder.addResults(SinkOuterClass.SinkResponse.Result.newBuilder() - .setId(response.getId() == null ? "" : response.getId()) - .setErrMsg(response.getErr() == null ? "" : response.getErr()) - .setStatus(status) - .build()); - }); - return responseBuilder.build(); - } - // shuts down the executor service which is used for reduce public void shutDown() { this.sinkTaskExecutor.shutdown(); try { + // SHUTDOWN_TIME is the time to wait for the sinker to shut down, in seconds. + // We use 30 seconds as the default value because it provides a balance between giving tasks enough time to complete + // and not delaying program termination unduly. + long SHUTDOWN_TIME = 30; + if (!sinkTaskExecutor.awaitTermination(SHUTDOWN_TIME, TimeUnit.SECONDS)) { log.error("Sink executor did not terminate in the specified time."); List<Runnable> droppedTasks = sinkTaskExecutor.shutdownNow(); diff --git a/src/main/java/io/numaproj/numaflow/sinker/SinkerTestKit.java b/src/main/java/io/numaproj/numaflow/sinker/SinkerTestKit.java index 0f976e27..78e0fa9a 100644 --- a/src/main/java/io/numaproj/numaflow/sinker/SinkerTestKit.java +++ b/src/main/java/io/numaproj/numaflow/sinker/SinkerTestKit.java @@ -113,11 +113,13 @@ public Client(String host, int port) { * @return response from the server as a ResponseList */ public ResponseList sendRequest(DatumIterator datumIterator) { - CompletableFuture<SinkOuterClass.SinkResponse> future = new CompletableFuture<>(); + List<SinkOuterClass.SinkResponse> responses = new ArrayList<>(); + CompletableFuture<List<SinkOuterClass.SinkResponse>> future = new CompletableFuture<>(); + StreamObserver<SinkOuterClass.SinkResponse> responseObserver = new StreamObserver<>() { @Override public void onNext(SinkOuterClass.SinkResponse response) { - future.complete(response); + responses.add(response); } @Override @@ -127,16 +129,19 @@ public void onError(Throwable t) { @Override public void onCompleted() { - if (!future.isDone()) { - future.completeExceptionally(new RuntimeException( - "Server completed without a response")); - } + future.complete(responses); } }; StreamObserver<SinkOuterClass.SinkRequest> requestObserver = sinkStub.sinkFn( responseObserver); + // send handshake request + requestObserver.onNext(SinkOuterClass.SinkRequest.newBuilder() + .setHandshake(SinkOuterClass.Handshake.newBuilder().setSot(true).build()) + .build()); + + // send actual requests while (true) { Datum datum = null; try { @@ -148,7 +153,8 @@ public void onCompleted() { if (datum == null) { break; } - SinkOuterClass.SinkRequest request = SinkOuterClass.SinkRequest.newBuilder() + SinkOuterClass.SinkRequest.Request request = SinkOuterClass.SinkRequest.Request + .newBuilder() .addAllKeys( datum.getKeys() == null ? new ArrayList<>() : List.of(datum.getKeys())) @@ -168,28 +174,39 @@ public void onCompleted() { .putAllHeaders( datum.getHeaders() == null ? new HashMap<>() : datum.getHeaders()) .build(); - requestObserver.onNext(request); + requestObserver.onNext(SinkOuterClass.SinkRequest + .newBuilder() + .setRequest(request) + .build()); } + // send end of transmission message + requestObserver.onNext(SinkOuterClass.SinkRequest.newBuilder().setStatus( + SinkOuterClass.SinkRequest.Status.newBuilder().setEot(true)).build()); requestObserver.onCompleted(); - SinkOuterClass.SinkResponse response; + List<SinkOuterClass.SinkResponse> outputResponses; try { - response = future.get(); + outputResponses = future.get(); } catch (Exception e) { throw new RuntimeException(e); } ResponseList.ResponseListBuilder responseListBuilder = ResponseList.newBuilder(); - for (SinkOuterClass.SinkResponse.Result result : response.getResultsList()) { - if (result.getStatus() == SinkOuterClass.Status.SUCCESS) { - responseListBuilder.addResponse(Response.responseOK(result.getId())); - } else if (result.getStatus() == SinkOuterClass.Status.FALLBACK) { + for (SinkOuterClass.SinkResponse result : outputResponses) { + if (result.getHandshake().getSot()) { + continue; + } + if (result.getResult().getStatus() == SinkOuterClass.Status.SUCCESS) { + responseListBuilder.addResponse(Response.responseOK(result + .getResult() + .getId())); + } else if (result.getResult().getStatus() == SinkOuterClass.Status.FALLBACK) { responseListBuilder.addResponse(Response.responseFallback( - result.getId())); + result.getResult().getId())); } else { responseListBuilder.addResponse(Response.responseFailure( - result.getId(), result.getErrMsg())); + result.getResult().getId(), result.getResult().getErrMsg())); } } diff --git a/src/main/java/io/numaproj/numaflow/sourcer/Service.java b/src/main/java/io/numaproj/numaflow/sourcer/Service.java index 58205a0a..90f6881f 100644 --- a/src/main/java/io/numaproj/numaflow/sourcer/Service.java +++ b/src/main/java/io/numaproj/numaflow/sourcer/Service.java @@ -30,18 +30,25 @@ public Service(Sourcer sourcer) { @Override public StreamObserver<SourceOuterClass.ReadRequest> readFn(final StreamObserver<SourceOuterClass.ReadResponse> responseObserver) { return new StreamObserver<>() { + private boolean handshakeDone = false; + @Override public void onNext(SourceOuterClass.ReadRequest request) { - // if the request is a handshake, send handshake response. - if (request.hasHandshake() && request.getHandshake().getSot()) { + // make sure that the handshake is done before processing the read requests + if (!handshakeDone) { + if (!request.hasHandshake() || !request.getHandshake().getSot()) { + responseObserver.onError(Status.INVALID_ARGUMENT + .withDescription("Handshake request not received") + .asException()); + return; + } responseObserver.onNext(SourceOuterClass.ReadResponse.newBuilder() .setHandshake(request.getHandshake()) - .setStatus(SourceOuterClass.ReadResponse.Status.newBuilder() - .setCode(SourceOuterClass.ReadResponse.Status.Code.SUCCESS) - .build()) .build()); + handshakeDone = true; return; } + ReadRequestImpl readRequest = new ReadRequestImpl( request.getRequest().getNumRecords(), Duration.ofMillis(request.getRequest().getTimeoutInMs())); @@ -89,16 +96,22 @@ public void onCompleted() { @Override public StreamObserver<SourceOuterClass.AckRequest> ackFn(final StreamObserver<SourceOuterClass.AckResponse> responseObserver) { return new StreamObserver<>() { + private boolean handshakeDone = false; + @Override public void onNext(SourceOuterClass.AckRequest request) { - // if the request is a handshake, send a handshake response - if (request.hasHandshake() && request.getHandshake().getSot()) { + // make sure that the handshake is done before processing the ack requests + if (!handshakeDone) { + if (!request.hasHandshake() || !request.getHandshake().getSot()) { + responseObserver.onError(Status.INVALID_ARGUMENT + .withDescription("Handshake request not received") + .asException()); + return; + } responseObserver.onNext(SourceOuterClass.AckResponse.newBuilder() .setHandshake(request.getHandshake()) - .setResult(SourceOuterClass.AckResponse.Result.newBuilder().setSuccess( - Empty.newBuilder().build())) .build()); - return; + handshakeDone = true; } SourceOuterClass.Offset offset = request.getRequest().getOffset(); diff --git a/src/main/proto/sink/v1/sink.proto b/src/main/proto/sink/v1/sink.proto index 1dde5c1e..f47d2e9e 100644 --- a/src/main/proto/sink/v1/sink.proto +++ b/src/main/proto/sink/v1/sink.proto @@ -9,7 +9,7 @@ package sink.v1; service Sink { // SinkFn writes the request to a user defined sink. - rpc SinkFn(stream SinkRequest) returns (SinkResponse); + rpc SinkFn(stream SinkRequest) returns (stream SinkResponse); // IsReady is the heartbeat endpoint for gRPC. rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); @@ -19,12 +19,32 @@ service Sink { * SinkRequest represents a request element. */ message SinkRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - string id = 5; - map<string, string> headers = 6; + message Request { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + string id = 5; + map<string, string> headers = 6; + } + message Status { + bool eot = 1; + } + // Required field indicating the request. + Request request = 1; + // Required field indicating the status of the request. + // If eot is set to true, it indicates the end of transmission. + Status status = 2; + // optional field indicating the handshake message. + optional Handshake handshake = 3; +} + +/* + * Handshake message between client and server to indicate the start of transmission. + */ +message Handshake { + // Required field indicating the start of transmission. + bool sot = 1; } /** @@ -55,5 +75,6 @@ message SinkResponse { // err_msg is the error message, set it if success is set to false. string err_msg = 3; } - repeated Result results = 1; + Result result = 1; + optional Handshake handshake = 2; } diff --git a/src/test/java/io/numaproj/numaflow/sinker/ServerErrTest.java b/src/test/java/io/numaproj/numaflow/sinker/ServerErrTest.java index 158b1355..b8e484fb 100644 --- a/src/test/java/io/numaproj/numaflow/sinker/ServerErrTest.java +++ b/src/test/java/io/numaproj/numaflow/sinker/ServerErrTest.java @@ -19,6 +19,8 @@ import java.util.List; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @Slf4j @@ -60,7 +62,7 @@ public void tearDown() throws Exception { } @Test - public void sinkerSuccess() throws InterruptedException { + public void sinkerException() { //create an output stream observer SinkOutputStreamObserver outputStreamObserver = new SinkOutputStreamObserver(); @@ -83,6 +85,11 @@ public void sinkerSuccess() throws InterruptedException { .sinkFn(outputStreamObserver); String actualId = "sink_test_id"; + // send handshake request + inputStreamObserver.onNext(SinkOuterClass.SinkRequest.newBuilder() + .setHandshake(SinkOuterClass.Handshake.newBuilder().setSot(true).build()) + .build()); + for (int i = 1; i <= 100; i++) { String[] keys; if (i < 100) { @@ -90,14 +97,22 @@ public void sinkerSuccess() throws InterruptedException { } else { keys = new String[]{"invalid-key"}; } - SinkOuterClass.SinkRequest sinkRequest = SinkOuterClass.SinkRequest.newBuilder() + SinkOuterClass.SinkRequest.Request request = SinkOuterClass.SinkRequest.Request + .newBuilder() .setValue(ByteString.copyFromUtf8(String.valueOf(i))) .setId(actualId) .addAllKeys(List.of(keys)) .build(); - inputStreamObserver.onNext(sinkRequest); + inputStreamObserver.onNext(SinkOuterClass.SinkRequest + .newBuilder() + .setRequest(request) + .build()); } + // send eot message + inputStreamObserver.onNext(SinkOuterClass.SinkRequest.newBuilder() + .setStatus(SinkOuterClass.SinkRequest.Status.newBuilder().setEot(true)).build()); + inputStreamObserver.onCompleted(); try { @@ -107,13 +122,40 @@ public void sinkerSuccess() throws InterruptedException { } } + @Test + public void sinkerNoHandshake() { + // Create an output stream observer + SinkOutputStreamObserver outputStreamObserver = new SinkOutputStreamObserver(); + + StreamObserver<SinkOuterClass.SinkRequest> inputStreamObserver = SinkGrpc + .newStub(inProcessChannel) + .sinkFn(outputStreamObserver); + + // Send a request without sending a handshake request + SinkOuterClass.SinkRequest request = SinkOuterClass.SinkRequest.newBuilder() + .setRequest(SinkOuterClass.SinkRequest.Request.newBuilder() + .setValue(ByteString.copyFromUtf8("test")) + .setId("test_id") + .addKeys("test_key") + .build()) + .build(); + inputStreamObserver.onNext(request); + + // Wait for the server to process the request + while (!outputStreamObserver.completed.get()) ; + + // Check if an error was received + assertNotNull(outputStreamObserver.t); + assertEquals( + "INVALID_ARGUMENT: Handshake request not received", + outputStreamObserver.t.getMessage()); + } + @Slf4j private static class TestSinkFnErr extends Sinker { - @Override public ResponseList processMessages(DatumIterator datumIterator) { throw new RuntimeException("unknown exception"); } - } } diff --git a/src/test/java/io/numaproj/numaflow/sinker/ServerTest.java b/src/test/java/io/numaproj/numaflow/sinker/ServerTest.java index 5292ad4a..65043c75 100644 --- a/src/test/java/io/numaproj/numaflow/sinker/ServerTest.java +++ b/src/test/java/io/numaproj/numaflow/sinker/ServerTest.java @@ -20,6 +20,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; @Slf4j @RunWith(JUnit4.class) @@ -72,6 +73,12 @@ public void sinkerSuccess() { String actualId = "sink_test_id"; String expectedId = actualId + processedIdSuffix; + // Send a handshake request + SinkOuterClass.SinkRequest handshakeRequest = SinkOuterClass.SinkRequest.newBuilder() + .setHandshake(SinkOuterClass.Handshake.newBuilder().setSot(true).build()) + .build(); + inputStreamObserver.onNext(handshakeRequest); + for (int i = 1; i <= 100; i++) { String[] keys; if (i < 100) { @@ -79,38 +86,55 @@ public void sinkerSuccess() { } else { keys = new String[]{"invalid-key"}; } - SinkOuterClass.SinkRequest sinkRequest = SinkOuterClass.SinkRequest.newBuilder() + + SinkOuterClass.SinkRequest.Request request = SinkOuterClass.SinkRequest.Request + .newBuilder() .setValue(ByteString.copyFromUtf8(String.valueOf(i))) .setId(actualId) .addAllKeys(List.of(keys)) .build(); + + SinkOuterClass.SinkRequest sinkRequest = SinkOuterClass.SinkRequest.newBuilder() + .setRequest(request).build(); inputStreamObserver.onNext(sinkRequest); + + // If it's the end of the batch, send an EOT message + if (i % 10 == 0) { + SinkOuterClass.SinkRequest eotRequest = SinkOuterClass.SinkRequest.newBuilder() + .setStatus(SinkOuterClass.SinkRequest.Status + .newBuilder() + .setEot(true) + .build()) + .build(); + inputStreamObserver.onNext(eotRequest); + } } inputStreamObserver.onCompleted(); while (!outputStreamObserver.completed.get()) ; - SinkOuterClass.SinkResponse responseList = outputStreamObserver.getSinkResponse(); - assertEquals(100, responseList.getResultsCount()); - responseList.getResultsList().forEach((response -> { - assertEquals(response.getId(), expectedId); - })); - - assertEquals( - responseList.getResults(responseList.getResultsCount() - 1).getErrMsg(), - "error message"); + List<SinkOuterClass.SinkResponse> responseList = outputStreamObserver.getSinkResponse(); + assertEquals(101, responseList.size()); + // first response is the handshake response + assertTrue(responseList.get(0).getHandshake().getSot()); + + responseList = responseList.subList(1, responseList.size()); + responseList.forEach(response -> { + assertEquals(response.getResult().getId(), expectedId); + if (response.getResult().getStatus() == SinkOuterClass.Status.FAILURE) { + assertEquals(response.getResult().getErrMsg(), "error message"); + } + }); } @Slf4j private static class TestSinkFn extends Sinker { - - @Override public ResponseList processMessages(DatumIterator datumIterator) { ResponseList.ResponseListBuilder builder = ResponseList.newBuilder(); while (true) { - Datum datum = null; + Datum datum; try { datum = datumIterator.next(); } catch (InterruptedException e) { diff --git a/src/test/java/io/numaproj/numaflow/sinker/SinkOutputStreamObserver.java b/src/test/java/io/numaproj/numaflow/sinker/SinkOutputStreamObserver.java index e9979494..0422d9e9 100644 --- a/src/test/java/io/numaproj/numaflow/sinker/SinkOutputStreamObserver.java +++ b/src/test/java/io/numaproj/numaflow/sinker/SinkOutputStreamObserver.java @@ -4,25 +4,28 @@ import io.grpc.stub.StreamObserver; import io.numaproj.numaflow.sink.v1.SinkOuterClass; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.atomic.AtomicReference; public class SinkOutputStreamObserver implements StreamObserver<SinkOuterClass.SinkResponse> { + private final List<SinkOuterClass.SinkResponse> sinkResponses = new ArrayList<>(); public AtomicReference<Boolean> completed = new AtomicReference<>(false); public Throwable t; - private SinkOuterClass.SinkResponse sinkResponse; - public SinkOuterClass.SinkResponse getSinkResponse() { - return sinkResponse; + public List<SinkOuterClass.SinkResponse> getSinkResponse() { + return sinkResponses; } @Override public void onNext(SinkOuterClass.SinkResponse datum) { - sinkResponse = datum; + sinkResponses.add(datum); } @Override public void onError(Throwable throwable) { t = throwable; + this.completed.set(true); } @Override diff --git a/src/test/java/io/numaproj/numaflow/sourcer/AckOutputStreamObserver.java b/src/test/java/io/numaproj/numaflow/sourcer/AckOutputStreamObserver.java new file mode 100644 index 00000000..b88a9622 --- /dev/null +++ b/src/test/java/io/numaproj/numaflow/sourcer/AckOutputStreamObserver.java @@ -0,0 +1,35 @@ +package io.numaproj.numaflow.sourcer; + + +import io.grpc.stub.StreamObserver; +import io.numaproj.numaflow.source.v1.SourceOuterClass; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +public class AckOutputStreamObserver implements StreamObserver<SourceOuterClass.AckResponse> { + private final List<SourceOuterClass.AckResponse> ackResponses = new ArrayList<>(); + public AtomicReference<Boolean> completed = new AtomicReference<>(false); + public Throwable t; + + public List<SourceOuterClass.AckResponse> getSinkResponse() { + return ackResponses; + } + + @Override + public void onNext(SourceOuterClass.AckResponse datum) { + ackResponses.add(datum); + } + + @Override + public void onError(Throwable throwable) { + t = throwable; + this.completed.set(true); + } + + @Override + public void onCompleted() { + this.completed.set(true); + } +} diff --git a/src/test/java/io/numaproj/numaflow/sourcer/ReadOutputStreamObserver.java b/src/test/java/io/numaproj/numaflow/sourcer/ReadOutputStreamObserver.java new file mode 100644 index 00000000..06342764 --- /dev/null +++ b/src/test/java/io/numaproj/numaflow/sourcer/ReadOutputStreamObserver.java @@ -0,0 +1,35 @@ +package io.numaproj.numaflow.sourcer; + + +import io.grpc.stub.StreamObserver; +import io.numaproj.numaflow.source.v1.SourceOuterClass; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +public class ReadOutputStreamObserver implements StreamObserver<SourceOuterClass.ReadResponse> { + private final List<SourceOuterClass.ReadResponse> readResponses = new ArrayList<>(); + public AtomicReference<Boolean> completed = new AtomicReference<>(false); + public Throwable t; + + public List<SourceOuterClass.ReadResponse> getSinkResponse() { + return readResponses; + } + + @Override + public void onNext(SourceOuterClass.ReadResponse datum) { + readResponses.add(datum); + } + + @Override + public void onError(Throwable throwable) { + t = throwable; + this.completed.set(true); + } + + @Override + public void onCompleted() { + this.completed.set(true); + } +} diff --git a/src/test/java/io/numaproj/numaflow/sourcer/ServerErrTest.java b/src/test/java/io/numaproj/numaflow/sourcer/ServerErrTest.java index 23281fb2..13b43c06 100644 --- a/src/test/java/io/numaproj/numaflow/sourcer/ServerErrTest.java +++ b/src/test/java/io/numaproj/numaflow/sourcer/ServerErrTest.java @@ -22,6 +22,8 @@ import java.util.List; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; public class ServerErrTest { @@ -136,6 +138,58 @@ public void onCompleted() { readRequestObserver.onCompleted(); } + @Test + public void sourceWithoutAckHandshake() { + // Create an output stream observer + AckOutputStreamObserver outputStreamObserver = new AckOutputStreamObserver(); + + StreamObserver<SourceOuterClass.AckRequest> inputStreamObserver = SourceGrpc + .newStub(inProcessChannel) + .ackFn(outputStreamObserver); + + // Send a request without sending a handshake request + SourceOuterClass.AckRequest request = SourceOuterClass.AckRequest.newBuilder() + .setRequest(SourceOuterClass.AckRequest.Request.newBuilder() + .build()) + .build(); + inputStreamObserver.onNext(request); + + // Wait for the server to process the request + while (!outputStreamObserver.completed.get()) ; + + // Check if an error was received + assertNotNull(outputStreamObserver.t); + assertEquals( + "INVALID_ARGUMENT: Handshake request not received", + outputStreamObserver.t.getMessage()); + } + + @Test + public void sourceWithoutReadHandshake() { + // Create an output stream observer + ReadOutputStreamObserver outputStreamObserver = new ReadOutputStreamObserver(); + + StreamObserver<SourceOuterClass.ReadRequest> inputStreamObserver = SourceGrpc + .newStub(inProcessChannel) + .readFn(outputStreamObserver); + + // Send a request without sending a handshake request + SourceOuterClass.ReadRequest request = SourceOuterClass.ReadRequest.newBuilder() + .setRequest(SourceOuterClass.ReadRequest.Request.newBuilder() + .build()) + .build(); + inputStreamObserver.onNext(request); + + // Wait for the server to process the request + while (!outputStreamObserver.completed.get()) ; + + // Check if an error was received + assertNotNull(outputStreamObserver.t); + assertEquals( + "INVALID_ARGUMENT: Handshake request not received", + outputStreamObserver.t.getMessage()); + } + private static class TestSourcerErr extends Sourcer { @Override diff --git a/src/test/java/io/numaproj/numaflow/sourcer/ServerTest.java b/src/test/java/io/numaproj/numaflow/sourcer/ServerTest.java index fb9139c3..1690d3af 100644 --- a/src/test/java/io/numaproj/numaflow/sourcer/ServerTest.java +++ b/src/test/java/io/numaproj/numaflow/sourcer/ServerTest.java @@ -86,6 +86,7 @@ public void TestSourcer() { int count = 0; boolean handshake = false; boolean eot = false; + @Override public void onNext(SourceOuterClass.ReadResponse readResponse) { // Handle handshake response @@ -135,6 +136,7 @@ public void onCompleted() { StreamObserver<SourceOuterClass.AckRequest> ackRequestObserver = stub.ackFn(new StreamObserver<>() { boolean handshake = false; int count = 0; + @Override public void onNext(SourceOuterClass.AckResponse ackResponse) { if (ackResponse.hasHandshake() && ackResponse.getHandshake().getSot()) {