@@ -1686,6 +1686,168 @@ fn test_update_data_filters_reapplied() {
1686
1686
] ) ;
1687
1687
}
1688
1688
1689
+ fn create_set_membership_filter (
1690
+ column_schema : amalthea:: comm:: data_explorer_comm:: ColumnSchema ,
1691
+ values : Vec < String > ,
1692
+ inclusive : bool ,
1693
+ filter_id : & str ,
1694
+ ) -> RowFilter {
1695
+ RowFilter {
1696
+ column_schema,
1697
+ filter_type : RowFilterType :: SetMembership ,
1698
+ filter_id : filter_id. to_string ( ) ,
1699
+ condition : RowFilterCondition :: And ,
1700
+ is_valid : None ,
1701
+ params : Some ( RowFilterParams :: SetMembership (
1702
+ amalthea:: comm:: data_explorer_comm:: FilterSetMembership { values, inclusive } ,
1703
+ ) ) ,
1704
+ error_message : None ,
1705
+ }
1706
+ }
1707
+
1708
+ /// Helper function to test set membership filters for both inclusive and exclusive modes
1709
+ fn test_set_membership_helper (
1710
+ data_frame_name : & str ,
1711
+ filter_values : Vec < & str > ,
1712
+ expected_inclusive_count : usize ,
1713
+ expected_exclusive_count : usize ,
1714
+ ) {
1715
+ let socket = open_data_explorer ( String :: from ( data_frame_name) ) ;
1716
+
1717
+ let req = DataExplorerBackendRequest :: GetSchema ( GetSchemaParams {
1718
+ column_indices : vec ! [ 0 ] ,
1719
+ } ) ;
1720
+
1721
+ let schema_reply = socket_rpc ( & socket, req) ;
1722
+ let schema = match schema_reply {
1723
+ DataExplorerBackendReply :: GetSchemaReply ( schema) => schema,
1724
+ _ => panic ! ( "Unexpected reply: {:?}" , schema_reply) ,
1725
+ } ;
1726
+
1727
+ let string_values: Vec < String > = filter_values. iter ( ) . map ( |s| s. to_string ( ) ) . collect ( ) ;
1728
+
1729
+ let inclusive_filter = create_set_membership_filter (
1730
+ schema. columns [ 0 ] . clone ( ) ,
1731
+ string_values. clone ( ) ,
1732
+ true , // inclusive
1733
+ "inclusive-filter-id" ,
1734
+ ) ;
1735
+
1736
+ let req = DataExplorerBackendRequest :: SetRowFilters ( SetRowFiltersParams {
1737
+ filters : vec ! [ inclusive_filter] ,
1738
+ } ) ;
1739
+
1740
+ assert_match ! ( socket_rpc( & socket, req) ,
1741
+ DataExplorerBackendReply :: SetRowFiltersReply (
1742
+ FilterResult { selected_num_rows: num_rows, had_errors: Some ( false ) }
1743
+ ) => {
1744
+ assert_eq!( num_rows as usize , expected_inclusive_count,
1745
+ "Inclusive filter for {} with values {:?} returned {} rows instead of expected {}" ,
1746
+ data_frame_name, filter_values, num_rows, expected_inclusive_count) ;
1747
+ } ) ;
1748
+
1749
+ let exclusive_filter = create_set_membership_filter (
1750
+ schema. columns [ 0 ] . clone ( ) ,
1751
+ string_values,
1752
+ false , // exclusive
1753
+ "exclusive-filter-id" ,
1754
+ ) ;
1755
+
1756
+ let req = DataExplorerBackendRequest :: SetRowFilters ( SetRowFiltersParams {
1757
+ filters : vec ! [ exclusive_filter] ,
1758
+ } ) ;
1759
+
1760
+ assert_match ! ( socket_rpc( & socket, req) ,
1761
+ DataExplorerBackendReply :: SetRowFiltersReply (
1762
+ FilterResult { selected_num_rows: num_rows, had_errors: Some ( false ) }
1763
+ ) => {
1764
+ assert_eq!( num_rows as usize , expected_exclusive_count,
1765
+ "Exclusive filter for {} with values {:?} returned {} rows instead of expected {}" ,
1766
+ data_frame_name, filter_values, num_rows, expected_exclusive_count) ;
1767
+ } ) ;
1768
+ }
1769
+
1770
+ #[ test]
1771
+ fn test_set_membership_filter ( ) {
1772
+ let _lock = r_test_lock ( ) ;
1773
+
1774
+ r_task ( || {
1775
+ harp:: parse_eval_global (
1776
+ r#"categories <- data.frame(
1777
+ fruit = c(
1778
+ "apple",
1779
+ "banana",
1780
+ "orange",
1781
+ "grape",
1782
+ "kiwi",
1783
+ "pear",
1784
+ "strawberry"
1785
+ )
1786
+ )"# ,
1787
+ )
1788
+ . unwrap ( ) ;
1789
+ } ) ;
1790
+
1791
+ test_set_membership_helper (
1792
+ "categories" , // data frame name
1793
+ vec ! [ "apple" , "banana" , "pear" ] , // filter values
1794
+ 3 , // expected inclusive match count
1795
+ 4 , // expected exclusive match count
1796
+ ) ;
1797
+
1798
+ r_task ( || {
1799
+ harp:: parse_eval_global (
1800
+ r#"numeric_data <- data.frame(
1801
+ values = c(1, 2, 3, 4, 5, 6, 7)
1802
+ )"# ,
1803
+ )
1804
+ . unwrap ( ) ;
1805
+ } ) ;
1806
+
1807
+ test_set_membership_helper (
1808
+ "numeric_data" , // data frame name
1809
+ vec ! [ "1" , "2" , "3" ] , // filter values (as strings, will be coerced)
1810
+ 3 , // expected inclusive match count
1811
+ 4 , // expected exclusive match count
1812
+ ) ;
1813
+
1814
+ // Test string data frame with NA values
1815
+ r_task ( || {
1816
+ harp:: parse_eval_global (
1817
+ r#"categories_with_na <- data.frame(
1818
+ fruits = c(
1819
+ "apple",
1820
+ "banana",
1821
+ NA_character_,
1822
+ "orange",
1823
+ "grape",
1824
+ NA_character_,
1825
+ "pear"
1826
+ )
1827
+ )"# ,
1828
+ )
1829
+ . unwrap ( ) ;
1830
+ } ) ;
1831
+
1832
+ // Test with just regular values in the filter (NA values won't match)
1833
+ test_set_membership_helper ( "categories_with_na" , vec ! [ "apple" , "banana" ] , 2 , 5 ) ;
1834
+
1835
+ // Test numeric data frame with NA values
1836
+ r_task ( || {
1837
+ harp:: parse_eval_global (
1838
+ r#"numeric_with_na <- data.frame(
1839
+ values = c(1, 2, NA_real_, 3, NA_real_, 4, 5)
1840
+ )"# ,
1841
+ )
1842
+ . unwrap ( ) ;
1843
+ } ) ;
1844
+
1845
+ // Tests with just regular values in the filter (NA values won't match)
1846
+ test_set_membership_helper ( "numeric_with_na" , vec ! [ "1" , "2" ] , 2 , 5 ) ;
1847
+ test_set_membership_helper ( "numeric_with_na" , vec ! [ ] , 0 , 7 ) ;
1848
+ test_set_membership_helper ( "numeric_with_na" , vec ! [ "3" ] , 1 , 6 ) ;
1849
+ }
1850
+
1689
1851
#[ test]
1690
1852
fn test_get_data_values_by_indices ( ) {
1691
1853
let _lock = r_test_lock ( ) ;
0 commit comments