diff --git a/cf_pandas/accessor.py b/cf_pandas/accessor.py index 1dabf01..d2cfdd8 100644 --- a/cf_pandas/accessor.py +++ b/cf_pandas/accessor.py @@ -58,17 +58,20 @@ class CFAccessor: """Dataframe accessor analogous to cf-xarray accessor.""" def __init__(self, pandas_obj): - self._validate(pandas_obj) + # don't automatically validate but can when needed + # self._validate(pandas_obj) self._obj = pandas_obj - @staticmethod - def _validate(obj): + # @staticmethod + def _validate(self): """what is necessary for basic use.""" # verify that necessary keys are present. Z would also be nice but might be missing. # but don't use the accessor to check keys = ["T", "longitude", "latitude"] - missing_keys = [key for key in keys if len(_get_axis_coord(obj, key)) == 0] + missing_keys = [ + key for key in keys if len(_get_axis_coord(self._obj, key)) == 0 + ] if len(missing_keys) > 0: raise AttributeError( f'{"longitude", "latitude", "time"} must be identifiable in DataFrame but {missing_keys} are missing.' @@ -110,9 +113,12 @@ def __getitem__(self, key: str) -> Union[pd.Series, pd.DataFrame]: else: col_names = _get_custom_criteria(self._obj, key) - # return series - if len(col_names) == 1: + # return series for column + if len(col_names) == 1 and col_names[0] in self._obj.columns: return self._obj[col_names[0]] + # return index + elif len(col_names) == 1 and col_names[0] in self._obj.index.names: + return self._obj.index.get_level_values(col_names[0]) # return DataFrame elif len(col_names) > 1: return self._obj[col_names] @@ -248,6 +254,32 @@ def custom_keys(self): return vardict + @property + def axes_cols(self) -> List[str]: + """ + Property that returns a list of column names from the axes mapping. + + Returns + ------- + list + Variable names that are the column names which represent axes. + """ + + return list(itertools.chain(*[*self.axes.values()])) + + @property + def coordinates_cols(self) -> List[str]: + """ + Property that returns a list of column names from the coordinates mapping. + + Returns + ------- + list + Variable names that are the column names which represent coordinates. + """ + + return list(itertools.chain(*[*self.coordinates.values()])) + @property def standard_names(self): """ @@ -313,26 +345,14 @@ def _get_axis_coord(obj: Union[DataFrame, Series], key: str) -> list: f"cf_xarray did not understand key {key!r}. Expected one of {valid_keys!r}" ) - # search_in = set() - # attrs_or_encoding = ChainMap(obj.attrs, obj.encoding) - # coordinates = attrs_or_encoding.get("coordinates", None) - - # # Handles case where the coordinates attribute is None - # # This is used to tell xarray to not write a coordinates attribute - # if coordinates: - # search_in.update(coordinates.split(" ")) - # if not search_in: - # search_in = set(obj.coords) - - # # maybe only do this for key in _AXIS_NAMES? - # search_in.update(obj.indexes) - - # search_in = search_in & set(obj.coords) + # loop over column names and index names results: set = set() - for col in obj.columns: - # var = obj.coords[coord] + cols_and_indices = list(obj.columns) + cols_and_indices += obj.index.names + # remove None if in names from index + cols_and_indices = [name for name in cols_and_indices if name is not None] + for col in cols_and_indices: if key in coordinate_criteria: - # import pdb; pdb.set_trace() for criterion, expected in coordinate_criteria[key].items(): # allow for the column header having a space in it that separate # the name from the units, for example @@ -350,14 +370,19 @@ def _get_axis_coord(obj: Union[DataFrame, Series], key: str) -> list: # units = getattr(col.data, "units", None) # if units in expected: # results.update((col,)) - # also use the guess_regex approach by default, but only if no results so far # this takes the logic from cf-xarray guess_coord_axis if len(results) == 0: - if key in ("T", "time") and _is_datetime_like(obj[col]): - results.update((col,)) - continue # prevent second detection - + if col in obj.columns: + if key in ("T", "time") and _is_datetime_like(obj[col]): + results.update((col,)) + continue # prevent second detection + elif col in obj.index.names: + if key in ("T", "time") and _is_datetime_like( + obj.index.get_level_values(col) + ): + results.update((col,)) + continue # prevent second detection pattern = guess_regex[key] if pattern.match(col.lower()): results.update((col,)) diff --git a/cf_pandas/criteria.py b/cf_pandas/criteria.py index 5b30f35..bf8226a 100644 --- a/cf_pandas/criteria.py +++ b/cf_pandas/criteria.py @@ -103,15 +103,22 @@ coordinate_criteria["X"]["long_name"] += ("cell index along first dimension",) coordinate_criteria["Y"]["long_name"] += ("cell index along second dimension",) +# changes allow for the pattern string to not be at the start of the comparison string +# like (?=.*lon) guess_regex = { - "time": re.compile("\\bt\\b|(time|min|hour|day|week|month|year)[0-9]*"), + "time": re.compile("\\bt\\b|(?=.*time|min|hour|day|week|month|year)[0-9]*"), + # "time": re.compile("\\bt\\b|(time|min|hour|day|week|month|year)[0-9]*"), "Z": re.compile( - "(z|nav_lev|gdep|lv_|[o]*lev|bottom_top|sigma|h(ei)?ght|altitude|depth|" + "(z|nav_lev|gdep|lv_|[o]*lev|bottom_top|sigma|(?=.*dbars)|h(ei)?ght|altitude|depth|" "isobaric|pres|isotherm)[a-z_]*[0-9]*" ), + # "Z": re.compile( + # "(z|nav_lev|gdep|lv_|[o]*lev|bottom_top|sigma|h(ei)?ght|altitude|depth|" + # "isobaric|pres|isotherm)[a-z_]*[0-9]*" + # ), "Y": re.compile("y|j|nlat|nj"), - "latitude": re.compile("y?(nav_lat|lat|gphi)[a-z0-9]*"), + "latitude": re.compile("y?(nav_lat|(?=.*lat)|gphi)[a-z0-9]*"), "X": re.compile("x|i|nlon|ni"), - "longitude": re.compile("x?(nav_lon|lon|glam)[a-z0-9]*"), + "longitude": re.compile("x?(nav_lon|(?=.*lon)|glam)[a-z0-9]*"), } guess_regex["T"] = guess_regex["time"] diff --git a/tests/test_accessor.py b/tests/test_accessor.py index 6ce4f67..990db14 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -33,7 +33,7 @@ def test_validate(): ] ) with pytest.raises(AttributeError): - df.cf.keys() + df.cf._validate() def test_match_criteria_key_accessor(): @@ -128,3 +128,20 @@ def test_get_by_guess_regex(): assert df.cf["longitude"].name == "lon" assert df.cf["latitude"].name == "lat" assert df.cf["time"].name == "min" + + df = pd.DataFrame(columns=["blah_lon", "table_lat"]) + assert df.cf["longitude"].name == "blah_lon" + assert df.cf["latitude"].name == "table_lat" + + +def test_index(): + """Test when time is in index.""" + df = pd.DataFrame(index=["m_time"]) + df.index.rename("m_time", inplace=True) + assert df.cf["T"].name == "m_time" + + +def test_cols(): + df = pd.DataFrame(columns=["m_time", "lon", "lat", "temp"]) + assert df.cf.axes_cols == ["m_time"] + assert sorted(df.cf.coordinates_cols) == ["lat", "lon", "m_time"]