Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Unify Batch Map and Unary Map Operations Using a Shared gRPC Protocol #144

Merged
merged 1 commit into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 88 additions & 64 deletions src/main/java/io/numaproj/numaflow/batchmapper/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,105 +4,130 @@
import com.google.protobuf.Empty;
import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import io.numaproj.numaflow.batchmap.v1.BatchMapGrpc;
import io.numaproj.numaflow.batchmap.v1.Batchmap;
import io.numaproj.numaflow.map.v1.MapGrpc;
import io.numaproj.numaflow.map.v1.MapOuterClass;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;


@Slf4j
@AllArgsConstructor
class Service extends BatchMapGrpc.BatchMapImplBase {
class Service extends MapGrpc.MapImplBase {

// batchMapTaskExecutor is the executor for the batchMap. It is a fixed size thread pool
// Executor service for the batch mapper. It is a fixed size thread pool
// with the number of threads equal to the number of cores on the machine times 2.
private final ExecutorService batchMapTaskExecutor = Executors
// We use 2 times the number of cores because the batch mapper is a CPU intensive task.
private final ExecutorService mapTaskExecutor = Executors
.newFixedThreadPool(Runtime.getRuntime().availableProcessors() * 2);

// SHUTDOWN_TIME is the time to wait for the sinker to shut down, in seconds.
// Time to wait for the batch mapper to shut down, in seconds.
// We use 30 seconds as the default value because it provides a balance between giving tasks enough time to complete
// and not delaying program termination unduly.
private final long SHUTDOWN_TIME = 30;

// BatchMapper instance to process the messages
private final BatchMapper batchMapper;

// Applies a map function to each datum element in the stream.
@Override
public StreamObserver<Batchmap.BatchMapRequest> batchMapFn(StreamObserver<Batchmap.BatchMapResponse> responseObserver) {
public StreamObserver<MapOuterClass.MapRequest> mapFn(StreamObserver<MapOuterClass.MapResponse> responseObserver) {

// If the batchMapper is null, return an unimplemented call
if (this.batchMapper == null) {
return io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall(
BatchMapGrpc.getBatchMapFnMethod(),
MapGrpc.getMapFnMethod(),
responseObserver);
}

DatumIteratorImpl datumStream = new DatumIteratorImpl();

Future<BatchResponses> result = batchMapTaskExecutor.submit(() -> this.batchMapper.processMessage(
datumStream));

// Return a new StreamObserver to handle the incoming requests
return new StreamObserver<>() {
private boolean startOfStream = true;
private boolean handshakeDone = false;
private DatumIteratorImpl datumStream;
private CompletableFuture<BatchResponses> result;

// Called for each incoming request
@Override
public void onNext(Batchmap.BatchMapRequest mapRequest) {
public void onNext(MapOuterClass.MapRequest mapRequest) {
try {
datumStream.writeMessage(constructHandlerDatum(mapRequest));
} catch (InterruptedException e) {
Thread.interrupted();
onError(e);
// Make sure the handshake is done before processing the messages
if (!handshakeDone) {
if (!mapRequest.hasHandshake() || !mapRequest.getHandshake().getSot()) {
responseObserver.onError(Status.INVALID_ARGUMENT
.withDescription("Handshake request not received")
.asException());
return;
}
responseObserver.onNext(MapOuterClass.MapResponse.newBuilder()
.setHandshake(mapRequest.getHandshake())
.build());
handshakeDone = true;
return;
}

// Create a DatumIterator to write the messages to the batch mapper
// and start the batch mapper if it is the start of the stream
if (startOfStream) {
datumStream = new DatumIteratorImpl();
result = CompletableFuture.supplyAsync(
() -> batchMapper.processMessage(datumStream),
mapTaskExecutor);
startOfStream = false;
}

// If end of transmission, write EOF datum to the stream
// Wait for the result and send the response back to the client
if (mapRequest.hasStatus() && mapRequest.getStatus().getEot()) {
datumStream.writeMessage(HandlerDatum.EOF_DATUM);
BatchResponses responses = result.join();
buildAndStreamResponse(responses, responseObserver);
startOfStream = true;
} else {
datumStream.writeMessage(constructHandlerDatum(mapRequest));
}
} catch (Exception e) {
log.error("Encountered an error in batch map", e);
responseObserver.onError(Status.UNKNOWN
.withDescription(e.getMessage())
.withCause(e)
.asException());
}
}

// Called when an error occurs
@Override
public void onError(Throwable throwable) {
// We close the stream and let the sender retry the messages
log.error("Error Encountered in batchMap Stream", throwable);
var status = Status.UNKNOWN
.withDescription(throwable.getMessage())
.withCause(throwable);
responseObserver.onError(status.asException());
}

// Called when the client has finished sending requests
@Override
public void onCompleted() {
try {
// We Fire off the call to the client from here and stream the response back
datumStream.writeMessage(HandlerDatum.EOF_DATUM);
BatchResponses responses = result.get();
log.debug(
"Finished the call Result size is :{} and iterator count is :{}",
responses.getItems().size(),
datumStream.getCount());
// Crash if the number of responses from the users don't match the input requests ignoring the EOF message
if (responses.getItems().size() != datumStream.getCount() - 1) {
throw new RuntimeException("Number of results did not match expected " + (
datumStream.getCount() - 1) + " but got " + responses
.getItems()
.size());
}
buildAndStreamResponse(responses, responseObserver);
} catch (Exception e) {
log.error("Error Encountered in batchMap Stream onCompleted", e);
onError(e);
}
responseObserver.onCompleted();
}
};
}

// Build and stream the response back to the client
private void buildAndStreamResponse(
BatchResponses responses,
StreamObserver<Batchmap.BatchMapResponse> responseObserver) {
StreamObserver<MapOuterClass.MapResponse> responseObserver) {
responses.getItems().forEach(message -> {
List<Batchmap.BatchMapResponse.Result> batchMapResponseResult = new ArrayList<>();
List<MapOuterClass.MapResponse.Result> mapResponseResult = new ArrayList<>();
message.getItems().forEach(res -> {
batchMapResponseResult.add(
Batchmap.BatchMapResponse.Result
mapResponseResult.add(
MapOuterClass.MapResponse.Result
.newBuilder()
.setValue(res.getValue()
== null ? ByteString.EMPTY : ByteString.copyFrom(
Expand All @@ -114,48 +139,48 @@ private void buildAndStreamResponse(
.build()
);
});
Batchmap.BatchMapResponse singleRequestResponse = Batchmap.BatchMapResponse
MapOuterClass.MapResponse singleRequestResponse = MapOuterClass.MapResponse
.newBuilder()
.setId(message.getId())
.addAllResults(batchMapResponseResult)
.addAllResults(mapResponseResult)
.build();
// Stream the response back to the sender
responseObserver.onNext(singleRequestResponse);
});
responseObserver.onCompleted();
}


// IsReady is the heartbeat endpoint for gRPC.
@Override
public void isReady(
Empty request,
StreamObserver<Batchmap.ReadyResponse> responseObserver) {
responseObserver.onNext(Batchmap.ReadyResponse.newBuilder().setReady(true).build());
StreamObserver<MapOuterClass.ReadyResponse> responseObserver) {
responseObserver.onNext(MapOuterClass.ReadyResponse.newBuilder().setReady(true).build());
responseObserver.onCompleted();
}

private HandlerDatum constructHandlerDatum(Batchmap.BatchMapRequest d) {
// Construct a HandlerDatum from a MapRequest
private HandlerDatum constructHandlerDatum(MapOuterClass.MapRequest d) {
return new HandlerDatum(
d.getKeysList().toArray(new String[0]),
d.getValue().toByteArray(),
d.getRequest().getKeysList().toArray(new String[0]),
d.getRequest().getValue().toByteArray(),
Instant.ofEpochSecond(
d.getWatermark().getSeconds(),
d.getWatermark().getNanos()),
d.getRequest().getWatermark().getSeconds(),
d.getRequest().getWatermark().getNanos()),
Instant.ofEpochSecond(
d.getEventTime().getSeconds(),
d.getEventTime().getNanos()),
d.getRequest().getEventTime().getSeconds(),
d.getRequest().getEventTime().getNanos()),
d.getId(),
d.getHeadersMap()
d.getRequest().getHeadersMap()
);
}

// shuts down the executor service which is used for reduce
// Shuts down the executor service which is used for batch map
public void shutDown() {
this.batchMapTaskExecutor.shutdown();
this.mapTaskExecutor.shutdown();
try {
if (!batchMapTaskExecutor.awaitTermination(SHUTDOWN_TIME, TimeUnit.SECONDS)) {
if (!mapTaskExecutor.awaitTermination(SHUTDOWN_TIME, TimeUnit.SECONDS)) {
log.error("BatchMap executor did not terminate in the specified time.");
List<Runnable> droppedTasks = batchMapTaskExecutor.shutdownNow();
List<Runnable> droppedTasks = mapTaskExecutor.shutdownNow();
log.error("BatchMap executor was abruptly shut down. " + droppedTasks.size()
+ " tasks will not be executed.");
} else {
Expand All @@ -166,5 +191,4 @@ public void shutDown() {
e.printStackTrace();
}
}

}
2 changes: 1 addition & 1 deletion src/main/java/io/numaproj/numaflow/sinker/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public void onNext(SinkOuterClass.SinkRequest request) {
}

try {
if (request.getStatus().getEot()) {
if (request.hasStatus() && request.getStatus().getEot()) {
// End of transmission, write EOF datum to the stream
// Wait for the result and send the response back to the client
datumStream.writeMessage(HandlerDatum.EOF_DATUM);
Expand Down
52 changes: 0 additions & 52 deletions src/main/proto/batchmap/v1/batchmap.proto

This file was deleted.

4 changes: 4 additions & 0 deletions src/main/proto/map/v1/map.proto
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@ message MapRequest {
google.protobuf.Timestamp watermark = 4;
map<string, string> headers = 5;
}
message Status {
bool eot = 1;
}
Request request = 1;
// This ID is used to uniquely identify a map request
string id = 2;
optional Handshake handshake = 3;
optional Status status = 4;
}

/*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,38 +1,40 @@
package io.numaproj.numaflow.batchmapper;

import io.grpc.stub.StreamObserver;
import io.numaproj.numaflow.batchmap.v1.Batchmap;
import lombok.extern.slf4j.Slf4j;
import io.numaproj.numaflow.map.v1.MapOuterClass;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.CompletableFuture;

@Slf4j
public class BatchMapOutputStreamObserver implements StreamObserver<Batchmap.BatchMapResponse> {
public AtomicReference<Boolean> completed = new AtomicReference<>(false);
public AtomicReference<List<Batchmap.BatchMapResponse>> resultDatum = new AtomicReference<>(
new ArrayList<>());
public Throwable t;
public class BatchMapOutputStreamObserver implements StreamObserver<MapOuterClass.MapResponse> {
List<MapOuterClass.MapResponse> mapResponses = new ArrayList<>();
CompletableFuture<Void> done = new CompletableFuture<>();
Integer responseCount;

public BatchMapOutputStreamObserver(Integer responseCount) {
this.responseCount = responseCount;
}

@Override
public void onNext(Batchmap.BatchMapResponse batchMapResponse) {
List<Batchmap.BatchMapResponse> receivedResponses = resultDatum.get();
receivedResponses.add(batchMapResponse);
resultDatum.set(receivedResponses);
log.info(
"Received BatchMapResponse with id {} and message count {}",
batchMapResponse.getId(),
batchMapResponse.getResultsCount());
public void onNext(MapOuterClass.MapResponse mapResponse) {
mapResponses.add(mapResponse);
if (mapResponses.size() == responseCount) {
done.complete(null);
}
}

@Override
public void onError(Throwable throwable) {
t = throwable;
done.completeExceptionally(throwable);
}

@Override
public void onCompleted() {
this.completed.set(true);
done.complete(null);
}

public List<MapOuterClass.MapResponse> getMapResponses() {
return mapResponses;
}
}
Loading
Loading