|
6 | 6 | from eds_scikit.utils.checks import MissingConceptError, algo_checker, concept_checker |
7 | 7 | from eds_scikit.utils.datetime_helpers import substract_datetime |
8 | 8 | from eds_scikit.utils.framework import get_framework |
| 9 | +from eds_scikit.utils.sort_values_first import sort_values_first |
9 | 10 | from eds_scikit.utils.typing import DataFrame |
10 | 11 |
|
11 | 12 |
|
@@ -292,44 +293,11 @@ def get_first( |
292 | 293 | how="inner", |
293 | 294 | ) |
294 | 295 |
|
295 | | - # Getting the corresponding first visit |
296 | | - # Replacement for : |
297 | | - # first_visit = merged.sort_values(by=[flag_name, "visit_start_datetime_1"], |
298 | | - # ascending=[False, False]) |
299 | | - # .groupby(visit_occurrence_id_2).first()["visit_occurrence_id_1"] |
300 | | - # which is not deterministic in Koalas |
301 | | - |
302 | | - flagged = ( |
303 | | - merged[merged[flag_name]] |
304 | | - .groupby("visit_occurrence_id_2", as_index=False)[ |
305 | | - ["visit_start_datetime_1"] |
306 | | - ] |
307 | | - .max() |
| 296 | + first_visit = sort_values_first( |
| 297 | + merged, |
| 298 | + by_cols=["visit_occurrence_id_2"], |
| 299 | + cols=[flag_name, "visit_start_datetime_1", "visit_occurrence_id_1"], |
308 | 300 | ) |
309 | | - flagged = merged[merged[flag_name]].merge( |
310 | | - flagged, on=["visit_occurrence_id_2", "visit_start_datetime_1"], how="right" |
311 | | - ) |
312 | | - flagged["flagged"] = True |
313 | | - unflagged = ( |
314 | | - merged[~merged[flag_name]] |
315 | | - .groupby("visit_occurrence_id_2", as_index=False)[ |
316 | | - ["visit_start_datetime_1"] |
317 | | - ] |
318 | | - .max() |
319 | | - ) |
320 | | - unflagged = merged[~merged[flag_name]].merge( |
321 | | - unflagged, |
322 | | - on=["visit_occurrence_id_2", "visit_start_datetime_1"], |
323 | | - how="right", |
324 | | - ) |
325 | | - unflagged = unflagged.merge( |
326 | | - flagged[["visit_occurrence_id_2", "flagged"]], |
327 | | - on="visit_occurrence_id_2", |
328 | | - how="left", |
329 | | - ) |
330 | | - unflagged = unflagged[unflagged.flagged.isna()] |
331 | | - first_visit = fw.concat((flagged, unflagged), axis=0) |
332 | | - |
333 | 301 | first_visit = first_visit.rename( |
334 | 302 | columns={ |
335 | 303 | "visit_occurrence_id_1": f"{concept_prefix}STAY_ID", |
|
0 commit comments