@@ -152,13 +152,13 @@ def select_columns(
152
152
153
153
Exclude columns with the `DropLabel` class:
154
154
>>> from janitor import DropLabel
155
- >>> df.select_columns(DropLabel(slice("name", "awake")), "conservation" )
156
- brainwt bodywt conservation
157
- 0 NaN 50.000 lc
158
- 1 0.01550 0.480 NaN
159
- 2 NaN 1.350 nt
160
- 3 0.00029 0.019 lc
161
- 4 0.42300 600.000 domesticated
155
+ >>> df.select_columns(DropLabel(slice("name", "awake")))
156
+ brainwt bodywt
157
+ 0 NaN 50.000
158
+ 1 0.01550 0.480
159
+ 2 NaN 1.350
160
+ 3 0.00029 0.019
161
+ 4 0.42300 600.000
162
162
163
163
Selection on MultiIndex columns:
164
164
>>> d = {'num_legs': [4, 4, 2, 2],
@@ -673,7 +673,7 @@ def _index_dispatch(arg, df, axis): # noqa: F811
673
673
674
674
Returns an array of integers.
675
675
"""
676
- level_label = {}
676
+
677
677
index = getattr (df , axis )
678
678
if not isinstance (index , pd .MultiIndex ):
679
679
return _select_index (list (arg ), df , axis )
@@ -687,6 +687,7 @@ def _index_dispatch(arg, df, axis): # noqa: F811
687
687
"in the MultiIndex, and should either be all "
688
688
"strings or integers."
689
689
)
690
+ level_label = {}
690
691
for key , value in arg .items ():
691
692
if isinstance (value , dispatch_callable ):
692
693
indexer = index .get_level_values (key )
@@ -757,14 +758,14 @@ def _index_dispatch(arg, df, axis): # noqa: F811
757
758
def _column_sel_dispatch (cols , df , axis ): # noqa: F811
758
759
"""
759
760
Base function for selection on a Pandas Index object.
760
- Returns the inverse of the passed label(s).
761
+ Returns the inverse of the passed label(s),
762
+ or the set difference if it is part of a list of labels.
761
763
762
764
Returns an array of integers.
763
765
"""
764
766
arr = _select_index (cols .label , df , axis )
765
767
index = np .arange (getattr (df , axis ).size )
766
- arr = _index_converter (arr , index )
767
- return np .delete (index , arr )
768
+ return _index_converter (arr , index )
768
769
769
770
770
771
@_select_index .register (set )
@@ -797,27 +798,43 @@ def _index_dispatch(arg, df, axis): # noqa: F811
797
798
indices = index .get_indexer_for (list (arg ))
798
799
if (indices != - 1 ).all ():
799
800
return indices
800
- # treat multiple DropLabel instances as a single unit
801
- checks = (isinstance (entry , DropLabel ) for entry in arg )
802
- if sum (checks ) > 1 :
803
- drop_labels = (entry for entry in arg if isinstance (entry , DropLabel ))
804
- drop_labels = [entry .label for entry in drop_labels ]
805
- drop_labels = DropLabel (drop_labels )
806
- arg = [entry for entry in arg if not isinstance (entry , DropLabel )]
807
- arg .append (drop_labels )
808
- indices = [_select_index (entry , df , axis ) for entry in arg ]
801
+
802
+ include = []
803
+ exclude = []
804
+ for entry in arg :
805
+ if isinstance (entry , DropLabel ):
806
+ exclude .append (entry )
807
+ else :
808
+ outcome = _select_index (entry , df , axis )
809
+ include .append (outcome )
810
+ if exclude :
811
+ if len (exclude ) > 1 :
812
+ exclude = [entry .label for entry in exclude ]
813
+ exclude = DropLabel (exclude )
814
+ else :
815
+ exclude = exclude [0 ]
816
+ exclude = _select_index (exclude , df , axis )
817
+ len_exclude = len (exclude )
818
+ if len_exclude and not include :
819
+ index_arr = np .arange (getattr (df , axis ).size )
820
+ return np .delete (index_arr , exclude )
821
+ if include and len_exclude :
822
+ include = [_index_converter (arr , index ) for arr in include ]
823
+ include = np .concatenate (include )
824
+ mask = np .isin (include , exclude )
825
+ return include [~ mask ]
809
826
# single entry does not need to be combined
810
827
# or materialized if possible;
811
828
# this offers more performance
812
- if len (indices ) == 1 :
813
- if is_scalar (indices [0 ]):
814
- return indices
815
- indices = indices [0 ]
816
- if is_list_like (indices ):
817
- indices = np .asanyarray (indices )
818
- return indices
819
- indices = [_index_converter (arr , index ) for arr in indices ]
820
- return np .concatenate (indices )
829
+ if len (include ) == 1 :
830
+ if is_scalar (include [0 ]):
831
+ return include
832
+ include = include [0 ]
833
+ if is_list_like (include ):
834
+ include = np .asanyarray (include )
835
+ return include
836
+ include = [_index_converter (arr , index ) for arr in include ]
837
+ return np .concatenate (include )
821
838
822
839
823
840
def _index_converter (arr , index ):
0 commit comments