Skip to content

Commit 6f22311

Browse files
authored
Implement SetMembership row filter and unit tests for data explorer (#793)
No support for selecting/excluding NA values for now, but that will require changes to the comm protocol.
1 parent 1d925d0 commit 6f22311

File tree

3 files changed

+183
-0
lines changed

3 files changed

+183
-0
lines changed

crates/ark/src/data_explorer/r_data_explorer.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -953,6 +953,7 @@ impl RDataExplorer {
953953
RowFilterType::NotEmpty,
954954
RowFilterType::NotNull,
955955
RowFilterType::Search,
956+
RowFilterType::SetMembership,
956957
]
957958
.iter()
958959
.map(|row_filter_type| RowFilterTypeSupportStatus {

crates/ark/src/modules/positron/r_data_explorer.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,26 @@ col_filter_indices <- function(col, idx = NULL) {
278278
!.ps.filter_col.between(col, params)
279279
}
280280

281+
.ps.filter_col.set_membership <- function(col, params) {
282+
# Check if the column values are in (or not in) the set of filter values
283+
# If inclusive is TRUE, include values in the set
284+
# If inclusive is FALSE, exclude values in the set
285+
286+
# Coerce values to numeric if the column is numeric
287+
values <- if (is.numeric(col)) {
288+
as.numeric(params$values)
289+
} else {
290+
params$values
291+
}
292+
293+
# Return a logical vector indicating which elements match the filter
294+
if (params$inclusive) {
295+
col %in% values
296+
} else {
297+
!(col %in% values)
298+
}
299+
}
300+
281301
.ps.regex_escape <- function(x) {
282302
# Escape all regex magic characters in a string
283303
gsub("([][{}()+*^$|\\\\?.])", "\\\\\\1", x)

crates/ark/tests/data_explorer.rs

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,6 +1686,168 @@ fn test_update_data_filters_reapplied() {
16861686
]);
16871687
}
16881688

1689+
fn create_set_membership_filter(
1690+
column_schema: amalthea::comm::data_explorer_comm::ColumnSchema,
1691+
values: Vec<String>,
1692+
inclusive: bool,
1693+
filter_id: &str,
1694+
) -> RowFilter {
1695+
RowFilter {
1696+
column_schema,
1697+
filter_type: RowFilterType::SetMembership,
1698+
filter_id: filter_id.to_string(),
1699+
condition: RowFilterCondition::And,
1700+
is_valid: None,
1701+
params: Some(RowFilterParams::SetMembership(
1702+
amalthea::comm::data_explorer_comm::FilterSetMembership { values, inclusive },
1703+
)),
1704+
error_message: None,
1705+
}
1706+
}
1707+
1708+
/// Helper function to test set membership filters for both inclusive and exclusive modes
1709+
fn test_set_membership_helper(
1710+
data_frame_name: &str,
1711+
filter_values: Vec<&str>,
1712+
expected_inclusive_count: usize,
1713+
expected_exclusive_count: usize,
1714+
) {
1715+
let socket = open_data_explorer(String::from(data_frame_name));
1716+
1717+
let req = DataExplorerBackendRequest::GetSchema(GetSchemaParams {
1718+
column_indices: vec![0],
1719+
});
1720+
1721+
let schema_reply = socket_rpc(&socket, req);
1722+
let schema = match schema_reply {
1723+
DataExplorerBackendReply::GetSchemaReply(schema) => schema,
1724+
_ => panic!("Unexpected reply: {:?}", schema_reply),
1725+
};
1726+
1727+
let string_values: Vec<String> = filter_values.iter().map(|s| s.to_string()).collect();
1728+
1729+
let inclusive_filter = create_set_membership_filter(
1730+
schema.columns[0].clone(),
1731+
string_values.clone(),
1732+
true, // inclusive
1733+
"inclusive-filter-id",
1734+
);
1735+
1736+
let req = DataExplorerBackendRequest::SetRowFilters(SetRowFiltersParams {
1737+
filters: vec![inclusive_filter],
1738+
});
1739+
1740+
assert_match!(socket_rpc(&socket, req),
1741+
DataExplorerBackendReply::SetRowFiltersReply(
1742+
FilterResult { selected_num_rows: num_rows, had_errors: Some(false) }
1743+
) => {
1744+
assert_eq!(num_rows as usize, expected_inclusive_count,
1745+
"Inclusive filter for {} with values {:?} returned {} rows instead of expected {}",
1746+
data_frame_name, filter_values, num_rows, expected_inclusive_count);
1747+
});
1748+
1749+
let exclusive_filter = create_set_membership_filter(
1750+
schema.columns[0].clone(),
1751+
string_values,
1752+
false, // exclusive
1753+
"exclusive-filter-id",
1754+
);
1755+
1756+
let req = DataExplorerBackendRequest::SetRowFilters(SetRowFiltersParams {
1757+
filters: vec![exclusive_filter],
1758+
});
1759+
1760+
assert_match!(socket_rpc(&socket, req),
1761+
DataExplorerBackendReply::SetRowFiltersReply(
1762+
FilterResult { selected_num_rows: num_rows, had_errors: Some(false) }
1763+
) => {
1764+
assert_eq!(num_rows as usize, expected_exclusive_count,
1765+
"Exclusive filter for {} with values {:?} returned {} rows instead of expected {}",
1766+
data_frame_name, filter_values, num_rows, expected_exclusive_count);
1767+
});
1768+
}
1769+
1770+
#[test]
1771+
fn test_set_membership_filter() {
1772+
let _lock = r_test_lock();
1773+
1774+
r_task(|| {
1775+
harp::parse_eval_global(
1776+
r#"categories <- data.frame(
1777+
fruit = c(
1778+
"apple",
1779+
"banana",
1780+
"orange",
1781+
"grape",
1782+
"kiwi",
1783+
"pear",
1784+
"strawberry"
1785+
)
1786+
)"#,
1787+
)
1788+
.unwrap();
1789+
});
1790+
1791+
test_set_membership_helper(
1792+
"categories", // data frame name
1793+
vec!["apple", "banana", "pear"], // filter values
1794+
3, // expected inclusive match count
1795+
4, // expected exclusive match count
1796+
);
1797+
1798+
r_task(|| {
1799+
harp::parse_eval_global(
1800+
r#"numeric_data <- data.frame(
1801+
values = c(1, 2, 3, 4, 5, 6, 7)
1802+
)"#,
1803+
)
1804+
.unwrap();
1805+
});
1806+
1807+
test_set_membership_helper(
1808+
"numeric_data", // data frame name
1809+
vec!["1", "2", "3"], // filter values (as strings, will be coerced)
1810+
3, // expected inclusive match count
1811+
4, // expected exclusive match count
1812+
);
1813+
1814+
// Test string data frame with NA values
1815+
r_task(|| {
1816+
harp::parse_eval_global(
1817+
r#"categories_with_na <- data.frame(
1818+
fruits = c(
1819+
"apple",
1820+
"banana",
1821+
NA_character_,
1822+
"orange",
1823+
"grape",
1824+
NA_character_,
1825+
"pear"
1826+
)
1827+
)"#,
1828+
)
1829+
.unwrap();
1830+
});
1831+
1832+
// Test with just regular values in the filter (NA values won't match)
1833+
test_set_membership_helper("categories_with_na", vec!["apple", "banana"], 2, 5);
1834+
1835+
// Test numeric data frame with NA values
1836+
r_task(|| {
1837+
harp::parse_eval_global(
1838+
r#"numeric_with_na <- data.frame(
1839+
values = c(1, 2, NA_real_, 3, NA_real_, 4, 5)
1840+
)"#,
1841+
)
1842+
.unwrap();
1843+
});
1844+
1845+
// Tests with just regular values in the filter (NA values won't match)
1846+
test_set_membership_helper("numeric_with_na", vec!["1", "2"], 2, 5);
1847+
test_set_membership_helper("numeric_with_na", vec![], 0, 7);
1848+
test_set_membership_helper("numeric_with_na", vec!["3"], 1, 6);
1849+
}
1850+
16891851
#[test]
16901852
fn test_get_data_values_by_indices() {
16911853
let _lock = r_test_lock();

0 commit comments

Comments
 (0)