Skip to content

Commit bb06576

Browse files
committed
test
1 parent b8d0716 commit bb06576

File tree

3 files changed

+53
-37
lines changed

3 files changed

+53
-37
lines changed

eds_scikit/period/stays.py

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from eds_scikit.utils.checks import MissingConceptError, algo_checker, concept_checker
77
from eds_scikit.utils.datetime_helpers import substract_datetime
88
from eds_scikit.utils.framework import get_framework
9+
from eds_scikit.utils.sort_values_first import sort_values_first
910
from eds_scikit.utils.typing import DataFrame
1011

1112

@@ -292,44 +293,11 @@ def get_first(
292293
how="inner",
293294
)
294295

295-
# Getting the corresponding first visit
296-
# Replacement for :
297-
# first_visit = merged.sort_values(by=[flag_name, "visit_start_datetime_1"],
298-
# ascending=[False, False])
299-
# .groupby(visit_occurrence_id_2).first()["visit_occurrence_id_1"]
300-
# which is not deterministic in Koalas
301-
302-
flagged = (
303-
merged[merged[flag_name]]
304-
.groupby("visit_occurrence_id_2", as_index=False)[
305-
["visit_start_datetime_1"]
306-
]
307-
.max()
296+
first_visit = sort_values_first(
297+
merged,
298+
by_cols=["visit_occurrence_id_2"],
299+
cols=[flag_name, "visit_start_datetime_1", "visit_occurrence_id_1"],
308300
)
309-
flagged = merged[merged[flag_name]].merge(
310-
flagged, on=["visit_occurrence_id_2", "visit_start_datetime_1"], how="right"
311-
)
312-
flagged["flagged"] = True
313-
unflagged = (
314-
merged[~merged[flag_name]]
315-
.groupby("visit_occurrence_id_2", as_index=False)[
316-
["visit_start_datetime_1"]
317-
]
318-
.max()
319-
)
320-
unflagged = merged[~merged[flag_name]].merge(
321-
unflagged,
322-
on=["visit_occurrence_id_2", "visit_start_datetime_1"],
323-
how="right",
324-
)
325-
unflagged = unflagged.merge(
326-
flagged[["visit_occurrence_id_2", "flagged"]],
327-
on="visit_occurrence_id_2",
328-
how="left",
329-
)
330-
unflagged = unflagged[unflagged.flagged.isna()]
331-
first_visit = fw.concat((flagged, unflagged), axis=0)
332-
333301
first_visit = first_visit.rename(
334302
columns={
335303
"visit_occurrence_id_1": f"{concept_prefix}STAY_ID",
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from eds_scikit.utils.typing import DataFrame
2+
from typing import List
3+
4+
def sort_values_first(df : DataFrame,
5+
by_cols : List[str],
6+
cols : List[str],
7+
ascending : bool = False):
8+
return df.groupby(by_cols).apply(lambda group: group.sort_values(by=cols, ascending=[ascending for i in cols]).head(1)).reset_index(drop=True)

tests/test_sort_values_first.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import pandas as pd
2+
import pytest
3+
from eds_scikit.utils import framework
4+
from eds_scikit.utils.sort_values_first import sort_values_first
5+
from eds_scikit.utils.test_utils import assert_equal_no_order
6+
7+
from databricks import koalas as ks
8+
import numpy as np
9+
10+
# Create a DataFrame
11+
np.random.seed(0)
12+
size=10000
13+
data = {
14+
'A': np.random.choice(['X', 'Y', 'Z'], size),
15+
'B': np.random.randint(1, 5, size),
16+
'C': np.random.randint(1, 5, size),
17+
'D': np.random.randint(1, 5, size),
18+
'E': np.random.randint(1, 5, size)
19+
}
20+
21+
inputs = pd.DataFrame(data)
22+
inputs.loc[0, 'B'] = 0
23+
inputs.loc[0, 'C'] = 4
24+
25+
@pytest.mark.parametrize(
26+
"module",
27+
["pandas", "koalas"],
28+
)
29+
def test_sort_values_first(module):
30+
31+
inputs_fr = framework.to(module, inputs)
32+
33+
computed = framework.pandas(sort_values_first(inputs_fr, ["A"], ["B", "C"], ascending=True))
34+
expected = inputs.sort_values(["B", "C"], ascending=True).groupby("A", as_index=False).first()
35+
assert_equal_no_order(computed, expected)
36+
37+
computed = framework.pandas(sort_values_first(inputs_fr, ["A"], ["B", "C"], ascending=False))
38+
expected = inputs.sort_values(["B", "C"], ascending=False).groupby("A", as_index=False).first()
39+
assert_equal_no_order(computed, expected)
40+

0 commit comments

Comments
 (0)