Skip to content

Commit

Permalink
feat: Unify Batch Map and Unary Map Operations Using a Shared gRPC Pr…
Browse files Browse the repository at this point in the history
…otocol (#144)
  • Loading branch information
yhl25 authored Oct 12, 2024
1 parent 2895ffd commit fa2f746
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 239 deletions.
152 changes: 88 additions & 64 deletions src/main/java/io/numaproj/numaflow/batchmapper/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,105 +4,130 @@
import com.google.protobuf.Empty;
import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import io.numaproj.numaflow.batchmap.v1.BatchMapGrpc;
import io.numaproj.numaflow.batchmap.v1.Batchmap;
import io.numaproj.numaflow.map.v1.MapGrpc;
import io.numaproj.numaflow.map.v1.MapOuterClass;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;


@Slf4j
@AllArgsConstructor
class Service extends BatchMapGrpc.BatchMapImplBase {
class Service extends MapGrpc.MapImplBase {

// batchMapTaskExecutor is the executor for the batchMap. It is a fixed size thread pool
// Executor service for the batch mapper. It is a fixed size thread pool
// with the number of threads equal to the number of cores on the machine times 2.
private final ExecutorService batchMapTaskExecutor = Executors
// We use 2 times the number of cores because the batch mapper is a CPU intensive task.
private final ExecutorService mapTaskExecutor = Executors
.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);

// SHUTDOWN_TIME is the time to wait for the sinker to shut down, in seconds.
// Time to wait for the batch mapper to shut down, in seconds.
// We use 30 seconds as the default value because it provides a balance between giving tasks enough time to complete
// and not delaying program termination unduly.
private final long SHUTDOWN_TIME = 30;

// BatchMapper instance to process the messages
private final BatchMapper batchMapper;

// Applies a map function to each datum element in the stream.
@Override
public StreamObserver<Batchmap.BatchMapRequest> batchMapFn(StreamObserver<Batchmap.BatchMapResponse> responseObserver) {
public StreamObserver<MapOuterClass.MapRequest> mapFn(StreamObserver<MapOuterClass.MapResponse> responseObserver) {

// If the batchMapper is null, return an unimplemented call
if (this.batchMapper == null) {
return io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall(
BatchMapGrpc.getBatchMapFnMethod(),
MapGrpc.getMapFnMethod(),
responseObserver);
}

DatumIteratorImpl datumStream = new DatumIteratorImpl();

Future<BatchResponses> result = batchMapTaskExecutor.submit(() -> this.batchMapper.processMessage(
datumStream));

// Return a new StreamObserver to handle the incoming requests
return new StreamObserver<>() {
private boolean startOfStream = true;
private boolean handshakeDone = false;
private DatumIteratorImpl datumStream;
private CompletableFuture<BatchResponses> result;

// Called for each incoming request
@Override
public void onNext(Batchmap.BatchMapRequest mapRequest) {
public void onNext(MapOuterClass.MapRequest mapRequest) {
try {
datumStream.writeMessage(constructHandlerDatum(mapRequest));
} catch (InterruptedException e) {
Thread.interrupted();
onError(e);
// Make sure the handshake is done before processing the messages
if (!handshakeDone) {
if (!mapRequest.hasHandshake() || !mapRequest.getHandshake().getSot()) {
responseObserver.onError(Status.INVALID_ARGUMENT
.withDescription("Handshake request not received")
.asException());
return;
}
responseObserver.onNext(MapOuterClass.MapResponse.newBuilder()
.setHandshake(mapRequest.getHandshake())
.build());
handshakeDone = true;
return;
}

// Create a DatumIterator to write the messages to the batch mapper
// and start the batch mapper if it is the start of the stream
if (startOfStream) {
datumStream = new DatumIteratorImpl();
result = CompletableFuture.supplyAsync(
() -> batchMapper.processMessage(datumStream),
mapTaskExecutor);
startOfStream = false;
}

// If end of transmission, write EOF datum to the stream
// Wait for the result and send the response back to the client
if (mapRequest.hasStatus() && mapRequest.getStatus().getEot()) {
datumStream.writeMessage(HandlerDatum.EOF_DATUM);
BatchResponses responses = result.join();
buildAndStreamResponse(responses, responseObserver);
startOfStream = true;
} else {
datumStream.writeMessage(constructHandlerDatum(mapRequest));
}
} catch (Exception e) {
log.error("Encountered an error in batch map", e);
responseObserver.onError(Status.UNKNOWN
.withDescription(e.getMessage())
.withCause(e)
.asException());
}
}

// Called when an error occurs
@Override
public void onError(Throwable throwable) {
// We close the stream and let the sender retry the messages
log.error("Error Encountered in batchMap Stream", throwable);
var status = Status.UNKNOWN
.withDescription(throwable.getMessage())
.withCause(throwable);
responseObserver.onError(status.asException());
}

// Called when the client has finished sending requests
@Override
public void onCompleted() {
try {
// We Fire off the call to the client from here and stream the response back
datumStream.writeMessage(HandlerDatum.EOF_DATUM);
BatchResponses responses = result.get();
log.debug(
"Finished the call Result size is :{} and iterator count is :{}",
responses.getItems().size(),
datumStream.getCount());
// Crash if the number of responses from the users don't match the input requests ignoring the EOF message
if (responses.getItems().size() != datumStream.getCount() - 1) {
throw new RuntimeException("Number of results did not match expected " + (
datumStream.getCount() - 1) + " but got " + responses
.getItems()
.size());
}
buildAndStreamResponse(responses, responseObserver);
} catch (Exception e) {
log.error("Error Encountered in batchMap Stream onCompleted", e);
onError(e);
}
responseObserver.onCompleted();
}
};
}

// Build and stream the response back to the client
private void buildAndStreamResponse(
BatchResponses responses,
StreamObserver<Batchmap.BatchMapResponse> responseObserver) {
StreamObserver<MapOuterClass.MapResponse> responseObserver) {
responses.getItems().forEach(message -> {
List<Batchmap.BatchMapResponse.Result> batchMapResponseResult = new ArrayList<>();
List<MapOuterClass.MapResponse.Result> mapResponseResult = new ArrayList<>();
message.getItems().forEach(res -> {
batchMapResponseResult.add(
Batchmap.BatchMapResponse.Result
mapResponseResult.add(
MapOuterClass.MapResponse.Result
.newBuilder()
.setValue(res.getValue()
== null ? ByteString.EMPTY : ByteString.copyFrom(
Expand All @@ -114,48 +139,48 @@ private void buildAndStreamResponse(
.build()
);
});
Batchmap.BatchMapResponse singleRequestResponse = Batchmap.BatchMapResponse
MapOuterClass.MapResponse singleRequestResponse = MapOuterClass.MapResponse
.newBuilder()
.setId(message.getId())
.addAllResults(batchMapResponseResult)
.addAllResults(mapResponseResult)
.build();
// Stream the response back to the sender
responseObserver.onNext(singleRequestResponse);
});
responseObserver.onCompleted();
}


// IsReady is the heartbeat endpoint for gRPC.
@Override
public void isReady(
Empty request,
StreamObserver<Batchmap.ReadyResponse> responseObserver) {
responseObserver.onNext(Batchmap.ReadyResponse.newBuilder().setReady(true).build());
StreamObserver<MapOuterClass.ReadyResponse> responseObserver) {
responseObserver.onNext(MapOuterClass.ReadyResponse.newBuilder().setReady(true).build());
responseObserver.onCompleted();
}

private HandlerDatum constructHandlerDatum(Batchmap.BatchMapRequest d) {
// Construct a HandlerDatum from a MapRequest
private HandlerDatum constructHandlerDatum(MapOuterClass.MapRequest d) {
return new HandlerDatum(
d.getKeysList().toArray(new String[0]),
d.getValue().toByteArray(),
d.getRequest().getKeysList().toArray(new String[0]),
d.getRequest().getValue().toByteArray(),
Instant.ofEpochSecond(
d.getWatermark().getSeconds(),
d.getWatermark().getNanos()),
d.getRequest().getWatermark().getSeconds(),
d.getRequest().getWatermark().getNanos()),
Instant.ofEpochSecond(
d.getEventTime().getSeconds(),
d.getEventTime().getNanos()),
d.getRequest().getEventTime().getSeconds(),
d.getRequest().getEventTime().getNanos()),
d.getId(),
d.getHeadersMap()
d.getRequest().getHeadersMap()
);
}

// shuts down the executor service which is used for reduce
// Shuts down the executor service which is used for batch map
public void shutDown() {
this.batchMapTaskExecutor.shutdown();
this.mapTaskExecutor.shutdown();
try {
if (!batchMapTaskExecutor.awaitTermination(SHUTDOWN_TIME, TimeUnit.SECONDS)) {
if (!mapTaskExecutor.awaitTermination(SHUTDOWN_TIME, TimeUnit.SECONDS)) {
log.error("BatchMap executor did not terminate in the specified time.");
List<Runnable> droppedTasks = batchMapTaskExecutor.shutdownNow();
List<Runnable> droppedTasks = mapTaskExecutor.shutdownNow();
log.error("BatchMap executor was abruptly shut down. " + droppedTasks.size()
+ " tasks will not be executed.");
} else {
Expand All @@ -166,5 +191,4 @@ public void shutDown() {
e.printStackTrace();
}
}

}
2 changes: 1 addition & 1 deletion src/main/java/io/numaproj/numaflow/sinker/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public void onNext(SinkOuterClass.SinkRequest request) {
}

try {
if (request.getStatus().getEot()) {
if (request.hasStatus() && request.getStatus().getEot()) {
// End of transmission, write EOF datum to the stream
// Wait for the result and send the response back to the client
datumStream.writeMessage(HandlerDatum.EOF_DATUM);
Expand Down
52 changes: 0 additions & 52 deletions src/main/proto/batchmap/v1/batchmap.proto

This file was deleted.

4 changes: 4 additions & 0 deletions src/main/proto/map/v1/map.proto
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@ message MapRequest {
google.protobuf.Timestamp watermark = 4;
map<string, string> headers = 5;
}
message Status {
bool eot = 1;
}
Request request = 1;
// This ID is used to uniquely identify a map request
string id = 2;
optional Handshake handshake = 3;
optional Status status = 4;
}

/*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,38 +1,40 @@
package io.numaproj.numaflow.batchmapper;

import io.grpc.stub.StreamObserver;
import io.numaproj.numaflow.batchmap.v1.Batchmap;
import lombok.extern.slf4j.Slf4j;
import io.numaproj.numaflow.map.v1.MapOuterClass;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.CompletableFuture;

@Slf4j
public class BatchMapOutputStreamObserver implements StreamObserver<Batchmap.BatchMapResponse> {
public AtomicReference<Boolean> completed = new AtomicReference<>(false);
public AtomicReference<List<Batchmap.BatchMapResponse>> resultDatum = new AtomicReference<>(
new ArrayList<>());
public Throwable t;
public class BatchMapOutputStreamObserver implements StreamObserver<MapOuterClass.MapResponse> {
List<MapOuterClass.MapResponse> mapResponses = new ArrayList<>();
CompletableFuture<Void> done = new CompletableFuture<>();
Integer responseCount;

public BatchMapOutputStreamObserver(Integer responseCount) {
this.responseCount = responseCount;
}

@Override
public void onNext(Batchmap.BatchMapResponse batchMapResponse) {
List<Batchmap.BatchMapResponse> receivedResponses = resultDatum.get();
receivedResponses.add(batchMapResponse);
resultDatum.set(receivedResponses);
log.info(
"Received BatchMapResponse with id {} and message count {}",
batchMapResponse.getId(),
batchMapResponse.getResultsCount());
public void onNext(MapOuterClass.MapResponse mapResponse) {
mapResponses.add(mapResponse);
if (mapResponses.size() == responseCount) {
done.complete(null);
}
}

@Override
public void onError(Throwable throwable) {
t = throwable;
done.completeExceptionally(throwable);
}

@Override
public void onCompleted() {
this.completed.set(true);
done.complete(null);
}

public List<MapOuterClass.MapResponse> getMapResponses() {
return mapResponses;
}
}
Loading

0 comments on commit fa2f746

Please sign in to comment.