@@ -2712,3 +2712,88 @@ def revoke_partition(store, partition):
2712
2712
consumer_group ,
2713
2713
state_dir ,
2714
2714
)
2715
+
2716
+ @pytest .mark .parametrize ("store_type" , SUPPORTED_STORES , indirect = True )
2717
+ def test_concatenated_sdfs_stateful (
2718
+ self ,
2719
+ app_factory ,
2720
+ executor ,
2721
+ state_manager_factory ,
2722
+ tmp_path ,
2723
+ ):
2724
+ def on_message_processed (* _ ):
2725
+ # Set the callback to track total messages processed
2726
+ # The callback is not triggered if processing fails
2727
+ nonlocal processed_count
2728
+
2729
+ processed_count += 1
2730
+ # Stop processing after consuming all the messages
2731
+ if processed_count == total_messages :
2732
+ done .set_result (True )
2733
+
2734
+ processed_count = 0
2735
+
2736
+ consumer_group = str (uuid .uuid4 ())
2737
+ state_dir = (tmp_path / "state" ).absolute ()
2738
+ partition_num = 0
2739
+ app = app_factory (
2740
+ commit_interval = 0 ,
2741
+ consumer_group = consumer_group ,
2742
+ auto_offset_reset = "earliest" ,
2743
+ state_dir = state_dir ,
2744
+ on_message_processed = on_message_processed ,
2745
+ use_changelog_topics = True ,
2746
+ )
2747
+ input_topic_a = app .topic (
2748
+ str (uuid .uuid4 ()), value_deserializer = JSONDeserializer ()
2749
+ )
2750
+ input_topic_b = app .topic (
2751
+ str (uuid .uuid4 ()), value_deserializer = JSONDeserializer ()
2752
+ )
2753
+ input_topics = [input_topic_a , input_topic_b ]
2754
+ messages_per_topic = 3
2755
+ total_messages = messages_per_topic * len (input_topics )
2756
+
2757
+ # Define a function that counts incoming Rows using state
2758
+ def count (_ , state : State ):
2759
+ total = state .get ("total" , 0 )
2760
+ total += 1
2761
+ state .set ("total" , total )
2762
+
2763
+ sdf_a = app .dataframe (input_topic_a )
2764
+ sdf_b = app .dataframe (input_topic_b )
2765
+
2766
+ sdf_concat = sdf_a .concat (sdf_b )
2767
+ sdf_concat .update (count , stateful = True )
2768
+
2769
+ # Produce messages to the topic and flush
2770
+ message_key = b"key"
2771
+ data = {
2772
+ "key" : message_key ,
2773
+ "value" : dumps ({"key" : "value" }),
2774
+ "partition" : partition_num ,
2775
+ }
2776
+ with app .get_producer () as producer :
2777
+ for topic in input_topics :
2778
+ for _ in range (messages_per_topic ):
2779
+ producer .produce (topic .name , ** data )
2780
+
2781
+ done = Future ()
2782
+ # Stop app when the future is resolved
2783
+ executor .submit (_stop_app_on_future , app , done , 10.0 )
2784
+
2785
+ stores = {}
2786
+
2787
+ def revoke_partition (store_ , partition ):
2788
+ stores [store_ .stream_id ] = store_
2789
+
2790
+ with patch ("quixstreams.state.base.Store.revoke_partition" , revoke_partition ):
2791
+ app .run ()
2792
+
2793
+ assert processed_count == total_messages
2794
+
2795
+ store = stores [sdf_concat .stream_id ]
2796
+ partition = store .partitions [partition_num ]
2797
+ with partition .begin () as tx :
2798
+ assert tx .get ("total" , prefix = message_key ) == total_messages
2799
+ store .revoke_partition (partition_num )
0 commit comments