Skip to content

Commit

Permalink
Add FlinkRuntime test
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasvanmol committed Dec 18, 2024
1 parent c23a33b commit d398b17
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 123 deletions.
102 changes: 87 additions & 15 deletions src/cascade/runtime/flink_runtime.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
import os
from typing import Optional, Type, Union
from pyflink.common.typeinfo import Types, get_gateway
Expand All @@ -8,7 +9,7 @@
from pyflink.datastream.connectors.kafka import KafkaOffsetsInitializer, KafkaRecordSerializationSchema, KafkaSource, KafkaSink
from pyflink.datastream import ProcessFunction, StreamExecutionEnvironment
import pickle
from cascade.dataflow.dataflow import Event, EventResult, Filter, InitClass, InvokeMethod, MergeNode, OpNode, SelectAllNode
from cascade.dataflow.dataflow import Event, EventResult, Filter, InitClass, InvokeMethod, MergeNode, Node, OpNode, SelectAllNode
from cascade.dataflow.operator import StatefulOperator
from confluent_kafka import Producer
import logging
Expand All @@ -20,6 +21,18 @@
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

@dataclass
class FlinkRegisterKeyNode(Node):
"""A node that will register a key with the SelectAll operator.
This node is specific to Flink, and will be automatically generated.
It should not be used in a `DataFlow`.
@private
"""
key: str
cls: Type

class FlinkOperator(KeyedProcessFunction):
"""Wraps an `cascade.dataflow.datflow.StatefulOperator` in a KeyedProcessFunction so that it can run in Flink.
"""
Expand All @@ -41,6 +54,17 @@ def process_element(self, event: Event, ctx: KeyedProcessFunction.Context):
# otherwise, order of variable_map matters for variable assignment
result = self.operator.handle_init_class(*event.variable_map.values())

# Register the created key in FlinkSelectAllOperator
register_key_event = Event(
FlinkRegisterKeyNode(key_stack[-1], self.operator._cls), # problem is that this id goes up when we don't rly watn it
[],
{},
None,
_id = event._id
)
logger.debug(f"FlinkOperator {event.target.cls.__name__}[{ctx.get_current_key()}]: Registering key: {register_key_event}")
yield register_key_event

# Pop this key from the key stack so that we exit
key_stack.pop()
self.state.update(pickle.dumps(result))
Expand Down Expand Up @@ -69,20 +93,36 @@ def process_element(self, event: Event, ctx: KeyedProcessFunction.Context):
logger.debug(f"FlinkOperator {event.target.cls.__name__}[{ctx.get_current_key()}]: Propogated {len(new_events)} new Events")
yield from new_events

class SelectAllOperator(ProcessFunction):
class FlinkSelectAllOperator(KeyedProcessFunction):
"""A process function that yields all keys of a certain class"""
def __init__(self, ids: dict[Type, list[str]]):
self.ids = ids
def __init__(self):
self.state: ValueState = None # type: ignore (expect state to be initialised on .open())

def open(self, runtime_context: RuntimeContext):
descriptor = ValueStateDescriptor("entity-keys", Types.PICKLED_BYTE_ARRAY()) #,Types.OBJECT_ARRAY(Types.STRING()))
self.state: ValueState = runtime_context.get_state(descriptor)

def process_element(self, event: Event, ctx: 'ProcessFunction.Context'):
assert isinstance(event.target, SelectAllNode)
logger.debug(f"SelectAllOperator {event.target.cls.__name__}: Processing: {event}")
state: list[str] = self.state.value()
if state is None:
state = []

if isinstance(event.target, FlinkRegisterKeyNode):
logger.debug(f"SelectAllOperator [{event.target.cls.__name__}]: Processing: {event}")

state.append(event.target.key)
self.state.update(state)

# yield all the hotel_ids we know about
event.key_stack.append(self.ids[event.target.cls])
new_events = event.propogate(event.key_stack, None)
logger.debug(f"SelectAll [{event.target.cls}]: Propogated {len(new_events)} events")
yield from new_events
elif isinstance(event.target, SelectAllNode):
logger.debug(f"SelectAllOperator [{event.target.cls.__name__}]: Processing: {event}")

# Yield all the keys we now about
event.key_stack.append(state)
new_events = event.propogate(event.key_stack, None)
logger.debug(f"SelectAllOperator [{event.target.cls.__name__}]: Propogated {len(new_events)} events")
yield from new_events
else:
raise Exception(f"Unexpected target for SelectAllOperator: {event.target}")

class FlinkMergeOperator(KeyedProcessFunction):
"""Flink implementation of a merge operator."""
Expand Down Expand Up @@ -235,7 +275,7 @@ def init(self, kafka_broker="localhost:9092", bundle_time=1, bundle_size=5):
)
"""Kafka sink that will be ingested again by the Flink runtime."""

stream = (
event_stream = (
self.env.from_source(
kafka_source,
WatermarkStrategy.no_watermarks(),
Expand All @@ -245,11 +285,29 @@ def init(self, kafka_broker="localhost:9092", bundle_time=1, bundle_size=5):
# .filter(lambda e: isinstance(e, Event)) # Enforced by `add_operator` type safety
)

self.stateful_op_stream = stream.filter(lambda e: isinstance(e.target, OpNode))
# Events with a `SelectAllNode` will first be processed by the select
# all operator, which will send out multiple other Events that can
# then be processed by operators in the same steam.
select_all_stream = (
event_stream.filter(lambda e:
isinstance(e.target, SelectAllNode) or isinstance(e.target, FlinkRegisterKeyNode))
.key_by(lambda e: e.target.cls)
.process(FlinkSelectAllOperator())
)
"""Stream that ingests events with an `SelectAllNode` or `FlinkRegisterKeyNode`"""
not_select_all_stream = (
event_stream.filter(lambda e:
not (isinstance(e.target, SelectAllNode) or isinstance(e.target, FlinkRegisterKeyNode)))
)

event_stream = select_all_stream.union(not_select_all_stream)


self.stateful_op_stream = event_stream.filter(lambda e: isinstance(e.target, OpNode))
"""Stream that ingests events with an `cascade.dataflow.dataflow.OpNode` target"""

self.merge_op_stream = (
stream.filter(lambda e: isinstance(e.target, MergeNode))
event_stream.filter(lambda e: isinstance(e.target, MergeNode))
.key_by(lambda e: e._id) # might not work in the future if we have multiple merges in one dataflow?
.process(FlinkMergeOperator())
)
Expand Down Expand Up @@ -289,7 +347,21 @@ def run(self, run_async=False, collect=False) -> Union[CloseableIterator, None]:
assert self.env is not None, "FlinkRuntime must first be initialised with `init()`."

# Combine all the operator streams
ds = self.merge_op_stream.union(*self.stateful_op_streams)
operator_streams = self.merge_op_stream.union(*self.stateful_op_streams)

# Add filtering for nodes with a `Filter` target
full_stream_filtered = (
operator_streams
.filter(lambda e: isinstance(e, Event) and isinstance(e.target, Filter))
.filter(lambda e: e.target.filter_fn())
)
full_stream_unfiltered = (
operator_streams
.filter(lambda e: not (isinstance(e, Event) and isinstance(e.target, Filter)))
)
ds = full_stream_filtered.union(full_stream_unfiltered)

# Output the stream
if collect:
ds_external = ds.filter(lambda e: isinstance(e, EventResult)).execute_and_collect()
else:
Expand Down
142 changes: 34 additions & 108 deletions src/cascade/runtime/test_global_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from pyflink.common import Configuration
from pyflink.datastream import ProcessFunction, StreamExecutionEnvironment
from pyflink.datastream.connectors.kafka import FlinkKafkaConsumer, FlinkKafkaProducer
from pyflink.datastream.data_stream import CloseableIterator

from cascade.dataflow.dataflow import DataFlow, Edge, Event, EventResult, Filter, InitClass, OpNode, SelectAllNode
from cascade.dataflow.operator import StatefulOperator
from cascade.runtime.flink_runtime import ByteSerializer, FlinkOperator, SelectAllOperator
from cascade.runtime.flink_runtime import ByteSerializer, FlinkOperator, FlinkRegisterKeyNode, FlinkRuntime, FlinkSelectAllOperator
from confluent_kafka import Producer
import os
import pickle # problems with pickling functions (e.g. lambdas)? use cloudpickle
import cloudpickle # problems with pickling functions (e.g. lambdas)? use cloudcloudpickle
import logging
import time

Expand All @@ -38,8 +39,8 @@ def add_kafka_source(env: StreamExecutionEnvironment, topics, broker="localhost:
kafka_consumer = FlinkKafkaConsumer(topics, deserialization_schema, properties)
return env.add_source(kafka_consumer)

def dbg(e):
# print(e)
def dbg(e, msg=""):
print(msg + str(e))
return e

@dataclass
Expand All @@ -58,6 +59,7 @@ def distance(self, loc: Geo) -> float:
def __repr__(self) -> str:
return f"Hotel({self.name}, {self.loc})"


def distance_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: list[str]) -> Any:
loc = variable_map["loc"]
return math.sqrt((state.loc.x - loc.x) ** 2 + (state.loc.y - loc.y) ** 2)
Expand All @@ -68,105 +70,36 @@ def distance_compiled(variable_map: dict[str, Any], state: Hotel, key_stack: lis
def get_nearby_predicate_compiled(variable_map: dict[str, Any], state: Hotel) -> bool:
return state.distance(variable_map["loc"]) < variable_map["dist"]

hotel_op = StatefulOperator(Hotel, {"distance": distance_compiled}, {})

def test_nearby_hotels():
runtime = FlinkRuntime("test_nearby_hotels")
runtime.init()
runtime.add_operator(FlinkOperator(hotel_op))


def test_yeeter():

hotel_op = StatefulOperator(Hotel, {"distance": distance_compiled}, {})
hotel_op = FlinkOperator(hotel_op)

# Create Hotels
hotels = []
init_hotel = OpNode(Hotel, InitClass())
random.seed(42)
for i in range(100):
for i in range(50):
coord_x = random.randint(-10, 10)
coord_y = random.randint(-10, 10)
hotels.append(Hotel(f"h_{i}", Geo(coord_x, coord_y)))

def get_nearby(loc: Geo, dist: int) -> list[Hotel]:
return [hotel for hotel in hotels if hotel.distance(loc) < dist]

# Configure the local Flink instance with the ui at http://localhost:8081
config = Configuration() # type: ignore
config.set_string("rest.port", "8081")
env = StreamExecutionEnvironment.get_execution_environment(config)

# Add the kafka producer and consumers
topic = "input-topic"
broker = "localhost:9092"
ds = add_kafka_source(env, topic)
producer = Producer({"bootstrap.servers": 'localhost:9092'})
deserialization_schema = ByteSerializer()
properties: dict = {
"bootstrap.servers": broker,
"group.id": "test_group_1",
}
kafka_external_sink = FlinkKafkaProducer("out-topic", deserialization_schema, properties)
kafka_internal_sink = FlinkKafkaProducer(topic, deserialization_schema, properties)

# Create the datastream that will handle
# - simple (single node) dataflows and,
# - init classes
stream = (
ds.map(lambda x: pickle.loads(x))
)


select_all_op = SelectAllOperator({Hotel: [hotel.name for hotel in hotels]})

select_all_stream = (
stream.filter(lambda e: isinstance(e.target, SelectAllNode))
.process(select_all_op) # yield all the hotel_ids
)

op_stream = (
stream.union(select_all_stream).filter(lambda e: isinstance(e.target, OpNode))
)


hotel_stream = (
op_stream
.filter(lambda e: e.target.cls == Hotel)
.key_by(lambda e: e.key_stack[-1])
.process(hotel_op)
)

full_stream = hotel_stream #.union...

full_stream_filtered = (
full_stream
.filter(lambda e: isinstance(e, Event))
.filter(lambda e: isinstance(e.target, Filter))
.filter(lambda e: e.target.filter_fn())
)

full_stream_unfiltered = (
full_stream
.filter(lambda e: not isinstance(e, Event) or not isinstance(e.target, Filter))
)

# have to remove items from full_stream as well??
ds = full_stream_unfiltered.union(full_stream_filtered)

# INIT HOTELS
init_hotel = OpNode(Hotel, InitClass())
for hotel in hotels:
hotel = Hotel(f"h_{i}", Geo(coord_x, coord_y))
event = Event(init_hotel, [hotel.name], {"name": hotel.name, "loc": hotel.loc}, None)
producer.produce(
topic,
value=pickle.dumps(event),
)
runtime.send(event)
hotels.append(hotel)



ds_external = ds.map(lambda e: dbg(e)).filter(lambda e: isinstance(e, EventResult)).filter(lambda e: e.event_id > 99).print() #.add_sink(kafka_external_sink)
ds_internal = ds.map(lambda e: dbg(e)).filter(lambda e: isinstance(e, Event)).map(lambda e: pickle.dumps(e)).add_sink(kafka_internal_sink)
producer.flush()
collected_iterator: CloseableIterator = runtime.run(run_async=True, collect=True)
records = []
def wait_for_event_id(id: int) -> EventResult:
for record in collected_iterator:
records.append(record)
print(f"Collected record: {record}")
if record.event_id == id:
return record

env.execute_async()

print("sleepin")
time.sleep(2)
# Wait for hotels to be created
wait_for_event_id(event._id)

# GET NEARBY
# dataflow for getting all hotels within region
Expand All @@ -178,22 +111,15 @@ def get_nearby(loc: Geo, dist: int) -> list[Hotel]:
dist = 5
loc = Geo(0, 0)
event = Event(n0, [], {"loc": loc, "dist": dist}, df)
producer.produce(
topic,
value=pickle.dumps(event),
)

runtime.send(event, flush=True)

nearby = []
for hotel in hotels:
if hotel.distance(loc) < dist:
nearby.append(hotel.name)
print(nearby)
# ok thats pretty good. But now we need to solve the problem of merging
# an arbitray number of nodes. but like we naturally want to merge as late
# as possible, right? ideally we want to process results in a streaming
# fashion

# I want another example that does something after filtering,
# for example buying all items less than 10 price
input()

sol = wait_for_event_id(event._id)
print(nearby)
print(sol)
print(records)
assert sol.result in nearby

0 comments on commit d398b17

Please sign in to comment.