diff --git a/src/main/java/io/numaproj/numaflow/batchmapper/Service.java b/src/main/java/io/numaproj/numaflow/batchmapper/Service.java index 3f436941..1774e218 100644 --- a/src/main/java/io/numaproj/numaflow/batchmapper/Service.java +++ b/src/main/java/io/numaproj/numaflow/batchmapper/Service.java @@ -48,7 +48,7 @@ public StreamObserver batchMapFn(StreamObserver result = batchMapTaskExecutor.submit(() -> this.batchMapper.processMessage( datumStream)); - return new StreamObserver() { + return new StreamObserver<>() { @Override public void onNext(Batchmap.BatchMapRequest mapRequest) { try { diff --git a/src/main/java/io/numaproj/numaflow/mapper/Constants.java b/src/main/java/io/numaproj/numaflow/mapper/Constants.java index 811cab32..3eb538e7 100644 --- a/src/main/java/io/numaproj/numaflow/mapper/Constants.java +++ b/src/main/java/io/numaproj/numaflow/mapper/Constants.java @@ -14,4 +14,6 @@ class Constants { public static final String MAP_MODE_KEY = "MAP_MODE"; public static final String MAP_MODE = "unary-map"; + + public static final String EOF = "EOF"; } diff --git a/src/main/java/io/numaproj/numaflow/mapper/MapSupervisorActor.java b/src/main/java/io/numaproj/numaflow/mapper/MapSupervisorActor.java new file mode 100644 index 00000000..6a1649eb --- /dev/null +++ b/src/main/java/io/numaproj/numaflow/mapper/MapSupervisorActor.java @@ -0,0 +1,142 @@ +package io.numaproj.numaflow.mapper; + +import akka.actor.AbstractActor; +import akka.actor.ActorRef; +import akka.actor.AllDeadLetters; +import akka.actor.AllForOneStrategy; +import akka.actor.Props; +import akka.actor.SupervisorStrategy; +import akka.japi.pf.DeciderBuilder; +import akka.japi.pf.ReceiveBuilder; +import io.grpc.Status; +import io.grpc.stub.StreamObserver; +import io.numaproj.numaflow.map.v1.MapOuterClass; +import lombok.extern.slf4j.Slf4j; + +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +/** + * MapSupervisorActor actor is responsible for distributing the messages to actors and handling failure. + * It creates a MapperActor for each incoming request and listens to the responses from the MapperActor. + *

+ * MapSupervisorActor + * │ + * ├── Creates MapperActor instances for each incoming MapRequest + * │ │ + * │ ├── MapperActor 1 + * │ │ ├── Processes MapRequest + * │ │ ├── Sends MapResponse to MapSupervisorActor + * │ │ └── Stops itself after processing + * │ │ + * │ ├── MapperActor 2 + * │ │ ├── Processes MapRequest + * │ │ ├── Sends MapResponse to MapSupervisorActor + * │ │ └── Stops itself after processing + * │ │ + * ├── Listens to the responses from the MapperActor instances + * │ ├── On receiving a MapResponse, writes the response back to the client + * │ + * ├── If any MapperActor fails (throws an exception): + * │ ├── Sends the exception back to the client + * │ ├── Initiates a shutdown by completing the CompletableFuture exceptionally + * │ └── Stops all child actors (AllForOneStrategy) + */ +@Slf4j +class MapSupervisorActor extends AbstractActor { + private final Mapper mapper; + private final StreamObserver responseObserver; + private final CompletableFuture failureFuture; + + public MapSupervisorActor( + Mapper mapper, + StreamObserver responseObserver, + CompletableFuture failureFuture) { + this.mapper = mapper; + this.responseObserver = responseObserver; + this.failureFuture = failureFuture; + } + + public static Props props( + Mapper mapper, + StreamObserver responseObserver, + CompletableFuture failureFuture) { + return Props.create(MapSupervisorActor.class, mapper, responseObserver, failureFuture); + } + + @Override + public void preRestart(Throwable reason, Optional message) { + log.debug("supervisor pre restart was executed"); + failureFuture.completeExceptionally(reason); + responseObserver.onError(Status.UNKNOWN + .withDescription(reason.getMessage()) + .withCause(reason) + .asException()); + Service.mapperActorSystem.stop(getSelf()); + } + + @Override + public void postStop() { + log.debug("post stop of supervisor executed - {}", getSelf().toString()); + } + + @Override + public Receive createReceive() { + return ReceiveBuilder + .create() + .match(MapOuterClass.MapRequest.class, this::processRequest) + .match(MapOuterClass.MapResponse.class, this::sendResponse) + .match(Exception.class, this::handleFailure) + .match(AllDeadLetters.class, this::handleDeadLetters) + .match(String.class, eof -> responseObserver.onCompleted()) + .build(); + } + + private void handleFailure(Exception e) { + responseObserver.onError(Status.UNKNOWN + .withDescription(e.getMessage()) + .withCause(e) + .asException()); + failureFuture.completeExceptionally(e); + } + + private void sendResponse(MapOuterClass.MapResponse mapResponse) { + responseObserver.onNext(mapResponse); + } + + private void processRequest(MapOuterClass.MapRequest mapRequest) { + // Create a MapperActor for each incoming request. + ActorRef mapperActor = getContext() + .actorOf(MapperActor.props( + mapper)); + + // Send the message to the MapperActor. + mapperActor.tell(mapRequest, getSelf()); + } + + // if we see dead letters, we need to stop the execution and exit + // to make sure no messages are lost + private void handleDeadLetters(AllDeadLetters deadLetter) { + log.debug("got a dead letter, stopping the execution"); + responseObserver.onError(Status.UNKNOWN.withDescription("dead letters").asException()); + failureFuture.completeExceptionally(new Throwable("dead letters")); + getContext().getSystem().stop(getSelf()); + } + + @Override + public SupervisorStrategy supervisorStrategy() { + // we want to stop all child actors in case of any exception + return new AllForOneStrategy( + DeciderBuilder + .match(Exception.class, e -> { + failureFuture.completeExceptionally(e); + responseObserver.onError(Status.UNKNOWN + .withDescription(e.getMessage()) + .withCause(e) + .asException()); + return SupervisorStrategy.stop(); + }) + .build() + ); + } +} diff --git a/src/main/java/io/numaproj/numaflow/mapper/MapperActor.java b/src/main/java/io/numaproj/numaflow/mapper/MapperActor.java new file mode 100644 index 00000000..ef2b6f10 --- /dev/null +++ b/src/main/java/io/numaproj/numaflow/mapper/MapperActor.java @@ -0,0 +1,87 @@ +package io.numaproj.numaflow.mapper; + +import akka.actor.AbstractActor; +import akka.actor.Props; +import akka.japi.pf.ReceiveBuilder; +import com.google.protobuf.ByteString; +import io.numaproj.numaflow.map.v1.MapOuterClass; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; + +/** + * Mapper actor that processes the map request. It invokes the mapper to process the request and + * sends the response back to the sender actor(MapSupervisorActor). In case of any exception, it + * sends the exception back to the sender actor. It stops itself after processing the request. + */ +class MapperActor extends AbstractActor { + private final Mapper mapper; + + public MapperActor(Mapper mapper) { + this.mapper = mapper; + } + + public static Props props(Mapper mapper) { + return Props.create(MapperActor.class, mapper); + } + + @Override + public Receive createReceive() { + return ReceiveBuilder.create() + .match(MapOuterClass.MapRequest.class, this::processRequest) + .build(); + } + + /** + * Process the map request and send the response back to the sender actor. + * + * @param mapRequest map request + */ + private void processRequest(MapOuterClass.MapRequest mapRequest) { + Datum handlerDatum = new HandlerDatum( + mapRequest.getRequest().getValue().toByteArray(), + Instant.ofEpochSecond( + mapRequest.getRequest().getWatermark().getSeconds(), + mapRequest.getRequest().getWatermark().getNanos()), + Instant.ofEpochSecond( + mapRequest.getRequest().getEventTime().getSeconds(), + mapRequest.getRequest().getEventTime().getNanos()), + mapRequest.getRequest().getHeadersMap() + ); + String[] keys = mapRequest.getRequest().getKeysList().toArray(new String[0]); + try { + MessageList resultMessages = this.mapper.processMessage(keys, handlerDatum); + MapOuterClass.MapResponse response = buildResponse(resultMessages, mapRequest.getId()); + getSender().tell(response, getSelf()); + } catch (Exception e) { + getSender().tell(e, getSelf()); + } + context().stop(getSelf()); + } + + /** + * Build the response from the message list. + * + * @param messageList message list + * + * @return map response + */ + private MapOuterClass.MapResponse buildResponse(MessageList messageList, String ID) { + MapOuterClass.MapResponse.Builder responseBuilder = MapOuterClass + .MapResponse + .newBuilder(); + + messageList.getMessages().forEach(message -> { + responseBuilder.addResults(MapOuterClass.MapResponse.Result.newBuilder() + .setValue(message.getValue() == null ? ByteString.EMPTY : ByteString.copyFrom( + message.getValue())) + .addAllKeys(message.getKeys() + == null ? new ArrayList<>() : List.of(message.getKeys())) + .addAllTags(message.getTags() + == null ? new ArrayList<>() : List.of(message.getTags())) + .build()); + }); + return responseBuilder.setId(ID).build(); + } +} diff --git a/src/main/java/io/numaproj/numaflow/mapper/MapperTestKit.java b/src/main/java/io/numaproj/numaflow/mapper/MapperTestKit.java index 51e88034..e2612dc3 100644 --- a/src/main/java/io/numaproj/numaflow/mapper/MapperTestKit.java +++ b/src/main/java/io/numaproj/numaflow/mapper/MapperTestKit.java @@ -16,7 +16,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -76,7 +78,21 @@ public void stopServer() throws Exception { */ public static class Client { private final ManagedChannel channel; - private final MapGrpc.MapStub mapStub; + private final StreamObserver requestStreamObserver; + /* + * A ConcurrentHashMap that stores CompletableFuture instances for each request sent to the server. + * Each CompletableFuture corresponds to a unique request and is used to handle the response from the server. + * + * Key: A unique identifier (UUID) for each request sent to the server. + * Value: A CompletableFuture that will be completed when the server sends a response for the corresponding request. + * + * We use a concurrent map so that the user can send multiple requests concurrently. + * + * When a request is sent to the server, a new CompletableFuture is created and stored in the map with its unique identifier. + * When a response is received from the server, the corresponding CompletableFuture is retrieved from the map using the unique identifier, + * and then completed with the response data. If an error occurs, we complete all remaining futures exceptionally. + */ + private final ConcurrentHashMap> responseFutures = new ConcurrentHashMap<>(); /** * empty constructor for Client. @@ -94,35 +110,11 @@ public Client() { */ public Client(String host, int port) { this.channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(); - this.mapStub = MapGrpc.newStub(channel); - } - - private CompletableFuture sendGrpcRequest(MapOuterClass.MapRequest request) { - CompletableFuture future = new CompletableFuture<>(); - StreamObserver responseObserver = new StreamObserver<>() { - @Override - public void onNext(MapOuterClass.MapResponse response) { - future.complete(response); - } - - @Override - public void onError(Throwable t) { - future.completeExceptionally(t); - } - - @Override - public void onCompleted() { - if (!future.isDone()) { - future.completeExceptionally(new RuntimeException( - "Server completed without a response")); - } - } - }; - - mapStub.mapFn( - request, responseObserver); - - return future; + MapGrpc.MapStub stub = MapGrpc.newStub(channel); + this.requestStreamObserver = stub.mapFn(new ResponseObserver()); + this.requestStreamObserver.onNext(MapOuterClass.MapRequest.newBuilder() + .setHandshake(MapOuterClass.Handshake.newBuilder().setSot(true)) + .build()); } /** @@ -134,35 +126,19 @@ public void onCompleted() { * @return response from the server as a MessageList */ public MessageList sendRequest(String[] keys, Datum data) { - MapOuterClass.MapRequest request = MapOuterClass.MapRequest.newBuilder() - .addAllKeys(keys == null ? new ArrayList<>() : List.of(keys)) - .setValue(data.getValue() - == null ? ByteString.EMPTY : ByteString.copyFrom(data.getValue())) - .setEventTime( - data.getEventTime() == null ? Timestamp.newBuilder().build() : Timestamp - .newBuilder() - .setSeconds(data.getEventTime().getEpochSecond()) - .setNanos(data.getEventTime().getNano()) - .build()) - .setWatermark( - data.getWatermark() == null ? Timestamp.newBuilder().build() : Timestamp - .newBuilder() - .setSeconds(data.getWatermark().getEpochSecond()) - .setNanos(data.getWatermark().getNano()) - .build()) - .putAllHeaders(data.getHeaders() == null ? new HashMap<>() : data.getHeaders()) - .build(); + String requestId = UUID + .randomUUID() + .toString(); + CompletableFuture responseFuture = new CompletableFuture<>(); + responseFutures.put(requestId, responseFuture); + MapOuterClass.MapRequest request = createRequest(keys, data, requestId); try { - MapOuterClass.MapResponse response = this.sendGrpcRequest(request).get(); - List messages = response.getResultsList().stream() - .map(result -> new Message( - result.getValue().toByteArray(), - result.getKeysList().toArray(new String[0]), - result.getTagsList().toArray(new String[0]))) - .collect(Collectors.toList()); - - return new MessageList(messages); + this.requestStreamObserver.onNext(request); + MapOuterClass.MapResponse response = responseFuture.get(); + MessageList messageList = createResponse(response); + responseFutures.remove(requestId); + return messageList; } catch (Exception e) { throw new RuntimeException(e); } @@ -174,8 +150,76 @@ public MessageList sendRequest(String[] keys, Datum data) { * @throws InterruptedException if the client fails to close */ public void close() throws InterruptedException { + this.requestStreamObserver.onCompleted(); channel.shutdown().awaitTermination(5, TimeUnit.SECONDS); } + + private MapOuterClass.MapRequest createRequest( + String[] keys, + Datum data, + String requestId) { + return MapOuterClass.MapRequest.newBuilder().setRequest( + MapOuterClass.MapRequest.Request.newBuilder() + .addAllKeys(keys == null ? new ArrayList<>() : List.of(keys)) + .setValue(data.getValue() + == null ? ByteString.EMPTY : ByteString.copyFrom(data.getValue())) + .setEventTime( + data.getEventTime() == null ? Timestamp + .newBuilder() + .build() : Timestamp.newBuilder() + .setSeconds(data.getEventTime().getEpochSecond()) + .setNanos(data.getEventTime().getNano()) + .build()) + .setWatermark( + data.getWatermark() == null ? Timestamp + .newBuilder() + .build() : Timestamp.newBuilder() + .setSeconds(data.getWatermark().getEpochSecond()) + .setNanos(data.getWatermark().getNano()) + .build()) + .putAllHeaders( + data.getHeaders() == null ? new HashMap<>() : data.getHeaders()) + .build() + ).setId(requestId).build(); + } + + private MessageList createResponse(MapOuterClass.MapResponse response) { + List messages = response.getResultsList().stream() + .map(result -> new Message( + result.getValue().toByteArray(), + result.getKeysList().toArray(new String[0]), + result.getTagsList().toArray(new String[0]))) + .collect(Collectors.toList()); + return new MessageList(messages); + } + + private class ResponseObserver implements StreamObserver { + @Override + public void onNext(MapOuterClass.MapResponse mapResponse) { + if (mapResponse.hasHandshake()) { + return; + } + CompletableFuture responseFuture = responseFutures.get( + mapResponse.getId()); + if (responseFuture != null) { + responseFuture.complete(mapResponse); + } + } + + @Override + public void onError(Throwable throwable) { + // complete all remaining futures exceptionally + for (CompletableFuture future : responseFutures.values()) { + future.completeExceptionally(throwable); + } + } + + @Override + public void onCompleted() { + // remove all completed futures + responseFutures.values().removeIf(CompletableFuture::isDone); + } + } } /** diff --git a/src/main/java/io/numaproj/numaflow/mapper/Server.java b/src/main/java/io/numaproj/numaflow/mapper/Server.java index be82cc73..dfc10214 100644 --- a/src/main/java/io/numaproj/numaflow/mapper/Server.java +++ b/src/main/java/io/numaproj/numaflow/mapper/Server.java @@ -81,7 +81,7 @@ public void start() throws Exception { log.info( "Server started, listening on {}", grpcConfig.isLocal() ? - "localhost:" + grpcConfig.getPort():grpcConfig.getSocketPath()); + "localhost:" + grpcConfig.getPort() : grpcConfig.getSocketPath()); // register shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(() -> { diff --git a/src/main/java/io/numaproj/numaflow/mapper/Service.java b/src/main/java/io/numaproj/numaflow/mapper/Service.java index c0260433..150dee57 100644 --- a/src/main/java/io/numaproj/numaflow/mapper/Service.java +++ b/src/main/java/io/numaproj/numaflow/mapper/Service.java @@ -1,16 +1,16 @@ package io.numaproj.numaflow.mapper; -import com.google.protobuf.ByteString; +import akka.actor.ActorRef; +import akka.actor.ActorSystem; import com.google.protobuf.Empty; +import io.grpc.Status; import io.grpc.stub.StreamObserver; import io.numaproj.numaflow.map.v1.MapGrpc; import io.numaproj.numaflow.map.v1.MapOuterClass; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; -import java.time.Instant; -import java.util.ArrayList; -import java.util.List; +import java.util.concurrent.CompletableFuture; import static io.numaproj.numaflow.map.v1.MapGrpc.getMapFnMethod; @@ -18,42 +18,73 @@ @AllArgsConstructor class Service extends MapGrpc.MapImplBase { + public static final ActorSystem mapperActorSystem = ActorSystem.create("mapper"); + private final Mapper mapper; - /** - * Applies a function to each datum element. - */ + // TODO we need to propagate the exception all the way up and shutdown the server. + static void handleFailure( + CompletableFuture failureFuture) { + new Thread(() -> { + try { + failureFuture.get(); + } catch (Exception e) { + e.printStackTrace(); + } + }).start(); + } + @Override - public void mapFn( - MapOuterClass.MapRequest request, - StreamObserver responseObserver) { + public StreamObserver mapFn(final StreamObserver responseObserver) { if (this.mapper == null) { - io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall( + return io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall( getMapFnMethod(), responseObserver); - return; } - Datum handlerDatum = new HandlerDatum( - request.getValue().toByteArray(), - Instant.ofEpochSecond( - request.getWatermark().getSeconds(), - request.getWatermark().getNanos()), - Instant.ofEpochSecond( - request.getEventTime().getSeconds(), - request.getEventTime().getNanos()), - request.getHeadersMap() - ); - - // process request - MessageList messageList = mapper.processMessage(request - .getKeysList() - .toArray(new String[0]), handlerDatum); - - // set response - responseObserver.onNext(buildResponse(messageList)); - responseObserver.onCompleted(); + CompletableFuture failureFuture = new CompletableFuture<>(); + + handleFailure(failureFuture); + + // create a MapSupervisorActor that processes the map requests. + ActorRef mapSupervisorActor = mapperActorSystem + .actorOf(MapSupervisorActor.props(mapper, responseObserver, failureFuture)); + + return new StreamObserver<>() { + private boolean handshakeDone = false; + + @Override + public void onNext(MapOuterClass.MapRequest request) { + // make sure the handshake is done before processing the messages + if (!handshakeDone) { + if (!request.hasHandshake() || !request.getHandshake().getSot()) { + responseObserver.onError(Status.INVALID_ARGUMENT + .withDescription("Handshake request not received") + .asException()); + return; + } + responseObserver.onNext(MapOuterClass.MapResponse.newBuilder() + .setHandshake(request.getHandshake()) + .build()); + handshakeDone = true; + return; + } + // send the message to the MapSupervisorActor. + mapSupervisorActor.tell(request, ActorRef.noSender()); + } + + @Override + public void onError(Throwable throwable) { + mapSupervisorActor.tell(new Exception(throwable), ActorRef.noSender()); + } + + @Override + public void onCompleted() { + // indicate the end of input to the MapSupervisorActor. + mapSupervisorActor.tell(Constants.EOF, ActorRef.noSender()); + } + }; } /** @@ -66,23 +97,4 @@ public void isReady( responseObserver.onNext(MapOuterClass.ReadyResponse.newBuilder().setReady(true).build()); responseObserver.onCompleted(); } - - private MapOuterClass.MapResponse buildResponse(MessageList messageList) { - MapOuterClass.MapResponse.Builder responseBuilder = MapOuterClass - .MapResponse - .newBuilder(); - - messageList.getMessages().forEach(message -> { - responseBuilder.addResults(MapOuterClass.MapResponse.Result.newBuilder() - .setValue(message.getValue() == null ? ByteString.EMPTY : ByteString.copyFrom( - message.getValue())) - .addAllKeys(message.getKeys() - == null ? new ArrayList<>() : List.of(message.getKeys())) - .addAllTags(message.getTags() - == null ? new ArrayList<>() : List.of(message.getTags())) - .build()); - }); - return responseBuilder.build(); - } - } diff --git a/src/main/java/io/numaproj/numaflow/reducer/Server.java b/src/main/java/io/numaproj/numaflow/reducer/Server.java index 2fe9095f..92295c67 100644 --- a/src/main/java/io/numaproj/numaflow/reducer/Server.java +++ b/src/main/java/io/numaproj/numaflow/reducer/Server.java @@ -76,7 +76,7 @@ public void start() throws Exception { log.info( "Server started, listening on {}", grpcConfig.isLocal() ? - "localhost:" + grpcConfig.getPort():grpcConfig.getSocketPath()); + "localhost:" + grpcConfig.getPort() : grpcConfig.getSocketPath()); // register shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(() -> { diff --git a/src/main/java/io/numaproj/numaflow/reducestreamer/Server.java b/src/main/java/io/numaproj/numaflow/reducestreamer/Server.java index f97d4258..bb1c2d76 100644 --- a/src/main/java/io/numaproj/numaflow/reducestreamer/Server.java +++ b/src/main/java/io/numaproj/numaflow/reducestreamer/Server.java @@ -79,7 +79,7 @@ public void start() throws Exception { log.info( "Server started, listening on {}", grpcConfig.isLocal() ? - "localhost:" + grpcConfig.getPort():grpcConfig.getSocketPath()); + "localhost:" + grpcConfig.getPort() : grpcConfig.getSocketPath()); // register shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(() -> { diff --git a/src/main/java/io/numaproj/numaflow/sessionreducer/Server.java b/src/main/java/io/numaproj/numaflow/sessionreducer/Server.java index 5fbdb6a3..7509a5a4 100644 --- a/src/main/java/io/numaproj/numaflow/sessionreducer/Server.java +++ b/src/main/java/io/numaproj/numaflow/sessionreducer/Server.java @@ -79,7 +79,7 @@ public void start() throws Exception { log.info( "Server started, listening on {}", grpcConfig.isLocal() ? - "localhost:" + grpcConfig.getPort():grpcConfig.getSocketPath()); + "localhost:" + grpcConfig.getPort() : grpcConfig.getSocketPath()); // register shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(() -> { diff --git a/src/main/java/io/numaproj/numaflow/sideinput/Server.java b/src/main/java/io/numaproj/numaflow/sideinput/Server.java index 580855c8..d0c7af0f 100644 --- a/src/main/java/io/numaproj/numaflow/sideinput/Server.java +++ b/src/main/java/io/numaproj/numaflow/sideinput/Server.java @@ -76,7 +76,7 @@ public void start() throws Exception { log.info( "Server started, listening on {}", grpcConfig.isLocal() ? - "localhost:" + grpcConfig.getPort():grpcConfig.getSocketPath()); + "localhost:" + grpcConfig.getPort() : grpcConfig.getSocketPath()); // register shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(() -> { diff --git a/src/main/java/io/numaproj/numaflow/sourcer/Server.java b/src/main/java/io/numaproj/numaflow/sourcer/Server.java index 83961dc5..d4600a00 100644 --- a/src/main/java/io/numaproj/numaflow/sourcer/Server.java +++ b/src/main/java/io/numaproj/numaflow/sourcer/Server.java @@ -76,7 +76,7 @@ public void start() throws Exception { log.info( "Server started, listening on {}", grpcConfig.isLocal() ? - "localhost:" + grpcConfig.getPort():grpcConfig.getSocketPath()); + "localhost:" + grpcConfig.getPort() : grpcConfig.getSocketPath()); // register shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(() -> { diff --git a/src/main/java/io/numaproj/numaflow/sourcetransformer/Server.java b/src/main/java/io/numaproj/numaflow/sourcetransformer/Server.java index 3dd29dc3..2e5e85fd 100644 --- a/src/main/java/io/numaproj/numaflow/sourcetransformer/Server.java +++ b/src/main/java/io/numaproj/numaflow/sourcetransformer/Server.java @@ -76,7 +76,7 @@ public void start() throws Exception { log.info( "Server started, listening on {}", grpcConfig.isLocal() ? - "localhost:" + grpcConfig.getPort():grpcConfig.getSocketPath()); + "localhost:" + grpcConfig.getPort() : grpcConfig.getSocketPath()); // register shutdown hook Runtime.getRuntime().addShutdownHook(new Thread(() -> { diff --git a/src/main/proto/map/v1/map.proto b/src/main/proto/map/v1/map.proto index 09842985..f256c157 100644 --- a/src/main/proto/map/v1/map.proto +++ b/src/main/proto/map/v1/map.proto @@ -9,7 +9,7 @@ package map.v1; service Map { // MapFn applies a function to each map request element. - rpc MapFn(MapRequest) returns (MapResponse); + rpc MapFn(stream MapRequest) returns (stream MapResponse); // IsReady is the heartbeat endpoint for gRPC. rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); @@ -19,11 +19,25 @@ service Map { * MapRequest represents a request element. */ message MapRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - map headers = 5; + message Request { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + map headers = 5; + } + Request request = 1; + // This ID is used to uniquely identify a map request + string id = 2; + optional Handshake handshake = 3; +} + +/* + * Handshake message between client and server to indicate the start of transmission. + */ +message Handshake { + // Required field indicating the start of transmission. + bool sot = 1; } /** @@ -36,6 +50,9 @@ message MapResponse { repeated string tags = 3; } repeated Result results = 1; + // This ID is used to refer the responses to the request it corresponds to. + string id = 2; + optional Handshake handshake = 3; } /** diff --git a/src/test/java/io/numaproj/numaflow/batchmapper/BatchMapOutputStreamObserver.java b/src/test/java/io/numaproj/numaflow/batchmapper/BatchMapOutputStreamObserver.java index d7b87c7c..4d8da634 100644 --- a/src/test/java/io/numaproj/numaflow/batchmapper/BatchMapOutputStreamObserver.java +++ b/src/test/java/io/numaproj/numaflow/batchmapper/BatchMapOutputStreamObserver.java @@ -20,7 +20,10 @@ public void onNext(Batchmap.BatchMapResponse batchMapResponse) { List receivedResponses = resultDatum.get(); receivedResponses.add(batchMapResponse); resultDatum.set(receivedResponses); - log.info("Received BatchMapResponse with id {} and message count {}", batchMapResponse.getId(), batchMapResponse.getResultsCount()); + log.info( + "Received BatchMapResponse with id {} and message count {}", + batchMapResponse.getId(), + batchMapResponse.getResultsCount()); } @Override diff --git a/src/test/java/io/numaproj/numaflow/batchmapper/ServerErrTest.java b/src/test/java/io/numaproj/numaflow/batchmapper/ServerErrTest.java index 046515be..93af2400 100644 --- a/src/test/java/io/numaproj/numaflow/batchmapper/ServerErrTest.java +++ b/src/test/java/io/numaproj/numaflow/batchmapper/ServerErrTest.java @@ -226,13 +226,13 @@ public BatchResponses processMessage(DatumIterator datumStream) { if (datum.getId().equals("exception")) { throw new RuntimeException("unknown exception"); } else if (!datum.getId().equals("drop")) { - String msg = new String(datum.getValue()); - String[] strs = msg.split(","); - BatchResponse batchResponse = new BatchResponse(datum.getId()); - for (String str : strs) { - batchResponse.append(new Message(str.getBytes())); - } - batchResponses.append(batchResponse); + String msg = new String(datum.getValue()); + String[] strs = msg.split(","); + BatchResponse batchResponse = new BatchResponse(datum.getId()); + for (String str : strs) { + batchResponse.append(new Message(str.getBytes())); + } + batchResponses.append(batchResponse); } } return batchResponses; diff --git a/src/test/java/io/numaproj/numaflow/mapper/MapOutputStreamObserver.java b/src/test/java/io/numaproj/numaflow/mapper/MapOutputStreamObserver.java new file mode 100644 index 00000000..bb1d8ee7 --- /dev/null +++ b/src/test/java/io/numaproj/numaflow/mapper/MapOutputStreamObserver.java @@ -0,0 +1,40 @@ +package io.numaproj.numaflow.mapper; + +import io.grpc.stub.StreamObserver; +import io.numaproj.numaflow.map.v1.MapOuterClass; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +public class MapOutputStreamObserver implements StreamObserver { + List mapResponses = new ArrayList<>(); + CompletableFuture done = new CompletableFuture<>(); + Integer responseCount; + + public MapOutputStreamObserver(Integer responseCount) { + this.responseCount = responseCount; + } + + @Override + public void onNext(MapOuterClass.MapResponse mapResponse) { + mapResponses.add(mapResponse); + if (mapResponses.size() == responseCount) { + done.complete(null); + } + } + + @Override + public void onError(Throwable throwable) { + done.completeExceptionally(throwable); + } + + @Override + public void onCompleted() { + done.complete(null); + } + + public List getMapResponses() { + return mapResponses; + } +} diff --git a/src/test/java/io/numaproj/numaflow/mapper/ServerErrTest.java b/src/test/java/io/numaproj/numaflow/mapper/ServerErrTest.java index 9224347e..b8fe977b 100644 --- a/src/test/java/io/numaproj/numaflow/mapper/ServerErrTest.java +++ b/src/test/java/io/numaproj/numaflow/mapper/ServerErrTest.java @@ -22,7 +22,6 @@ import java.util.List; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; public class ServerErrTest { @@ -103,20 +102,33 @@ public void tearDown() throws Exception { } @Test - public void TestMapperErr() { + public void testMapperFailure() { + MapOuterClass.MapRequest handshakeRequest = MapOuterClass.MapRequest + .newBuilder() + .setHandshake(MapOuterClass.Handshake.newBuilder().setSot(true)) + .build(); + ByteString inValue = ByteString.copyFromUtf8("invalue"); MapOuterClass.MapRequest mapRequest = MapOuterClass.MapRequest .newBuilder() - .addAllKeys(List.of("test-map-key")) - .setValue(inValue) + .setRequest(MapOuterClass.MapRequest.Request.newBuilder() + .addAllKeys(List.of("test-map-key")).setValue(inValue).build()) .build(); - var stub = MapGrpc.newBlockingStub(inProcessChannel); + MapOutputStreamObserver responseObserver = new MapOutputStreamObserver(1); + + var stub = MapGrpc.newStub(inProcessChannel); + var requestObserver = stub.mapFn(responseObserver); + + requestObserver.onNext(handshakeRequest); + requestObserver.onNext(mapRequest); + try { - stub.mapFn(mapRequest); - fail("Expected the mapperErr to complete with exception"); + responseObserver.done.get(); } catch (Exception e) { - assertEquals("UNKNOWN: unknown exception", e.getMessage()); + assertEquals( + "io.grpc.StatusRuntimeException: UNKNOWN: unknown exception", + e.getMessage()); } } diff --git a/src/test/java/io/numaproj/numaflow/mapper/ServerTest.java b/src/test/java/io/numaproj/numaflow/mapper/ServerTest.java index 0bfc3630..0069f78c 100644 --- a/src/test/java/io/numaproj/numaflow/mapper/ServerTest.java +++ b/src/test/java/io/numaproj/numaflow/mapper/ServerTest.java @@ -14,8 +14,10 @@ import java.util.Arrays; import java.util.List; +import java.util.concurrent.ExecutionException; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; public class ServerTest { private final static String PROCESSED_KEY_SUFFIX = "-key-processed"; @@ -57,31 +59,86 @@ public void tearDown() throws Exception { } @Test - public void TestMapper() { + public void testMapperSuccess() { + MapOuterClass.MapRequest handshakeRequest = MapOuterClass.MapRequest + .newBuilder() + .setHandshake(MapOuterClass.Handshake.newBuilder().setSot(true)) + .build(); + ByteString inValue = ByteString.copyFromUtf8("invalue"); MapOuterClass.MapRequest inDatum = MapOuterClass.MapRequest .newBuilder() - .addAllKeys(List.of("test-map-key")) - .setValue(inValue) - .build(); + .setRequest(MapOuterClass.MapRequest.Request + .newBuilder() + .setValue(inValue) + .addAllKeys(List.of("test-map-key")) + .build()).build(); String[] expectedKeys = new String[]{"test-map-key" + PROCESSED_KEY_SUFFIX}; String[] expectedTags = new String[]{"test-tag"}; ByteString expectedValue = ByteString.copyFromUtf8("invalue" + PROCESSED_VALUE_SUFFIX); + MapOutputStreamObserver responseObserver = new MapOutputStreamObserver(4); + + var stub = MapGrpc.newStub(inProcessChannel); + var requestStreamObserver = stub + .mapFn(responseObserver); + + requestStreamObserver.onNext(handshakeRequest); + requestStreamObserver.onNext(inDatum); + requestStreamObserver.onNext(inDatum); + requestStreamObserver.onNext(inDatum); + + try { + responseObserver.done.get(); + } catch (InterruptedException | ExecutionException e) { + fail("Error while waiting for response" + e.getMessage()); + } + + List responses = responseObserver.getMapResponses(); + assertEquals(4, responses.size()); - var stub = MapGrpc.newBlockingStub(inProcessChannel); - var mapResponse = stub - .mapFn(inDatum); + // first response is the handshake response + assertEquals(handshakeRequest.getHandshake(), responses.get(0).getHandshake()); - assertEquals(1, mapResponse.getResultsCount()); - assertEquals( - expectedKeys, - mapResponse.getResults(0).getKeysList().toArray(new String[0])); - assertEquals(expectedValue, mapResponse.getResults(0).getValue()); - assertEquals( - expectedTags, - mapResponse.getResults(0).getTagsList().toArray(new String[0])); + responses = responses.subList(1, responses.size()); + for (MapOuterClass.MapResponse response : responses) { + assertEquals(expectedValue, response.getResults(0).getValue()); + assertEquals(Arrays.asList(expectedKeys), response.getResults(0).getKeysList()); + assertEquals(Arrays.asList(expectedTags), response.getResults(0).getTagsList()); + assertEquals(1, response.getResultsCount()); + } + + requestStreamObserver.onCompleted(); + } + + @Test + public void testMapperWithoutHandshake() { + ByteString inValue = ByteString.copyFromUtf8("invalue"); + MapOuterClass.MapRequest inDatum = MapOuterClass.MapRequest + .newBuilder() + .setRequest(MapOuterClass.MapRequest.Request + .newBuilder() + .setValue(inValue) + .addAllKeys(List.of("test-map-key")) + .build()).build(); + + MapOutputStreamObserver responseObserver = new MapOutputStreamObserver(1); + + var stub = MapGrpc.newStub(inProcessChannel); + var requestStreamObserver = stub + .mapFn(responseObserver); + + requestStreamObserver.onNext(inDatum); + + try { + responseObserver.done.get(); + } catch (InterruptedException | ExecutionException e) { + assertEquals( + "io.grpc.StatusRuntimeException: INVALID_ARGUMENT: Handshake request not received", + e.getMessage()); + } + requestStreamObserver.onCompleted(); } private static class TestMapFn extends Mapper { diff --git a/src/test/java/io/numaproj/numaflow/sinker/ServerErrTest.java b/src/test/java/io/numaproj/numaflow/sinker/ServerErrTest.java index b8e484fb..fe944ff8 100644 --- a/src/test/java/io/numaproj/numaflow/sinker/ServerErrTest.java +++ b/src/test/java/io/numaproj/numaflow/sinker/ServerErrTest.java @@ -20,7 +20,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @Slf4j diff --git a/src/test/java/io/numaproj/numaflow/sourcer/ServerErrTest.java b/src/test/java/io/numaproj/numaflow/sourcer/ServerErrTest.java index 13b43c06..4a59006e 100644 --- a/src/test/java/io/numaproj/numaflow/sourcer/ServerErrTest.java +++ b/src/test/java/io/numaproj/numaflow/sourcer/ServerErrTest.java @@ -23,7 +23,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; public class ServerErrTest {