Skip to content

Commit

Permalink
feat: support for multi keys (numaproj#30)
Browse files Browse the repository at this point in the history
Signed-off-by: Yashash H L <[email protected]>
  • Loading branch information
yhl25 authored Apr 5, 2023
1 parent 8b0a22c commit fd9b5ce
Show file tree
Hide file tree
Showing 29 changed files with 108 additions and 124 deletions.
2 changes: 1 addition & 1 deletion examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<dependency>
<groupId>io.numaproj.numaflow</groupId>
<artifactId>numaflow-java</artifactId>
<version>0.3.4</version>
<version>0.4.0</version>
</dependency>
</dependencies>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@

/**
* This is a simple User Defined Function example which receives a message,
* and attaches a key to the message based on the value, if the value is even
* and attaches keys to the message based on the value, if the value is even
* the key will be set as "even" if the value is odd the key will be set as
* "odd"
*/

@Slf4j
public class EvenOddFunction extends MapHandler {

public Message[] processMessage(String key, Datum data) {
public Message[] processMessage(String[] keys, Datum data) {
int value = 0;
try {
value = Integer.parseInt(new String(data.getValue()));
Expand All @@ -27,9 +27,9 @@ public Message[] processMessage(String key, Datum data) {
return new Message[]{Message.toDrop()};
}
if (value % 2 == 0) {
return new Message[]{Message.to("even", data.getValue())};
return new Message[]{Message.to(new String[]{"even"}, data.getValue())};
}
return new Message[]{Message.to("odd", data.getValue())};
return new Message[]{Message.to(new String[]{"odd"}, data.getValue())};
}

public static void main(String[] args) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class EventTimeFilterFunction extends MapTHandler {
private static final Instant januaryFirst2022 = Instant.ofEpochMilli(1640995200000L);
private static final Instant januaryFirst2023 = Instant.ofEpochMilli(1672531200000L);

public MessageT[] processMessage(String key, Datum data) {
public MessageT[] processMessage(String[] keys, Datum data) {
Instant eventTime = data.getEventTime();

if (eventTime.isBefore(januaryFirst2022)) {
Expand All @@ -31,13 +31,13 @@ public MessageT[] processMessage(String key, Datum data) {
return new MessageT[]{
MessageT.to(
januaryFirst2022,
"within_year_2022",
new String[]{"within_year_2022"},
data.getValue())};
} else {
return new MessageT[]{
MessageT.to(
januaryFirst2023,
"after_year_2022",
new String[]{"after_year_2022"},
data.getValue())};
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

public class FlatMapFunction extends MapHandler {

public Message[] processMessage(String key, Datum data) {
public Message[] processMessage(String[] keys, Datum data) {
String msg = new String(data.getValue());
String[] strs = msg.split(",");
Message[] results = new Message[strs.length];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
*/

public class ForwardFunction extends MapHandler {
public Message[] processMessage(String key, Datum data) {
public Message[] processMessage(String[] keys, Datum data) {
return new Message[]{Message.toAll(data.getValue())};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import lombok.extern.slf4j.Slf4j;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;

@Slf4j
Expand All @@ -33,7 +34,7 @@ public EvenOddCounter(Config config) {
}

@Override
public void addMessage(String key, Datum datum, Metadata md) {
public void addMessage(String[] keys, Datum datum, Metadata md) {
try {
int val = Integer.parseInt(new String(datum.getValue()));
// increment based on the value specified in the config
Expand All @@ -48,18 +49,18 @@ public void addMessage(String key, Datum datum, Metadata md) {
}

@Override
public Message[] getOutput(String key, Metadata md) {
public Message[] getOutput(String[] keys, Metadata md) {
log.info(
"even and odd count - {} {}, window - {} {}",
evenCount,
oddCount,
md.getIntervalWindow().getStartTime().toString(),
md.getIntervalWindow().getEndTime().toString());

if (Objects.equals(key, "even")) {
return new Message[]{Message.to(key, String.valueOf(evenCount).getBytes())};
if (Arrays.equals(keys, new String[]{"even"})) {
return new Message[]{Message.to(keys, String.valueOf(evenCount).getBytes())};
} else {
return new Message[]{Message.to(key, String.valueOf(oddCount).getBytes())};
return new Message[]{Message.to(keys, String.valueOf(oddCount).getBytes())};
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class SumFunction extends ReduceHandler {
private int sum = 0;

@Override
public void addMessage(String key, Datum datum, Metadata md) {
public void addMessage(String[] keys, Datum datum, Metadata md) {
try {
sum += Integer.parseInt(new String(datum.getValue()));
} catch (NumberFormatException e) {
Expand All @@ -22,7 +22,7 @@ public void addMessage(String key, Datum datum, Metadata md) {
}

@Override
public Message[] getOutput(String key, Metadata md) {
public Message[] getOutput(String[] keys, Metadata md) {
return new Message[]{Message.toAll(String.valueOf(sum).getBytes())};
}
}
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>io.numaproj.numaflow</groupId>
<artifactId>numaflow-java</artifactId>
<version>0.3.4</version>
<version>0.4.0</version>
<packaging>jar</packaging>

<name>numaflow-java</name>
Expand Down
10 changes: 1 addition & 9 deletions src/main/java/io/numaproj/numaflow/function/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ public class Function {

public static final int DEFAULT_MESSAGE_SIZE = 1024 * 1024 * 4;

public static final String DATUM_KEY = "x-numaflow-datum-key";

public static final String WIN_START_KEY = "x-numaflow-win-start-time";

public static final String WIN_END_KEY = "x-numaflow-win-end-time";
Expand All @@ -18,9 +16,7 @@ public class Function {

public static final String SUCCESS = "SUCCESS";

public static final Metadata.Key<String> DATUM_METADATA_KEY = Metadata.Key.of(
Function.DATUM_KEY,
Metadata.ASCII_STRING_MARSHALLER);
public static final String DELIMITTER = ":";

public static final Metadata.Key<String> DATUM_METADATA_WIN_START = Metadata.Key.of(
Function.WIN_START_KEY,
Expand All @@ -30,10 +26,6 @@ public class Function {
Function.WIN_END_KEY,
Metadata.ASCII_STRING_MARSHALLER);

public static final Context.Key<String> DATUM_CONTEXT_KEY = Context.keyWithDefault(
Function.DATUM_KEY,
"");

public static final Context.Key<String> WINDOW_START_TIME = Context.keyWithDefault(
Function.WIN_START_KEY,
"");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(

final var context =
Context.current().withValues(
Function.DATUM_CONTEXT_KEY,
headers.get(Function.DATUM_METADATA_KEY),
Function.WINDOW_START_TIME,
headers.get(Function.DATUM_METADATA_WIN_START),
Function.WINDOW_END_TIME,
Expand Down
18 changes: 6 additions & 12 deletions src/main/java/io/numaproj/numaflow/function/FunctionService.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import akka.actor.ActorRef;
import akka.actor.ActorSystem;
import akka.actor.AllDeadLetters;
import akka.actor.DeadLetter;
import com.google.protobuf.ByteString;
import com.google.protobuf.Empty;
import io.grpc.stub.StreamObserver;
Expand All @@ -25,6 +24,7 @@

import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;

import static io.numaproj.numaflow.function.Function.EOF;
Expand Down Expand Up @@ -67,9 +67,6 @@ public void mapFn(
return;
}

// get key from gPRC metadata
String key = Function.DATUM_CONTEXT_KEY.get();

// get Datum from request
HandlerDatum handlerDatum = new HandlerDatum(
request.getValue().toByteArray(),
Expand All @@ -82,7 +79,7 @@ public void mapFn(
);

// process Datum
Message[] messages = mapHandler.processMessage(key, handlerDatum);
Message[] messages = mapHandler.processMessage(request.getKeysList().toArray(new String[0]), handlerDatum);

// set response
responseObserver.onNext(buildDatumListResponse(messages));
Expand All @@ -101,9 +98,6 @@ public void mapTFn(
return;
}

// get key from gPRC metadata
String key = Function.DATUM_CONTEXT_KEY.get();

// get Datum from request
HandlerDatum handlerDatum = new HandlerDatum(
request.getValue().toByteArray(),
Expand All @@ -116,7 +110,7 @@ public void mapTFn(
);

// process Datum
MessageT[] messageTs = mapTHandler.processMessage(key, handlerDatum);
MessageT[] messageTs = mapTHandler.processMessage(request.getKeysList().toArray(new String[0]), handlerDatum);

// set response
responseObserver.onNext(buildDatumListResponse(messageTs));
Expand Down Expand Up @@ -159,7 +153,7 @@ public StreamObserver<Udfunction.Datum> reduceFn(final StreamObserver<Udfunction
handleFailure(failureFuture);
/*
create a supervisor actor which assign the tasks to child actors.
we create a child actor for every key in a window.
we create a child actor for every unique set of keys in a window
*/
ActorRef supervisorActor = actorSystem
.actorOf(ReduceSupervisorActor.props(reducerFactory, md, shutdownActorRef, responseObserver));
Expand Down Expand Up @@ -204,7 +198,7 @@ private Udfunction.DatumList buildDatumListResponse(Message[] messages) {
Udfunction.DatumList.Builder datumListBuilder = Udfunction.DatumList.newBuilder();
Arrays.stream(messages).forEach(message -> {
datumListBuilder.addElements(Udfunction.Datum.newBuilder()
.setKey(message.getKey())
.addAllKeys(List.of(message.getKeys()))
.setValue(ByteString.copyFrom(message.getValue()))
.build());
});
Expand All @@ -220,7 +214,7 @@ private Udfunction.DatumList buildDatumListResponse(MessageT[] messageTs) {
.setSeconds(messageT.getEventTime().getEpochSecond())
.setNanos(messageT.getEventTime().getNano()))
)
.setKey(messageT.getKey())
.addAllKeys(List.of(messageT.getKeys()))
.setValue(ByteString.copyFrom(messageT.getValue()))
.build());
});
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/io/numaproj/numaflow/function/Message.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@ public class Message {
public static final String ALL = "U+005C__ALL__";
public static final String DROP = "U+005C__DROP__";

private final String key;
private final String[] keys;
private final byte[] value;

// creates a Message to be dropped
public static Message toDrop() {
return new Message(DROP, new byte[0]);
return new Message(new String[]{DROP}, new byte[0]);
}

// creates a Message that will forward to all
public static Message toAll(byte[] value) {
return new Message(ALL, value);
return new Message(new String[]{ALL}, value);
}

// creates a Message that will forward to specified "to"
public static Message to(String to, byte[] value) {
public static Message to(String[] to, byte[] value) {
return new Message(to, value);
}
}
8 changes: 4 additions & 4 deletions src/main/java/io/numaproj/numaflow/function/MessageT.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@
public class MessageT {

private Instant eventTime;
private final String key;
private final String[] keys;
private final byte[] value;

// creates a MessageT to be dropped
public static MessageT toDrop() {
return new MessageT(Instant.MIN, DROP, new byte[0]);
return new MessageT(Instant.MIN, new String[]{DROP}, new byte[0]);
}

// creates a MessageT that will forward to all
public static MessageT toAll(Instant eventTime, byte[] value) {
return new MessageT(eventTime, ALL, value);
return new MessageT(eventTime, new String[]{ALL}, value);
}

// creates a MessageT that will forward to specified "to"
public static MessageT to(Instant eventTime, String to, byte[] value) {
public static MessageT to(Instant eventTime, String[] to, byte[] value) {
return new MessageT(eventTime, to, value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ public abstract class MapHandler {
processMessage will be invoked for each input message.
this method will be used for processing messages
*/
public abstract Message[] processMessage(String key, Datum datum);
public abstract Message[] processMessage(String[] key, Datum datum);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ public abstract class MapTHandler {
this method will be used for processing and transforming
the messages
*/
public abstract MessageT[] processMessage(String key, Datum datum);
public abstract MessageT[] processMessage(String[] keys, Datum datum);

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@
@Getter
@AllArgsConstructor
public class ActorResponse {
String key;
String[] keys;
Udfunction.DatumList datumList;
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import lombok.extern.slf4j.Slf4j;

import java.util.Arrays;
import java.util.List;

/**
* Reduce actor invokes the user defined code and returns the result.
Expand All @@ -21,12 +22,12 @@
@AllArgsConstructor
public class ReduceActor extends AbstractActor {

private String key;
private String[] keys;
private Metadata md;
private ReduceHandler groupBy;

public static Props props(String key, Metadata md, ReduceHandler groupBy) {
return Props.create(ReduceActor.class, key, md, groupBy);
public static Props props(String[] keys, Metadata md, ReduceHandler groupBy) {
return Props.create(ReduceActor.class, keys, md, groupBy);
}

@Override
Expand All @@ -39,11 +40,11 @@ public Receive createReceive() {
}

private void invokeHandler(HandlerDatum handlerDatum) {
this.groupBy.addMessage(key, handlerDatum, md);
this.groupBy.addMessage(keys, handlerDatum, md);
}

private void getResult(String eof) {
Message[] resultMessages = this.groupBy.getOutput(key, md);
Message[] resultMessages = this.groupBy.getOutput(keys, md);
// send the result back to sender(parent actor)
getSender().tell(buildDatumListResponse(resultMessages), getSelf());
}
Expand All @@ -52,11 +53,11 @@ private ActorResponse buildDatumListResponse(Message[] messages) {
Udfunction.DatumList.Builder datumListBuilder = Udfunction.DatumList.newBuilder();
Arrays.stream(messages).forEach(message -> {
datumListBuilder.addElements(Udfunction.Datum.newBuilder()
.setKey(message.getKey())
.addAllKeys(List.of(message.getKeys()))
.setValue(ByteString.copyFrom(message.getValue()))
.build());
});
return new ActorResponse(this.key, datumListBuilder.build());
return new ActorResponse(this.keys, datumListBuilder.build());
}

}
Expand Down
Loading

0 comments on commit fd9b5ce

Please sign in to comment.