Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Unify MapStream and Unary Map Operations Using a Shared gRPC Protocol #146

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.google.protobuf.ByteString;
import io.grpc.stub.StreamObserver;
import io.numaproj.numaflow.mapstream.v1.Mapstream;
import io.numaproj.numaflow.map.v1.MapOuterClass;
import lombok.AllArgsConstructor;

import java.util.ArrayList;
Expand All @@ -14,17 +14,17 @@
*/
@AllArgsConstructor
class OutputObserverImpl implements OutputObserver {
StreamObserver<Mapstream.MapStreamResponse> responseObserver;
StreamObserver<MapOuterClass.MapResponse> responseObserver;

@Override
public void send(Message message) {
Mapstream.MapStreamResponse response = buildResponse(message);
MapOuterClass.MapResponse response = buildResponse(message);
responseObserver.onNext(response);
}

private Mapstream.MapStreamResponse buildResponse(Message message) {
return Mapstream.MapStreamResponse.newBuilder()
.setResult(Mapstream.MapStreamResponse.Result.newBuilder()
private MapOuterClass.MapResponse buildResponse(Message message) {
return MapOuterClass.MapResponse.newBuilder()
.addResults(MapOuterClass.MapResponse.Result.newBuilder()
.setValue(
message.getValue() == null ? ByteString.EMPTY : ByteString.copyFrom(
message.getValue()))
Expand All @@ -33,6 +33,5 @@ private Mapstream.MapStreamResponse buildResponse(Message message) {
.addAllTags(message.getTags()
== null ? new ArrayList<>() : List.of(message.getTags()))
.build()).build();

}
}
114 changes: 76 additions & 38 deletions src/main/java/io/numaproj/numaflow/mapstreamer/Service.java
Original file line number Diff line number Diff line change
@@ -1,64 +1,102 @@
package io.numaproj.numaflow.mapstreamer;

import com.google.protobuf.Empty;
import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import io.numaproj.numaflow.mapstream.v1.MapStreamGrpc;
import io.numaproj.numaflow.mapstream.v1.Mapstream;
import io.numaproj.numaflow.map.v1.MapGrpc;
import io.numaproj.numaflow.map.v1.MapOuterClass;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.time.Instant;

import static io.numaproj.numaflow.mapstream.v1.MapStreamGrpc.getMapStreamFnMethod;

import static io.numaproj.numaflow.map.v1.MapGrpc.getMapFnMethod;

@Slf4j
@AllArgsConstructor
class Service extends MapStreamGrpc.MapStreamImplBase {
class Service extends MapGrpc.MapImplBase {

private final MapStreamer mapStreamer;

/**
* Applies a map stream function to each request.
*/
@Override
public void mapStreamFn(
Mapstream.MapStreamRequest request,
StreamObserver<Mapstream.MapStreamResponse> responseObserver) {
public StreamObserver<MapOuterClass.MapRequest> mapFn(StreamObserver<MapOuterClass.MapResponse> responseObserver) {

if (this.mapStreamer == null) {
io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(
getMapStreamFnMethod(),
return io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall(
getMapFnMethod(),

Check warning on line 25 in src/main/java/io/numaproj/numaflow/mapstreamer/Service.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/io/numaproj/numaflow/mapstreamer/Service.java#L24-L25

Added lines #L24 - L25 were not covered by tests
responseObserver);
return;
}

HandlerDatum handlerDatum = new HandlerDatum(
request.getValue().toByteArray(),
Instant.ofEpochSecond(
request.getWatermark().getSeconds(),
request.getWatermark().getNanos()),
Instant.ofEpochSecond(
request.getEventTime().getSeconds(),
request.getEventTime().getNanos()),
request.getHeadersMap()
);
return new StreamObserver<>() {
private boolean handshakeDone = false;

// process Datum
this.mapStreamer.processMessage(request
.getKeysList()
.toArray(new String[0]), handlerDatum, new OutputObserverImpl(responseObserver));
@Override
public void onNext(MapOuterClass.MapRequest request) {
// make sure the handshake is done before processing the messages
if (!handshakeDone) {
if (!request.hasHandshake() || !request.getHandshake().getSot()) {
responseObserver.onError(Status.INVALID_ARGUMENT
.withDescription("Handshake request not received")
.asException());
return;

Check warning on line 40 in src/main/java/io/numaproj/numaflow/mapstreamer/Service.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/io/numaproj/numaflow/mapstreamer/Service.java#L37-L40

Added lines #L37 - L40 were not covered by tests
}
responseObserver.onNext(MapOuterClass.MapResponse.newBuilder()
.setHandshake(request.getHandshake())
.build());
handshakeDone = true;
return;
}

responseObserver.onCompleted();
}
try {
// process the message
mapStreamer.processMessage(
request
.getRequest()
.getKeysList()
.toArray(new String[0]),
constructHandlerDatum(request),
new OutputObserverImpl(responseObserver));
} catch (Exception e) {
log.error("Error processing message", e);
responseObserver.onError(Status.UNKNOWN
.withDescription(e.getMessage())
.asException());
return;
}

/**
* IsReady is the heartbeat endpoint for gRPC.
*/
@Override
public void isReady(Empty request, StreamObserver<Mapstream.ReadyResponse> responseObserver) {
responseObserver.onNext(Mapstream.ReadyResponse.newBuilder().setReady(true).build());
responseObserver.onCompleted();
// Send an EOT message to indicate the end of the transmission for the batch.
MapOuterClass.MapResponse eotResponse = MapOuterClass.MapResponse
.newBuilder()
.setStatus(MapOuterClass.Status.newBuilder().setEot(true).build()).build();
responseObserver.onNext(eotResponse);
}

@Override
public void onError(Throwable throwable) {
log.error("Error Encountered in mapStream Stream", throwable);
var status = Status.UNKNOWN
.withDescription(throwable.getMessage())
.withCause(throwable);
responseObserver.onError(status.asException());
}

Check warning on line 80 in src/main/java/io/numaproj/numaflow/mapstreamer/Service.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/io/numaproj/numaflow/mapstreamer/Service.java#L75-L80

Added lines #L75 - L80 were not covered by tests

@Override
public void onCompleted() {
responseObserver.onCompleted();
}
};
}

// Construct a HandlerDatum from a MapRequest
private HandlerDatum constructHandlerDatum(MapOuterClass.MapRequest d) {
return new HandlerDatum(
d.getRequest().getValue().toByteArray(),
Instant.ofEpochSecond(
d.getRequest().getWatermark().getSeconds(),
d.getRequest().getWatermark().getNanos()),
Instant.ofEpochSecond(
d.getRequest().getEventTime().getSeconds(),
d.getRequest().getEventTime().getNanos()),
d.getRequest().getHeadersMap()
);
}
}
46 changes: 0 additions & 46 deletions src/main/proto/mapstream/v1/mapstream.proto

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.numaproj.numaflow.mapstreamer;

import io.grpc.stub.StreamObserver;
import io.numaproj.numaflow.map.v1.MapOuterClass;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;

public class MapStreamOutputStreamObserver implements StreamObserver<MapOuterClass.MapResponse> {
List<MapOuterClass.MapResponse> mapResponses = new ArrayList<>();
CompletableFuture<Void> done = new CompletableFuture<>();
Integer responseCount;

public MapStreamOutputStreamObserver(Integer responseCount) {
this.responseCount = responseCount;
}

@Override
public void onNext(MapOuterClass.MapResponse mapResponse) {
System.out.println("Received response: " + mapResponse);
mapResponses.add(mapResponse);
if (mapResponses.size() == responseCount) {
done.complete(null);
}
}

@Override
public void onError(Throwable throwable) {
done.completeExceptionally(throwable);
}

@Override
public void onCompleted() {
done.complete(null);
}

public List<MapOuterClass.MapResponse> getMapResponses() {
return mapResponses;
}
}
37 changes: 26 additions & 11 deletions src/test/java/io/numaproj/numaflow/mapstreamer/ServerErrTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.testing.GrpcCleanupRule;
import io.numaproj.numaflow.mapstream.v1.MapStreamGrpc;
import io.numaproj.numaflow.mapstream.v1.Mapstream;
import io.numaproj.numaflow.map.v1.MapGrpc;
import io.numaproj.numaflow.map.v1.MapOuterClass;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;

import java.util.List;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;

public class ServerErrTest {

Expand Down Expand Up @@ -103,18 +102,34 @@ public void tearDown() throws Exception {

@Test
public void TestMapStreamerErr() {
ByteString inValue = ByteString.copyFromUtf8("invalue");
Mapstream.MapStreamRequest request = Mapstream.MapStreamRequest
MapOuterClass.MapRequest handshakeRequest = MapOuterClass.MapRequest
.newBuilder()
.addAllKeys(List.of("test-map-stream-key"))
.setValue(inValue)
.setHandshake(MapOuterClass.Handshake.newBuilder().setSot(true))
.build();

var stub = MapStreamGrpc.newBlockingStub(inProcessChannel);
ByteString inValue = ByteString.copyFromUtf8("invalue");
MapOuterClass.MapRequest request = MapOuterClass.MapRequest
.newBuilder()
.setRequest(MapOuterClass.MapRequest.Request.newBuilder()
.setValue(inValue)
.addKeys("test-map-stream-key")).build();

MapStreamOutputStreamObserver mapStreamOutputStreamObserver = new MapStreamOutputStreamObserver(
2);
var stub = MapGrpc.newStub(inProcessChannel);

var requestStreamObserver = stub
.mapFn(mapStreamOutputStreamObserver);
requestStreamObserver.onNext(handshakeRequest);
requestStreamObserver.onNext(request);

try {
stub.mapStreamFn(request);
mapStreamOutputStreamObserver.done.get();
fail("Should have thrown an exception");
} catch (Exception e) {
assertEquals("UNKNOWN: unknown exception", e.getMessage());
assertEquals(
"io.grpc.StatusRuntimeException: UNKNOWN: unknown exception",
e.getMessage());
}
}

Expand Down
Loading
Loading