From e4e9ef2a0829d51adc675301929f955ed8acc483 Mon Sep 17 00:00:00 2001 From: Keran Yang Date: Fri, 3 Feb 2023 10:48:17 -0500 Subject: [PATCH] feat: add source transformer SDK support (#17) Signed-off-by: Keran Yang --- .../function/evenodd/EvenOddFunction.java | 1 - .../EventTimeFilterFunction.java | 52 +++++++++++++++ .../function/flatmap/FlatMapFunction.java | 1 - .../numaflow/function/FunctionServer.java | 6 ++ .../numaflow/function/FunctionService.java | 59 ++++++++++++++++- .../numaproj/numaflow/function/MessageT.java | 36 +++++++++++ .../numaflow/function/mapt/MapTFunc.java | 23 +++++++ .../numaflow/function/mapt/MapTHandler.java | 13 ++++ src/main/proto/function/v1/udfunction.proto | 11 +++- .../numaflow/function/FunctionServerTest.java | 63 ++++++++++++++++--- 10 files changed, 249 insertions(+), 16 deletions(-) create mode 100644 examples/src/main/java/io/numaproj/numaflow/examples/function/eventtimefilter/EventTimeFilterFunction.java create mode 100644 src/main/java/io/numaproj/numaflow/function/MessageT.java create mode 100644 src/main/java/io/numaproj/numaflow/function/mapt/MapTFunc.java create mode 100644 src/main/java/io/numaproj/numaflow/function/mapt/MapTHandler.java diff --git a/examples/src/main/java/io/numaproj/numaflow/examples/function/evenodd/EvenOddFunction.java b/examples/src/main/java/io/numaproj/numaflow/examples/function/evenodd/EvenOddFunction.java index e1406d45..303d8fd1 100644 --- a/examples/src/main/java/io/numaproj/numaflow/examples/function/evenodd/EvenOddFunction.java +++ b/examples/src/main/java/io/numaproj/numaflow/examples/function/evenodd/EvenOddFunction.java @@ -26,7 +26,6 @@ private static Message[] process(String key, Datum data) { } public static void main(String[] args) throws IOException { - logger.info("Forward invoked"); new FunctionServer().registerMapper(new MapFunc(EvenOddFunction::process)).start(); } } diff --git a/examples/src/main/java/io/numaproj/numaflow/examples/function/eventtimefilter/EventTimeFilterFunction.java b/examples/src/main/java/io/numaproj/numaflow/examples/function/eventtimefilter/EventTimeFilterFunction.java new file mode 100644 index 00000000..eac89927 --- /dev/null +++ b/examples/src/main/java/io/numaproj/numaflow/examples/function/eventtimefilter/EventTimeFilterFunction.java @@ -0,0 +1,52 @@ +package io.numaproj.numaflow.examples.function.eventtimefilter; + +import io.numaproj.numaflow.function.Datum; +import io.numaproj.numaflow.function.FunctionServer; +import io.numaproj.numaflow.function.MessageT; +import io.numaproj.numaflow.function.mapt.MapTFunc; + +import java.io.IOException; +import java.time.Instant; +import java.util.logging.Logger; + +/** + * This is a simple User Defined Function example which receives a message, applies the following + * data transformation, and returns the message. + *

+ * If the message event time is before year 2022, drop the message. If it's within year 2022, update + * the key to "within_year_2022" and update the message event time to Jan 1st 2022. + * Otherwise, (exclusively after year 2022), update the key to "after_year_2022" and update the + * message event time to Jan 1st 2023. + */ +public class EventTimeFilterFunction { + + private static final Logger logger = Logger.getLogger(EventTimeFilterFunction.class.getName()); + private static final Instant januaryFirst2022 = Instant.ofEpochMilli(1640995200000L); + private static final Instant januaryFirst2023 = Instant.ofEpochMilli(1672531200000L); + + private static MessageT[] process(String key, Datum data) { + Instant eventTime = data.getEventTime(); + + if (eventTime.isBefore(januaryFirst2022)) { + return new MessageT[]{MessageT.toDrop()}; + } else if (eventTime.isBefore(januaryFirst2023)) { + return new MessageT[]{ + MessageT.to( + januaryFirst2022, + "within_year_2022", + data.getValue())}; + } else { + return new MessageT[]{ + MessageT.to( + januaryFirst2023, + "after_year_2022", + data.getValue())}; + } + } + + public static void main(String[] args) throws IOException { + new FunctionServer() + .registerMapperT(new MapTFunc(EventTimeFilterFunction::process)) + .start(); + } +} diff --git a/examples/src/main/java/io/numaproj/numaflow/examples/function/flatmap/FlatMapFunction.java b/examples/src/main/java/io/numaproj/numaflow/examples/function/flatmap/FlatMapFunction.java index ab186b2c..ef3285ba 100644 --- a/examples/src/main/java/io/numaproj/numaflow/examples/function/flatmap/FlatMapFunction.java +++ b/examples/src/main/java/io/numaproj/numaflow/examples/function/flatmap/FlatMapFunction.java @@ -23,7 +23,6 @@ private static Message[] process(String key, Datum data) { } public static void main(String[] args) throws IOException { - logger.info("Flatmap invoked"); new FunctionServer().registerMapper(new MapFunc(FlatMapFunction::process)).start(); } } diff --git a/src/main/java/io/numaproj/numaflow/function/FunctionServer.java b/src/main/java/io/numaproj/numaflow/function/FunctionServer.java index f35dff28..9eb6ae6a 100644 --- a/src/main/java/io/numaproj/numaflow/function/FunctionServer.java +++ b/src/main/java/io/numaproj/numaflow/function/FunctionServer.java @@ -14,6 +14,7 @@ import io.netty.channel.unix.DomainSocketAddress; import io.numaproj.numaflow.common.GrpcServerConfig; import io.numaproj.numaflow.function.map.MapHandler; +import io.numaproj.numaflow.function.mapt.MapTHandler; import io.numaproj.numaflow.function.reduce.ReduceHandler; import java.io.IOException; @@ -63,6 +64,11 @@ public FunctionServer registerMapper(MapHandler mapHandler) { return this; } + public FunctionServer registerMapperT(MapTHandler mapTHandler) { + this.functionService.setMapTHandler(mapTHandler); + return this; + } + public FunctionServer registerReducer(ReduceHandler reduceHandler) { this.functionService.setReduceHandler(reduceHandler); return this; diff --git a/src/main/java/io/numaproj/numaflow/function/FunctionService.java b/src/main/java/io/numaproj/numaflow/function/FunctionService.java index 03f159c2..2a7c093d 100644 --- a/src/main/java/io/numaproj/numaflow/function/FunctionService.java +++ b/src/main/java/io/numaproj/numaflow/function/FunctionService.java @@ -4,6 +4,7 @@ import com.google.protobuf.Empty; import io.grpc.stub.StreamObserver; import io.numaproj.numaflow.function.map.MapHandler; +import io.numaproj.numaflow.function.mapt.MapTHandler; import io.numaproj.numaflow.function.metadata.IntervalWindow; import io.numaproj.numaflow.function.metadata.IntervalWindowImpl; import io.numaproj.numaflow.function.metadata.Metadata; @@ -12,6 +13,7 @@ import io.numaproj.numaflow.function.reduce.ReduceDatumStreamImpl; import io.numaproj.numaflow.function.reduce.ReduceHandler; import io.numaproj.numaflow.function.v1.Udfunction; +import io.numaproj.numaflow.function.v1.Udfunction.EventTime; import io.numaproj.numaflow.function.v1.UserDefinedFunctionGrpc; import java.time.Instant; @@ -40,6 +42,7 @@ class FunctionService extends UserDefinedFunctionGrpc.UserDefinedFunctionImplBas private final long SHUTDOWN_TIME = 30; private MapHandler mapHandler; + private MapTHandler mapTHandler; private ReduceHandler reduceHandler; public FunctionService() { @@ -49,6 +52,10 @@ public void setMapHandler(MapHandler mapHandler) { this.mapHandler = mapHandler; } + public void setMapTHandler(MapTHandler mapTHandler) { + this.mapTHandler = mapTHandler; + } + public void setReduceHandler(ReduceHandler reduceHandler) { this.reduceHandler = reduceHandler; } @@ -78,7 +85,8 @@ public void mapFn( request.getWatermark().getWatermark().getNanos()), Instant.ofEpochSecond( request.getEventTime().getEventTime().getSeconds(), - request.getEventTime().getEventTime().getNanos()), false); + request.getEventTime().getEventTime().getNanos()), + false); // process Datum Message[] messages = mapHandler.HandleDo(key, handlerDatum); @@ -88,6 +96,39 @@ public void mapFn( responseObserver.onCompleted(); } + @Override + public void mapTFn( + Udfunction.Datum request, + StreamObserver responseObserver) { + if (this.mapTHandler == null) { + io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall( + getMapFnMethod(), + responseObserver); + return; + } + + // get key from gPRC metadata + String key = Function.DATUM_CONTEXT_KEY.get(); + + // get Datum from request + HandlerDatum handlerDatum = new HandlerDatum( + request.getValue().toByteArray(), + Instant.ofEpochSecond( + request.getWatermark().getWatermark().getSeconds(), + request.getWatermark().getWatermark().getNanos()), + Instant.ofEpochSecond( + request.getEventTime().getEventTime().getSeconds(), + request.getEventTime().getEventTime().getNanos()), + false); + + // process Datum + MessageT[] messageTs = mapTHandler.HandleDo(key, handlerDatum); + + // set response + responseObserver.onNext(buildDatumListResponse(messageTs)); + responseObserver.onCompleted(); + } + /** * Streams input data to reduceFn and returns the result. */ @@ -227,4 +268,20 @@ private Udfunction.DatumList buildDatumListResponse(Message[] messages) { }); return datumListBuilder.build(); } + + private Udfunction.DatumList buildDatumListResponse(MessageT[] messageTs) { + Udfunction.DatumList.Builder datumListBuilder = Udfunction.DatumList.newBuilder(); + Arrays.stream(messageTs).forEach(messageT -> { + datumListBuilder.addElements(Udfunction.Datum.newBuilder() + .setEventTime(EventTime.newBuilder().setEventTime + (com.google.protobuf.Timestamp.newBuilder() + .setSeconds(messageT.getEventTime().getEpochSecond()) + .setNanos(messageT.getEventTime().getNano())) + ) + .setKey(messageT.getKey()) + .setValue(ByteString.copyFrom(messageT.getValue())) + .build()); + }); + return datumListBuilder.build(); + } } diff --git a/src/main/java/io/numaproj/numaflow/function/MessageT.java b/src/main/java/io/numaproj/numaflow/function/MessageT.java new file mode 100644 index 00000000..50d642c3 --- /dev/null +++ b/src/main/java/io/numaproj/numaflow/function/MessageT.java @@ -0,0 +1,36 @@ +package io.numaproj.numaflow.function; + +import static io.numaproj.numaflow.function.Message.ALL; +import static io.numaproj.numaflow.function.Message.DROP; + +import java.time.Instant; +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * MessageT is used to wrap the data return by UDF functions. Compared with Message, MessageT + * contains one more field, the event time, usually extracted from the payload. + */ +@AllArgsConstructor +@Getter +public class MessageT { + + private Instant eventTime; + private final String key; + private final byte[] value; + + // creates a MessageT to be dropped + public static MessageT toDrop() { + return new MessageT(Instant.MIN, DROP, new byte[0]); + } + + // creates a MessageT that will forward to all + public static MessageT toAll(Instant eventTime, byte[] value) { + return new MessageT(eventTime, ALL, value); + } + + // creates a MessageT that will forward to specified "to" + public static MessageT to(Instant eventTime, String to, byte[] value) { + return new MessageT(eventTime, to, value); + } +} diff --git a/src/main/java/io/numaproj/numaflow/function/mapt/MapTFunc.java b/src/main/java/io/numaproj/numaflow/function/mapt/MapTFunc.java new file mode 100644 index 00000000..61fb8b83 --- /dev/null +++ b/src/main/java/io/numaproj/numaflow/function/mapt/MapTFunc.java @@ -0,0 +1,23 @@ +package io.numaproj.numaflow.function.mapt; + +import io.numaproj.numaflow.function.Datum; +import io.numaproj.numaflow.function.MessageT; + +import java.util.function.BiFunction; + +/** + * Implementation of MapTHandler instantiated from a function + */ +public class MapTFunc implements MapTHandler { + + private final BiFunction mapTFn; + + public MapTFunc(BiFunction mapTFn) { + this.mapTFn = mapTFn; + } + + @Override + public MessageT[] HandleDo(String key, Datum datum) { + return mapTFn.apply(key, datum); + } +} diff --git a/src/main/java/io/numaproj/numaflow/function/mapt/MapTHandler.java b/src/main/java/io/numaproj/numaflow/function/mapt/MapTHandler.java new file mode 100644 index 00000000..d5dbe30d --- /dev/null +++ b/src/main/java/io/numaproj/numaflow/function/mapt/MapTHandler.java @@ -0,0 +1,13 @@ +package io.numaproj.numaflow.function.mapt; + +import io.numaproj.numaflow.function.Datum; +import io.numaproj.numaflow.function.MessageT; + +/** + * Interface of mapT function implementation. + */ +public interface MapTHandler { + + // Function to process each coming message + MessageT[] HandleDo(String key, Datum datum); +} diff --git a/src/main/proto/function/v1/udfunction.proto b/src/main/proto/function/v1/udfunction.proto index 6b524291..48c059ef 100644 --- a/src/main/proto/function/v1/udfunction.proto +++ b/src/main/proto/function/v1/udfunction.proto @@ -8,10 +8,15 @@ import "google/protobuf/empty.proto"; package function.v1; service UserDefinedFunction { - // Applies a function to each datum element. + // MapFn applies a function to each datum element. rpc MapFn(Datum) returns (DatumList); - // Applies a reduce function to a datum stream. + // MapTFn applies a function to each datum element. + // In addition to map function, MapTFn also supports assigning a new event time to datum. + // MapTFn can be used only at source vertex by source data transformer. + rpc MapTFn(Datum) returns (DatumList); + + // ReduceFn applies a reduce function to a datum stream. rpc ReduceFn(stream Datum) returns (DatumList); // IsReady is the heartbeat endpoint for gRPC. @@ -53,4 +58,4 @@ message DatumList { */ message ReadyResponse { bool ready = 1; -} \ No newline at end of file +} diff --git a/src/test/java/io/numaproj/numaflow/function/FunctionServerTest.java b/src/test/java/io/numaproj/numaflow/function/FunctionServerTest.java index 747d5630..a6721b74 100644 --- a/src/test/java/io/numaproj/numaflow/function/FunctionServerTest.java +++ b/src/test/java/io/numaproj/numaflow/function/FunctionServerTest.java @@ -10,9 +10,11 @@ import io.grpc.testing.GrpcCleanupRule; import io.numaproj.numaflow.common.GrpcServerConfig; import io.numaproj.numaflow.function.map.MapFunc; +import io.numaproj.numaflow.function.mapt.MapTFunc; import io.numaproj.numaflow.function.reduce.ReduceDatumStream; import io.numaproj.numaflow.function.reduce.ReduceFunc; import io.numaproj.numaflow.function.v1.Udfunction; +import io.numaproj.numaflow.function.v1.Udfunction.EventTime; import io.numaproj.numaflow.function.v1.UserDefinedFunctionGrpc; import io.numaproj.numaflow.utils.TriFunction; import org.junit.After; @@ -22,6 +24,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import java.time.Instant; import java.util.function.BiFunction; import java.util.logging.Logger; @@ -33,17 +36,25 @@ @RunWith(JUnit4.class) public class FunctionServerTest { private static final Logger logger = Logger.getLogger(FunctionServerTest.class.getName()); - private final static String processedKeySuffix = "-key-processed"; + private final static String PROCESSED_KEY_SUFFIX = "-key-processed"; + private final static String REDUCE_PROCESSED_KEY_SUFFIX = "-processed-sum"; + private final static String PROCESSED_VALUE_SUFFIX = "-value-processed"; + private final static Instant TEST_EVENT_TIME = Instant.MIN; - private final static String reduceProcessedKeySuffix = "-processed-sum"; - private final static String processedValueSuffix = "-value-processed"; @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); private final BiFunction testMapFn = (key, datum) -> new Message[]{new Message( - key + processedKeySuffix, + key + PROCESSED_KEY_SUFFIX, (new String(datum.getValue()) - + processedValueSuffix).getBytes())}; + + PROCESSED_VALUE_SUFFIX).getBytes())}; + + private final BiFunction testMapTFn = + (key, datum) -> new MessageT[]{new MessageT( + TEST_EVENT_TIME, + key + PROCESSED_KEY_SUFFIX, + (new String(datum.getValue()) + + PROCESSED_VALUE_SUFFIX).getBytes())}; private final TriFunction testReduceFn = ((key, reduceChannel, md) -> { @@ -61,7 +72,7 @@ public class FunctionServerTest { } } return new Message[]{Message.to( - key + reduceProcessedKeySuffix, + key + REDUCE_PROCESSED_KEY_SUFFIX, String.valueOf(sum).getBytes())}; }); @@ -76,6 +87,7 @@ public void setUp() throws Exception { new GrpcServerConfig(Function.SOCKET_PATH, Function.DEFAULT_MESSAGE_SIZE)); server .registerMapper(new MapFunc(testMapFn)) + .registerMapperT(new MapTFunc(testMapTFn)) .registerReducer(new ReduceFunc(testReduceFn)) .start(); inProcessChannel = grpcCleanup.register(InProcessChannelBuilder @@ -98,8 +110,8 @@ public void mapper() { .setValue(inValue) .build(); - String expectedKey = "inkey" + processedKeySuffix; - ByteString expectedValue = ByteString.copyFromUtf8("invalue" + processedValueSuffix); + String expectedKey = "inkey" + PROCESSED_KEY_SUFFIX; + ByteString expectedValue = ByteString.copyFromUtf8("invalue" + PROCESSED_VALUE_SUFFIX); Metadata metadata = new Metadata(); metadata.put(Metadata.Key.of(DATUM_KEY, Metadata.ASCII_STRING_MARSHALLER), "inkey"); @@ -114,6 +126,37 @@ public void mapper() { assertEquals(expectedValue, actualDatumList.getElements(0).getValue()); } + @Test + public void mapperT() { + ByteString inValue = ByteString.copyFromUtf8("invalue"); + Udfunction.Datum inDatum = Udfunction.Datum + .newBuilder() + .setKey("not-my-key") + .setValue(inValue) + .build(); + + String expectedKey = "inkey" + PROCESSED_KEY_SUFFIX; + ByteString expectedValue = ByteString.copyFromUtf8("invalue" + PROCESSED_VALUE_SUFFIX); + + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of(DATUM_KEY, Metadata.ASCII_STRING_MARSHALLER), "inkey"); + + var stub = UserDefinedFunctionGrpc.newBlockingStub(inProcessChannel); + var actualDatumList = stub + .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata)) + .mapTFn(inDatum); + + assertEquals(1, actualDatumList.getElementsCount()); + assertEquals( + EventTime.newBuilder().setEventTime( + com.google.protobuf.Timestamp.newBuilder() + .setSeconds(TEST_EVENT_TIME.getEpochSecond()) + .setNanos(TEST_EVENT_TIME.getNano())).build(), + actualDatumList.getElements(0).getEventTime()); + assertEquals(expectedKey, actualDatumList.getElements(0).getKey()); + assertEquals(expectedValue, actualDatumList.getElements(0).getValue()); + } + @Test public void reducerWithOneKey() { String reduceKey = "reduce-key"; @@ -141,7 +184,7 @@ public void reducerWithOneKey() { inputStreamObserver.onCompleted(); - String expectedKey = reduceKey + reduceProcessedKeySuffix; + String expectedKey = reduceKey + REDUCE_PROCESSED_KEY_SUFFIX; // sum of first 10 numbers 1 to 10 -> 55 ByteString expectedValue = ByteString.copyFromUtf8(String.valueOf(55)); @@ -183,7 +226,7 @@ public void reducerWithMultipleKey() { inputStreamObserver.onCompleted(); - String expectedKey = reduceKey + reduceProcessedKeySuffix; + String expectedKey = reduceKey + REDUCE_PROCESSED_KEY_SUFFIX; // sum of first 10 numbers 1 to 10 -> 55 ByteString expectedValue = ByteString.copyFromUtf8(String.valueOf(55));