Skip to content

Commit

Permalink
Fix dataflow test
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasvanmol committed Jan 21, 2025
1 parent 438f1da commit 610104c
Showing 1 changed file with 38 additions and 30 deletions.
68 changes: 38 additions & 30 deletions src/cascade/dataflow/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@ def buy_item(self, item: 'DummyItem') -> bool:
self.balance -= item_price
return self.balance >= 0

def buy_item_0_compiled(variable_map: dict[str, Any], state: DummyUser, key_stack: list[str]) -> dict[str, Any]:
key_stack.append(variable_map["item_key"])
def buy_item_0_compiled(variable_map: dict[str, Any], state: DummyUser):
return

def buy_item_1_compiled(variable_map: dict[str, Any], state: DummyUser, key_stack: list[str]) -> dict[str, Any]:
key_stack.pop()
def buy_item_1_compiled(variable_map: dict[str, Any], state: DummyUser):
state.balance -= variable_map["item_price"]
return {"user_postive_balance": state.balance >= 0}
return state.balance >= 0

class DummyItem:
def __init__(self, key: str, price: int):
Expand All @@ -29,10 +27,8 @@ def __init__(self, key: str, price: int):
def get_price(self) -> int:
return self.price

def get_price_compiled(variable_map: dict[str, Any], state: DummyItem, key_stack: list[str]) -> dict[str, Any]:
key_stack.pop() # final function
variable_map["item_price"] = state.price
# return {"item_price": state.price}
def get_price_compiled(variable_map: dict[str, Any], state: DummyItem):
return state.price

################## TESTS #######################

Expand All @@ -46,79 +42,91 @@ def get_price_compiled(variable_map: dict[str, Any], state: DummyItem, key_stack

def test_simple_df_propogation():
df = DataFlow("user.buy_item")
n1 = OpNode(DummyUser, InvokeMethod("buy_item_0_compiled"))
n2 = OpNode(DummyItem, InvokeMethod("get_price"))
n3 = OpNode(DummyUser, InvokeMethod("buy_item_1"))
n1 = OpNode(DummyUser, InvokeMethod("buy_item_0_compiled"), read_key_from="user_key")
n2 = OpNode(DummyItem, InvokeMethod("get_price"), read_key_from="item_key", assign_result_to="item_price")
n3 = OpNode(DummyUser, InvokeMethod("buy_item_1"), read_key_from="user_key")
df.add_edge(Edge(n1, n2))
df.add_edge(Edge(n2, n3))

user.buy_item(item)
event = Event(n1, ["user"], {"item_key":"fork"}, df)
event = Event(n1, {"user_key": "user", "item_key":"fork"}, df)

# Manually propogate
item_key = buy_item_0_compiled(event.variable_map, state=user, key_stack=event.key_stack)
next_event = event.propogate(event.key_stack, item_key)
item_key = buy_item_0_compiled(event.variable_map, state=user)
next_event = event.propogate(event, item_key)

assert isinstance(next_event, list)
assert len(next_event) == 1
assert next_event[0].target == n2
assert next_event[0].key_stack == ["user", "fork"]
event = next_event[0]

item_price = get_price_compiled(event.variable_map, state=item, key_stack=event.key_stack)
next_event = event.propogate(event.key_stack, item_price)
# manually add the price to the variable map
item_price = get_price_compiled(event.variable_map, state=item)
assert n2.assign_result_to
event.variable_map[n2.assign_result_to] = item_price

next_event = event.propogate(item_price)

assert isinstance(next_event, list)
assert len(next_event) == 1
assert next_event[0].target == n3
event = next_event[0]

positive_balance = buy_item_1_compiled(event.variable_map, state=user, key_stack=event.key_stack)
next_event = event.propogate(event.key_stack, None)
positive_balance = buy_item_1_compiled(event.variable_map, state=user)
next_event = event.propogate(None)
assert isinstance(next_event, EventResult)


def test_merge_df_propogation():
df = DataFlow("user.buy_2_items")
n0 = OpNode(DummyUser, InvokeMethod("buy_2_items_0"))
n0 = OpNode(DummyUser, InvokeMethod("buy_2_items_0"), read_key_from="user_key")
n3 = CollectNode(assign_result_to="item_prices", read_results_from="item_price")
n1 = OpNode(
DummyItem,
InvokeMethod("get_price"),
assign_result_to="item_price",
collect_target=CollectTarget(n3, 2, 0)
collect_target=CollectTarget(n3, 2, 0),
read_key_from="item_1_key"
)
n2 = OpNode(
DummyItem,
InvokeMethod("get_price"),
assign_result_to="item_price",
collect_target=CollectTarget(n3, 2, 1)
collect_target=CollectTarget(n3, 2, 1),
read_key_from="item_2_key"
)
n4 = OpNode(DummyUser, InvokeMethod("buy_2_items_1"))
n4 = OpNode(DummyUser, InvokeMethod("buy_2_items_1"), read_key_from="user_key")
df.add_edge(Edge(n0, n1))
df.add_edge(Edge(n0, n2))
df.add_edge(Edge(n1, n3))
df.add_edge(Edge(n2, n3))
df.add_edge(Edge(n3, n4))

# User with key "foo" buys items with keys "fork" and "spoon"
event = Event(n0, ["foo"], {"item_1_key": "fork", "item_2_key": "spoon"}, df)
event = Event(n0, {"user_key": "foo", "item_1_key": "fork", "item_2_key": "spoon"}, df)

# Propogate the event (without actually doing any calculation)
# Normally, the key_stack should've been updated by the runtime here:
key_stack = ["foo", ["fork", "spoon"]]
next_event = event.propogate(key_stack, None)
next_event = event.propogate(None)

assert isinstance(next_event, list)
assert len(next_event) == 2
assert next_event[0].target == n1
assert next_event[1].target == n2

event1, event2 = next_event
next_event = event1.propogate(event1.key_stack, None)
next_event = event1.propogate(None)

assert isinstance(next_event, list)
assert len(next_event) == 1
assert next_event[0].target == n3

next_event = event2.propogate(event2.key_stack, None)
next_event = event2.propogate(None)

assert isinstance(next_event, list)
assert len(next_event) == 1
assert next_event[0].target == n3

final_event = next_event[0].propogate(next_event[0].key_stack, None)
final_event = next_event[0].propogate(None)
assert isinstance(final_event, list)
assert final_event[0].target == n4

0 comments on commit 610104c

Please sign in to comment.