Skip to content

Commit

Permalink
Merge pull request #26 from kthyng/improvements
Browse files Browse the repository at this point in the history
Accessor can return indices, some guess_regex terms can be in middle of string now.
  • Loading branch information
kthyng authored Apr 28, 2023
2 parents 83208a6 + 0b991b4 commit 614f6de
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 34 deletions.
83 changes: 54 additions & 29 deletions cf_pandas/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,20 @@ class CFAccessor:
"""Dataframe accessor analogous to cf-xarray accessor."""

def __init__(self, pandas_obj):
self._validate(pandas_obj)
# don't automatically validate but can when needed
# self._validate(pandas_obj)
self._obj = pandas_obj

@staticmethod
def _validate(obj):
# @staticmethod
def _validate(self):
"""what is necessary for basic use."""

# verify that necessary keys are present. Z would also be nice but might be missing.
# but don't use the accessor to check
keys = ["T", "longitude", "latitude"]
missing_keys = [key for key in keys if len(_get_axis_coord(obj, key)) == 0]
missing_keys = [
key for key in keys if len(_get_axis_coord(self._obj, key)) == 0
]
if len(missing_keys) > 0:
raise AttributeError(
f'{"longitude", "latitude", "time"} must be identifiable in DataFrame but {missing_keys} are missing.'
Expand Down Expand Up @@ -110,9 +113,12 @@ def __getitem__(self, key: str) -> Union[pd.Series, pd.DataFrame]:
else:
col_names = _get_custom_criteria(self._obj, key)

# return series
if len(col_names) == 1:
# return series for column
if len(col_names) == 1 and col_names[0] in self._obj.columns:
return self._obj[col_names[0]]
# return index
elif len(col_names) == 1 and col_names[0] in self._obj.index.names:
return self._obj.index.get_level_values(col_names[0])
# return DataFrame
elif len(col_names) > 1:
return self._obj[col_names]
Expand Down Expand Up @@ -248,6 +254,32 @@ def custom_keys(self):

return vardict

@property
def axes_cols(self) -> List[str]:
"""
Property that returns a list of column names from the axes mapping.
Returns
-------
list
Variable names that are the column names which represent axes.
"""

return list(itertools.chain(*[*self.axes.values()]))

@property
def coordinates_cols(self) -> List[str]:
"""
Property that returns a list of column names from the coordinates mapping.
Returns
-------
list
Variable names that are the column names which represent coordinates.
"""

return list(itertools.chain(*[*self.coordinates.values()]))

@property
def standard_names(self):
"""
Expand Down Expand Up @@ -313,26 +345,14 @@ def _get_axis_coord(obj: Union[DataFrame, Series], key: str) -> list:
f"cf_xarray did not understand key {key!r}. Expected one of {valid_keys!r}"
)

# search_in = set()
# attrs_or_encoding = ChainMap(obj.attrs, obj.encoding)
# coordinates = attrs_or_encoding.get("coordinates", None)

# # Handles case where the coordinates attribute is None
# # This is used to tell xarray to not write a coordinates attribute
# if coordinates:
# search_in.update(coordinates.split(" "))
# if not search_in:
# search_in = set(obj.coords)

# # maybe only do this for key in _AXIS_NAMES?
# search_in.update(obj.indexes)

# search_in = search_in & set(obj.coords)
# loop over column names and index names
results: set = set()
for col in obj.columns:
# var = obj.coords[coord]
cols_and_indices = list(obj.columns)
cols_and_indices += obj.index.names
# remove None if in names from index
cols_and_indices = [name for name in cols_and_indices if name is not None]
for col in cols_and_indices:
if key in coordinate_criteria:
# import pdb; pdb.set_trace()
for criterion, expected in coordinate_criteria[key].items():
# allow for the column header having a space in it that separate
# the name from the units, for example
Expand All @@ -350,14 +370,19 @@ def _get_axis_coord(obj: Union[DataFrame, Series], key: str) -> list:
# units = getattr(col.data, "units", None)
# if units in expected:
# results.update((col,))

# also use the guess_regex approach by default, but only if no results so far
# this takes the logic from cf-xarray guess_coord_axis
if len(results) == 0:
if key in ("T", "time") and _is_datetime_like(obj[col]):
results.update((col,))
continue # prevent second detection

if col in obj.columns:
if key in ("T", "time") and _is_datetime_like(obj[col]):
results.update((col,))
continue # prevent second detection
elif col in obj.index.names:
if key in ("T", "time") and _is_datetime_like(
obj.index.get_level_values(col)
):
results.update((col,))
continue # prevent second detection
pattern = guess_regex[key]
if pattern.match(col.lower()):
results.update((col,))
Expand Down
15 changes: 11 additions & 4 deletions cf_pandas/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,22 @@
coordinate_criteria["X"]["long_name"] += ("cell index along first dimension",)
coordinate_criteria["Y"]["long_name"] += ("cell index along second dimension",)

# changes allow for the pattern string to not be at the start of the comparison string
# like (?=.*lon)
guess_regex = {
"time": re.compile("\\bt\\b|(time|min|hour|day|week|month|year)[0-9]*"),
"time": re.compile("\\bt\\b|(?=.*time|min|hour|day|week|month|year)[0-9]*"),
# "time": re.compile("\\bt\\b|(time|min|hour|day|week|month|year)[0-9]*"),
"Z": re.compile(
"(z|nav_lev|gdep|lv_|[o]*lev|bottom_top|sigma|h(ei)?ght|altitude|depth|"
"(z|nav_lev|gdep|lv_|[o]*lev|bottom_top|sigma|(?=.*dbars)|h(ei)?ght|altitude|depth|"
"isobaric|pres|isotherm)[a-z_]*[0-9]*"
),
# "Z": re.compile(
# "(z|nav_lev|gdep|lv_|[o]*lev|bottom_top|sigma|h(ei)?ght|altitude|depth|"
# "isobaric|pres|isotherm)[a-z_]*[0-9]*"
# ),
"Y": re.compile("y|j|nlat|nj"),
"latitude": re.compile("y?(nav_lat|lat|gphi)[a-z0-9]*"),
"latitude": re.compile("y?(nav_lat|(?=.*lat)|gphi)[a-z0-9]*"),
"X": re.compile("x|i|nlon|ni"),
"longitude": re.compile("x?(nav_lon|lon|glam)[a-z0-9]*"),
"longitude": re.compile("x?(nav_lon|(?=.*lon)|glam)[a-z0-9]*"),
}
guess_regex["T"] = guess_regex["time"]
19 changes: 18 additions & 1 deletion tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_validate():
]
)
with pytest.raises(AttributeError):
df.cf.keys()
df.cf._validate()


def test_match_criteria_key_accessor():
Expand Down Expand Up @@ -128,3 +128,20 @@ def test_get_by_guess_regex():
assert df.cf["longitude"].name == "lon"
assert df.cf["latitude"].name == "lat"
assert df.cf["time"].name == "min"

df = pd.DataFrame(columns=["blah_lon", "table_lat"])
assert df.cf["longitude"].name == "blah_lon"
assert df.cf["latitude"].name == "table_lat"


def test_index():
"""Test when time is in index."""
df = pd.DataFrame(index=["m_time"])
df.index.rename("m_time", inplace=True)
assert df.cf["T"].name == "m_time"


def test_cols():
df = pd.DataFrame(columns=["m_time", "lon", "lat", "temp"])
assert df.cf.axes_cols == ["m_time"]
assert sorted(df.cf.coordinates_cols) == ["lat", "lon", "m_time"]

0 comments on commit 614f6de

Please sign in to comment.