From 934e7a20a2d8a2ec5fb7e4b219ee4313ff81d9dd Mon Sep 17 00:00:00 2001 From: Kristen Thyng Date: Fri, 13 Oct 2023 13:02:28 -0500 Subject: [PATCH] more options for running different cases --- docs/conf.py | 2 +- docs/datasets.md | 6 +- docs/whats_new.md | 4 + ocean_model_skill_assessor/featuretype.py | 23 +- ocean_model_skill_assessor/main.py | 282 +++++++++++++----- ocean_model_skill_assessor/plot/__init__.py | 60 +++- ocean_model_skill_assessor/plot/line.py | 4 +- ocean_model_skill_assessor/plot/surface.py | 5 +- ocean_model_skill_assessor/utils.py | 129 +++++++- .../vocab/vocab_labels.json | 3 + tests/test_datasets.py | 21 +- 11 files changed, 423 insertions(+), 116 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 590d1eb..6ba367f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -115,7 +115,7 @@ # https://myst-nb.readthedocs.io/en/v0.9.0/use/execute.html # jupyter_execute_notebooks = "auto" # deprecated -nb_execution_mode = "force" +nb_execution_mode = "cache" # -- nbsphinx specific options ---------------------------------------------- # this allows notebooks to be run even if they produce errors. diff --git a/docs/datasets.md b/docs/datasets.md index 21b78de..c9c836f 100644 --- a/docs/datasets.md +++ b/docs/datasets.md @@ -9,8 +9,12 @@ The NCEI netCDF feature types are useful because they describe what does and doe | | timeSeries | profile | timeSeriesProfile | trajectory (TODO) | trajectoryProfile | grid (TODO) | |--- |--- |--- |--- |--- | --- | --- | | Definition | only t changes | only z changes | t and z change | t, y, and x change | t, z, y, and x change | t changes, y/x grid | -| Data types | mooring, buoy | CTD profile | moored ADCP | flow through, surface/drogued drifter | glider, transect of CTD profiles, towed ADCP | satellite, HF Radar | +| Data types | mooring, buoy | CTD profile | moored ADCP | flow through, 2D drifter | glider, transect of CTD profiles, towed ADCP, 3D drifter | satellite, HF Radar | | maptypes | point | point | point | point(s), line, box | point(s), line, box | box | +| X/Y are pairs (locstream) or grid | either locstream or grid | either locstream or grid | either locstream or grid | locstream | locstream | grid | +| Which dimensions are independent from X/Y choice? | +| T | Independent | Independent | Independent | match X/Y | match X/Y | Independent | +| Z | Independent | Independent | Independent | Independent | match X/Y | Independent | diff --git a/docs/whats_new.md b/docs/whats_new.md index d5925f7..514b04c 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -1,5 +1,9 @@ # What's New +## v1.1.0 (October 13, 2023) +* Continuing to improve functionality of flags to be able to control how model output is extracted +* making code more robust to different use cases + ## v1.0.0 (October 5, 2023) * more modularized code structure with much more testing * requires datasets to include catalog metadata of NCEI feature type and maptype (for plotting): diff --git a/ocean_model_skill_assessor/featuretype.py b/ocean_model_skill_assessor/featuretype.py index ca3a008..6daf5ef 100644 --- a/ocean_model_skill_assessor/featuretype.py +++ b/ocean_model_skill_assessor/featuretype.py @@ -3,17 +3,26 @@ ftconfig = {} ftconfig["timeSeries"] = { - "make_time_series": False, + "locstreamT": False, + "locstreamZ": False, } ftconfig["profile"] = { - "make_time_series": False, -} -ftconfig["trajectoryProfile"] = { - "make_time_series": True, + "locstreamT": False, + "locstreamZ": False, } ftconfig["timeSeriesProfile"] = { - "make_time_series": True, + "locstreamT": False, + "locstreamZ": False, +} +ftconfig["trajectory"] = { + "locstreamT": True, + "locstreamZ": False, +} +ftconfig["trajectoryProfile"] = { + "locstreamT": True, + "locstreamZ": True, } ftconfig["grid"] = { - "make_time_series": False, + "locstreamT": False, + "locstreamZ": False, } diff --git a/ocean_model_skill_assessor/main.py b/ocean_model_skill_assessor/main.py index b3d12bb..447b602 100644 --- a/ocean_model_skill_assessor/main.py +++ b/ocean_model_skill_assessor/main.py @@ -44,6 +44,7 @@ check_dataset, coords1Dto2D, find_bbox, + fix_dataset, get_mask, kwargs_search_from_model, open_catalogs, @@ -66,6 +67,7 @@ def make_local_catalog( metadata: dict = None, metadata_catalog: dict = None, skip_entry_metadata: bool = False, + skip_strings: Optional[list] = None, kwargs_open: Optional[Dict] = None, logger=None, ) -> Catalog: @@ -91,6 +93,8 @@ def make_local_catalog( Metadata for catalog. skip_entry_metadata : bool, optional This is useful for testing in which case we don't want to actually read the file. If you are making a catalog file for a model, you may want to set this to `True` to avoid reading it all in for metadata. + skip_strings : list of strings, optional + If provided, source_names in catalog will only be checked for goodness if they do not contain one of skip_strings. For example, if `skip_strings=["_base"]` then any source in the catalog whose name contains that string will be skipped. kwargs_open : dict, optional Keyword arguments to pass on to the appropriate ``intake`` ``open_*`` call for model or dataset. @@ -247,7 +251,7 @@ def make_local_catalog( # this allows for not checking a model catalog if not skip_entry_metadata: - check_catalog(cat) + check_catalog(cat, skip_strings=skip_strings) return cat @@ -261,6 +265,7 @@ def make_catalog( kwargs: Optional[Dict[str, Any]] = None, kwargs_search: Optional[Dict[str, Union[str, int, float]]] = None, kwargs_open: Optional[Dict] = None, + skip_strings: Optional[list] = None, vocab: Optional[Union[Vocab, str, PurePath]] = None, return_cat: bool = True, save_cat: bool = False, @@ -296,6 +301,8 @@ def make_catalog( kwargs_open : dict, optional Keyword arguments to save into local catalog for model to pass on to ``xr.open_mfdataset`` call or ``pandas`` ``open_csv``. Only for use with ``catalog_type=local``. + skip_strings : list of strings, optional + If provided, source_names in catalog will only be checked for goodness if they do not contain one of skip_strings. For example, if `skip_strings=["_base"]` then any source in the catalog whose name contains that string will be skipped. vocab : str, Vocab, Path, optional Way to find the criteria to use to map from variable to attributes describing the variable. This is to be used with a key representing what variable to search for. return_cat : bool, optional @@ -361,6 +368,7 @@ def make_catalog( description=description, metadata=metadata, kwargs_open=kwargs_open, + skip_strings=skip_strings, logger=logger, **kwargs, ) @@ -409,7 +417,7 @@ def make_catalog( # this allows for not checking a model catalog if "skip_entry_metadata" in kwargs and not kwargs["skip_entry_metadata"]: - check_catalog(cat) + check_catalog(cat, skip_strings=skip_strings) if save_cat: # save cat to file @@ -911,10 +919,13 @@ def _check_prep_narrow_data( maps.pop(-1) return None, maps - elif isinstance( - dd, xr.DataArray - ) and key_variable_data not in cf_xarray.accessor._get_custom_criteria( - dd, key_variable_data, vocab.vocab + elif ( + isinstance(dd, xr.DataArray) + and vocab is not None + and key_variable_data + not in cf_xarray.accessor._get_custom_criteria( + dd, key_variable_data, vocab.vocab + ) ): msg = f"Key variable {key_variable_data} cannot be identified in dataset {source_name}. Skipping dataset.\n" logger.warning(msg) @@ -931,41 +942,40 @@ def _check_prep_narrow_data( dd.drop(col, axis=1, inplace=True) if isinstance(dd, pd.DataFrame): - # ONLY DO THIS FOR DATAFRAMES - # dd.cf["T"] = to_datetime(dd.cf["T"]) - # dd.set_index(dd.cf["T"], inplace=True) + # shouldn't need to deal with multi-indices anymore # deal with possible time zone - if isinstance(dd.index, pd.core.indexes.multi.MultiIndex): - index = dd.index.get_level_values(dd.cf["T"].name) - else: - index = dd.index + # if isinstance(dd.index, pd.core.indexes.multi.MultiIndex): + # index = dd.index.get_level_values(dd.cf["T"].name) + # else: + # index = dd.index - if hasattr(index, "tz") and index.tz is not None: + # if hasattr(index, "tz") and index.tz is not None: + if dd.cf["T"].dt.tz is not None: logger.warning( "Dataset %s had a timezone %s which is being removed. Make sure the timezone matches the model output.", source_name, - str(index.tz), + str(dd.cf["T"].dt.tz), ) # remove time zone - index = index.tz_convert(None) - - if isinstance(dd.index, pd.core.indexes.multi.MultiIndex): - # loop over levels in index so we know which level to replace - inds = [] - for lev in range(dd.index.nlevels): - ind = dd.index.get_level_values(lev) - if dd.index.names[lev] == dd.cf["T"].name: - ind = ind.tz_convert(None) - inds.append(ind) - dd = dd.set_index(inds) - - # ilev = dd.index.names.index(index.name) - # dd.index = dd.index.set_levels(index, level=ilev) - # # dd.index.set_index([]) - else: - dd.index = index # dd.index.tz_convert(None) - dd.cf["T"] = index # dd.index + dd.cf["T"] = dd.cf["T"].dt.tz_convert(None) + + # if isinstance(dd.index, pd.core.indexes.multi.MultiIndex): + # # loop over levels in index so we know which level to replace + # inds = [] + # for lev in range(dd.index.nlevels): + # ind = dd.index.get_level_values(lev) + # if dd.index.names[lev] == dd.cf["T"].name: + # ind = ind.tz_convert(None) + # inds.append(ind) + # dd = dd.set_index(inds) + + # # ilev = dd.index.names.index(index.name) + # # dd.index = dd.index.set_levels(index, level=ilev) + # # # dd.index.set_index([]) + # else: + # dd.index = index # dd.index.tz_convert(None) + # dd.cf["T"] = index # dd.index # # make sure index is sorted ascending so time goes forward # dd = dd.sort_index() @@ -980,7 +990,11 @@ def _check_prep_narrow_data( # if pd.notnull(user_min_time) and pd.notnull(user_max_time) and (abs(data_min_time - user_min_time) <= pd.Timedelta("1 day")) and (abs(data_max_time - user_max_time) >= pd.Timedelta("1 day")): # if pd.notnull(user_min_time) and pd.notnull(user_max_time) and (data_min_time <= user_min_time) and (data_max_time >= user_max_time): # if data_time_range.encompass(model_time_range): - dd = dd.loc[user_min_time:user_max_time] + dd = ( + dd.set_index(dd.cf["T"]) + .loc[user_min_time:user_max_time] + .reset_index(drop=True) + ) else: dd = dd @@ -1138,7 +1152,7 @@ def _return_p1( def _return_data_locations( - maps: list, dd: Union[pd.DataFrame, xr.Dataset], logger=None + maps: list, dd: Union[pd.DataFrame, xr.Dataset], featuretype: str, logger=None ) -> Tuple[Union[float, np.array], Union[float, np.array]]: """Return lon, lat locations from dataset. @@ -1148,6 +1162,8 @@ def _return_data_locations( Each entry is a list of information about a dataset; the last entry is for the present source_name or dataset. Each entry contains [min_lon, max_lon, min_lat, max_lat, source_name] and possibly an additional element containing "maptype". dd : Union[pd.DataFrame, xr.Dataset] Dataset + featuretype : str + NCEI feature type for dataset logger : optional logger, by default None @@ -1161,7 +1177,12 @@ def _return_data_locations( min_lon, max_lon, min_lat, max_lat, source_name = maps[-1][:5] # logic for one or multiple lon/lat locations - if min_lon != max_lon or min_lat != max_lat: + if ( + min_lon != max_lon + or min_lat != max_lat + or featuretype == "trajectory" + or featuretype == "trajectoryProfile" + ): if logger is not None: logger.info( f"Source {source_name} is not stationary so using multiple locations." @@ -1584,7 +1605,14 @@ def _select_process_save_model( else: no_Z = False - check_dataset(model_var, no_Z=no_Z) + model_var = model_var.cf.guess_coord_axis() + + try: + check_dataset(model_var, no_Z=no_Z) + except KeyError: + # see if I can fix it + model_var = fix_dataset(model_var, dam) + check_dataset(model_var, no_Z=no_Z) if logger is not None: logger.info(f"Saving model output to file...") @@ -1598,7 +1626,7 @@ def run( project_name: str, key_variable: Union[str, dict], model_name: Union[str, Catalog], - vocabs: Union[str, Vocab, Sequence, PurePath], + vocabs: Optional[Union[str, Vocab, Sequence, PurePath]] = None, vocab_labels: Optional[Union[str, PurePath, dict]] = None, ndatasets: Optional[int] = None, kwargs_map: Optional[Dict] = None, @@ -1609,7 +1637,9 @@ def run( dd: int = 2, preprocess: bool = False, need_xgcm_grid: bool = False, + xcmocean_options: Optional[dict] = None, kwargs_xroms: Optional[dict] = None, + locstream: bool = True, interpolate_horizontal: bool = True, horizontal_interp_code="delaunay", save_horizontal_interp_weights: bool = True, @@ -1630,6 +1660,9 @@ def run( plot_count_title: bool = True, cache_dir: Optional[Union[str, PurePath]] = None, return_fig: bool = False, + override_model: bool = False, + override_processed: bool = False, + override_stats: bool = False, **kwargs, ): """Run the model-data comparison. @@ -1649,7 +1682,7 @@ def run( model_name : str, Catalog Name of catalog for model output, created with ``make_catalog`` call, or Catalog instance. vocabs : str, list, Vocab, PurePath, optional - Criteria to use to map from variable to attributes describing the variable. This is to be used with a key representing what variable to search for. This input is for the name of one or more existing vocabularies which are stored in a user application cache. + Criteria to use to map from variable to attributes describing the variable. This is to be used with a key representing what variable to search for. This input is for the name of one or more existing vocabularies which are stored in a user application cache. This should be supplied, however it is made optional because it could be provided by setting it outside of the OMSA code. vocab_labels : str, dict, Path, optional Ultimately a dictionary whose keys match the input vocab and values have strings to be used in plot labels, such as "Sea water temperature [C]" for the key "temp". They can be input from a stored file or as a dict. ndatasets : int, optional @@ -1672,6 +1705,11 @@ def run( If True, try to set up xgcm grid for run, which will be used for the variable calculation for the model. kwargs_xroms : dict Optional keyword arguments to pass to xroms.open_dataset + locstream: boolean, optional + Which type of interpolation to do, passed to em.select(): + + * False: 2D array of points with 1 dimension the lons and the other dimension the lats. + * True: lons/lats as unstructured coordinate pairs (in xESMF language, LocStream). interpolate_horizontal : bool, optional If True, interpolate horizontally. Otherwise find nearest model points. horizontal_interp_code: str @@ -1714,6 +1752,12 @@ def run( dict with keys that match input vocab for putting labels with units on the plots. User has to make sure they match both the data and model; there is no unit handling. return_fig: bool Set to True to return all outputs from this function. Use for testing. Only works if using a single source. + override_model : bool + Flag to force-redo model selection. Default False. + override_processed : bool + Flag to force-redo model and data processing. Default False. + override_stats : bool + Flag to force-redo stats calculation. Default False. """ paths = Paths(project_name, cache_dir=cache_dir) @@ -1729,17 +1773,20 @@ def run( mask = None # After this, we have a single Vocab object with vocab stored in vocab.vocab - vocab = open_vocabs(vocabs, paths) - # now we shouldn't need to worry about this for the rest of the run right? - cfp_set_options(custom_criteria=vocab.vocab) - cfx_set_options(custom_criteria=vocab.vocab) + if vocabs is not None: + vocab = open_vocabs(vocabs, paths) + # now we shouldn't need to worry about this for the rest of the run right? + cfp_set_options(custom_criteria=vocab.vocab) + cfx_set_options(custom_criteria=vocab.vocab) + else: + vocab = None # After this, we have None or a dict with key, values of vocab keys, string description for plot labels if vocab_labels is not None: vocab_labels = open_vocab_labels(vocab_labels, paths) # Open and check catalogs. - cats = open_catalogs(catalogs, paths) + cats = open_catalogs(catalogs, paths, skip_strings=["_base", "_all", "_tidecons"]) # Warning about number of datasets ndata = np.sum([len(list(cat)) for cat in cats]) @@ -1781,6 +1828,9 @@ def run( "key_variables" in cat[source_name].metadata and key_variable not in cat[source_name].metadata["key_variables"] ): + logger.info( + f"no `key_variables` key found in source metadata or at least not {key_variable}" + ) continue min_lon = cat[source_name].metadata["minLongitude"] @@ -1846,6 +1896,13 @@ def run( maps.pop(-1) continue + except Exception as e: + logger.warning(str(e)) + msg = f"Data cannot be loaded for dataset {source_name}. Skipping dataset.\n" + logger.warning(msg) + maps.pop(-1) + continue + # Need to have this here because if model file has previously been read in but # aligned file doesn't exist yet, this needs to run to update the sign of the # data depths in certain cases. @@ -1902,11 +1959,11 @@ def run( ) # read in previously-saved processed model output and obs. - if fname_processed_data.is_file() or fname_processed_model.is_file(): - # make sure both exist if either exist - assert ( - fname_processed_data.is_file() and fname_processed_model.is_file() - ) + if ( + not override_processed + and fname_processed_data.is_file() + and fname_processed_model.is_file() + ): logger.info( "Reading previously-processed model output and data for %s.", @@ -1916,13 +1973,19 @@ def run( obs = pd.read_csv(fname_processed_data) obs = check_dataframe(obs, no_Z) elif isinstance(dfd, xr.Dataset): - obs = xr.open_dataset(fname_processed_data) + obs = xr.open_dataset(fname_processed_data).cf.guess_coord_axis() check_dataset(obs, is_model=False, no_Z=no_Z) else: raise TypeError("object is neither DataFrame nor Dataset.") - model = xr.open_dataset(fname_processed_model) - check_dataset(model, no_Z=no_Z) + model = xr.open_dataset(fname_processed_model).cf.guess_coord_axis() + # check_dataset(model, no_Z=no_Z) + try: + check_dataset(model, no_Z=no_Z) + except KeyError: + # see if I can fix it + model = fix_dataset(model, dsm) + check_dataset(model, no_Z=no_Z) else: logger.info( @@ -1949,7 +2012,7 @@ def run( continue # Read in model output from cache if possible. - if model_file_name.is_file(): + if not override_model and model_file_name.is_file(): logger.info("Reading model output from file.") model_var = xr.open_dataset(model_file_name) # model_var = xr.open_dataarray(model_file_name) @@ -1960,7 +2023,13 @@ def run( model_var = model_var.cf.guess_coord_axis() model_var = model_var.cf[key_variable_data] # distance = model_var.attrs["distance_from_location_km"] - check_dataset(model_var, no_Z=no_Z) + # check_dataset(model_var, no_Z=no_Z) + try: + check_dataset(model_var, no_Z=no_Z) + except KeyError: + # see if I can fix it + model_var = fix_dataset(model_var, dsm) + check_dataset(model_var, no_Z=no_Z) if model_only: logger.info("Running model only so moving on to next source...") @@ -1970,7 +2039,9 @@ def run( else: # lons, lats might be one location or many - lons, lats = _return_data_locations(maps, dfd, logger) + lons, lats = _return_data_locations( + maps, dfd, cat[source_name].metadata["featuretype"], logger + ) # narrow time range to limit how much model output to deal with dsm2 = _narrow_model_time_range( @@ -2012,12 +2083,15 @@ def run( # if your model is too large to be treated with this way, subset the model first. dam = coords1Dto2D(dam) # this is fast if not needed - # if make_time_series then want to keep all the data times (like a CTD transect) + # if locstreamT then want to keep all the data times (like a CTD transect) # if not, just want the unique values (like a CTD profile) - make_time_series = ftconfig[ - cat[source_name].metadata["featuretype"] - ]["make_time_series"] - if make_time_series: + locstreamT = ftconfig[cat[source_name].metadata["featuretype"]][ + "locstreamT" + ] + locstreamZ = ftconfig[cat[source_name].metadata["featuretype"]][ + "locstreamZ" + ] + if locstreamT: T = [pd.Timestamp(date) for date in dfd.cf["T"].values] else: T = [ @@ -2036,14 +2110,15 @@ def run( T=T, # # works for both # T=None, # changed this because wasn't working with CTD profiles. Time interpolation happens during _align. - make_time_series=make_time_series, Z=Z, vertical_interp=vertical_interp, iT=None, iZ=None, extrap=extrap, extrap_val=None, - locstream=True, + locstream=locstream, + locstreamT=locstreamT, + locstreamZ=locstreamZ, # locstream_dim="z_rho", weights=None, mask=mask, @@ -2077,11 +2152,61 @@ def run( logger.info( f"Apply a time series modification called {mod['function']}." ) + if isinstance(dfd, pd.DataFrame): + dfd.set_index(dfd.cf["T"], inplace=True) dfd[dfd.cf[key_variable_data].name] = mod["function"]( dfd.cf[key_variable_data], **mod["inputs"] ) + if isinstance(dfd, pd.DataFrame): + dfd = dfd.reset_index(drop=True) model_var = mod["function"](model_var, **mod["inputs"]) + # there could be a small mismatch in the length of time if times were pulled + # out separately + # import pdb; pdb.set_trace() + if np.unique(model_var.cf["T"]).size != np.unique(dfd.cf["T"]).size: + # if model_var.cf["T"].size != np.unique(dfd.cf["T"]).size: + # if (isinstance(dfd, pd.DataFrame) and model_var.cf["T"].size != dfd.cf["T"].unique().size) or (isinstance(dfd, xr.Dataset) and model_var.cf["T"].size != dfd.cf["T"].drop_duplicates(dim=dfd.cf["T"].name).size): + # if len(model_var.cf["T"]) != len(dfd.cf["T"]): # timeSeries + stime = pd.Timestamp( + max(dfd.cf["T"].values[0], model_var.cf["T"].values[0]) + ) + etime = pd.Timestamp( + min(dfd.cf["T"].values[-1], model_var.cf["T"].values[-1]) + ) + model_var = model_var.cf.sel({"T": slice(stime, etime)}) + + if isinstance(dfd, pd.DataFrame): + dfd = dfd.set_index(dfd.cf["T"].name) + dfd = dfd.loc[stime:etime] + + # interpolate data to model times + # Times between data and model should already match from em.select + # except in the case that model output was cached in convenient time series + # in which case the times aren't already matched. For this case, the data + # also might be missing the occasional data points, and want + # the data index to match the model index since the data resolution might be very high. + # get combined index of model and obs to first interpolate then reindex obs to model + # otherwise only nan's come through + # accounting for known issue for interpolation after sampling if indices changes + # https://github.com/pandas-dev/pandas/issues/14297 + model_index = model_var.cf["T"].to_pandas().index + model_index.name = dfd.index.name + ind = model_index.union(dfd.index) + dfd = ( + dfd.reindex(ind) + .interpolate(method="time", limit=3) + .reindex(model_index) + ) + dfd = dfd.reset_index() + + elif isinstance(dfd, xr.Dataset): + # interpolate data to model times + # model_index = model_var.cf["T"].to_pandas().index + # ind = model_index.union(dfd.cf["T"].to_pandas().index) + dfd = dfd.interp({dfd.cf["T"].name: model_var.cf["T"].values}) + # dfd = dfd.cf.sel({"T": slice(stime, etime)}) + # Save processed data and model files # read in from newly made file to make sure output is loaded if isinstance(dfd, pd.DataFrame): @@ -2090,37 +2215,51 @@ def run( obs = check_dataframe(obs, no_Z) elif isinstance(dfd, xr.Dataset): dfd.to_netcdf(fname_processed_data) - obs = xr.open_dataset(fname_processed_data) + obs = xr.open_dataset(fname_processed_data).cf.guess_coord_axis() check_dataset(obs, is_model=False, no_Z=no_Z) else: raise TypeError("object is neither DataFrame nor Dataset.") model_var.to_netcdf(fname_processed_model) - model = xr.open_dataset(fname_processed_model) - check_dataset(model, no_Z=no_Z) + model = xr.open_dataset(fname_processed_model).cf.guess_coord_axis() + # check_dataset(model, no_Z=no_Z) + try: + check_dataset(model, no_Z=no_Z) + except KeyError: + # see if I can fix it + model = fix_dataset(model, dsm) + check_dataset(model, no_Z=no_Z) logger.info(f"model file name is {model_file_name}.") - if model_file_name.is_file(): + if not override_model and model_file_name.is_file(): logger.info("Reading model output from file.") - model_var = xr.open_dataset(model_file_name) - check_dataset(model_var, no_Z=no_Z) + model_var = xr.open_dataset(model_file_name).cf.guess_coord_axis() + # check_dataset(model_var, no_Z=no_Z) + try: + check_dataset(model_var, no_Z=no_Z) + except KeyError: + # see if I can fix it + model_var = fix_dataset(model_var, dsm) + check_dataset(model_var, no_Z=no_Z) if not interpolate_horizontal: distance = model_var["distance"] # distance = model_var.attrs["distance_from_location_km"] else: - raise ValueError("If the aligned file is available need this one too.") + raise ValueError( + "If the processed files are available need this one too." + ) stats_fname = (paths.OUT_DIR / f"{fname_processed.stem}").with_suffix( ".yaml" ) - if stats_fname.is_file(): + if not override_stats and stats_fname.is_file(): logger.info("Reading from previously-saved stats file.") with open(stats_fname, "r") as stream: stats = yaml.safe_load(stream) else: stats = compute_stats( - obs.cf[key_variable_data], model.cf[key_variable_data] + obs.cf[key_variable_data], model.cf[key_variable_data].squeeze() ) # stats = obs.omsa.compute_stats @@ -2156,6 +2295,7 @@ def run( stats, figname, vocab_labels, + xcmocean_options=xcmocean_options, **kwargs, ) msg = f"Plotted time series for {source_name}\n." diff --git a/ocean_model_skill_assessor/plot/__init__.py b/ocean_model_skill_assessor/plot/__init__.py index 8f9a8cb..148eca3 100644 --- a/ocean_model_skill_assessor/plot/__init__.py +++ b/ocean_model_skill_assessor/plot/__init__.py @@ -25,10 +25,25 @@ def selection( stats: dict, figname: Union[str, pathlib.Path], vocab_labels: Optional[dict] = None, + xcmocean_options: Optional[dict] = None, **kwargs, ) -> figure: """Plot.""" + # must contain keys + if xcmocean_options is not None: + if any( + [ + key + for key in xcmocean_options.keys() + if key not in ["regexin", "seqin", "divin"] + ] + ): + raise KeyError( + 'keys for `xcmocean_options` must be ["regexin", "seqin", "divin"]' + ) + xcmocean.set_options(**xcmocean_options) + if vocab_labels is not None: key_variable_label = vocab_labels[key_variable] else: @@ -48,11 +63,16 @@ def selection( # add location info # always show first/only location - loc = f"lon: {obs.cf['longitude'][0]:.2f} lat: {obs.cf['latitude'][0]:.2f}" + if obs.cf["longitude"].size == 1: + loc = f"lon: {float(obs.cf['longitude']):.2f} lat: {float(obs.cf['latitude']):.2f}" + else: + loc = f"lon: {obs.cf['longitude'][0]:.2f} lat: {obs.cf['latitude'][0]:.2f}" # time = f"{str(obs.cf['T'][0].date())}" # worked for DF time = str(pd.Timestamp(obs.cf["T"].values[0]).date()) # works for DF and DS # only shows depths if 1 depth since otherwise will be on plot - if np.unique(obs.cf["Z"]).size == 1: + if np.unique(obs.cf["Z"][~np.isnan(obs.cf["Z"])]).size == 1: + # if (np.unique(obs.cf["Z"]) * ~np.isnan(obs.cf["Z"])).size == 1: + # if np.unique(obs[obs.cf["Z"].notnull()].cf["Z"]).size == 1: # did not work for timeSeriesProfile depth = f"depth: {obs.cf['Z'][0]}" title = f"{source_name}: {stat_sum}\n{time} {depth} {loc}" else: @@ -95,16 +115,32 @@ def selection( ) elif featuretype == "trajectoryProfile": - xname, yname, zname = "distance", "Z", key_variable - xlabel, ylabel, zlabel = ( - "along-transect distance [km]", - "Depth [m]", - key_variable_label, - ) - if "distance" not in obs.cf: - along_transect_distance = True + # Assume want along-transect distance if number of unique locations is + # equal to or more than number of times + if ( + np.unique(obs.cf["longitude"]).size >= np.unique(obs.cf["T"]).size + or np.unique(obs.cf["latitude"]).size >= np.unique(obs.cf["T"]).size + ): + xname, yname, zname = "distance", "Z", key_variable + xlabel, ylabel, zlabel = ( + "along-transect distance [km]", + "Depth [m]", + key_variable_label, + ) + if "distance" not in obs.cf: + along_transect_distance = True + else: + along_transect_distance = False + # otherwise use time for x axis else: + xname, yname, zname = "T", "Z", key_variable + xlabel, ylabel, zlabel = ( + "", + "Depth [m]", + key_variable_label, + ) along_transect_distance = False + fig = surface.plot( obs, model, @@ -128,8 +164,8 @@ def selection( xname, yname, zname = "T", "Z", key_variable xlabel, ylabel, zlabel = "", "Depth [m]", key_variable_label fig = surface.plot( - obs, - model, + obs.squeeze(), + model.squeeze(), xname, yname, zname, diff --git a/ocean_model_skill_assessor/plot/line.py b/ocean_model_skill_assessor/plot/line.py index 8a026d6..d0dc87c 100644 --- a/ocean_model_skill_assessor/plot/line.py +++ b/ocean_model_skill_assessor/plot/line.py @@ -70,8 +70,8 @@ def plot( fig, ax = plt.subplots(1, 1, figsize=figsize, layout="constrained") ax.plot(obs.cf[xname], obs.cf[yname], label="data", lw=lw, color=col_obs) ax.plot( - np.array(model.cf[xname]), - np.array(model.cf[yname]), + np.array(model.cf[xname].squeeze()), + np.array(model.cf[yname].squeeze()), label="model", lw=lw, color=col_model, diff --git a/ocean_model_skill_assessor/plot/surface.py b/ocean_model_skill_assessor/plot/surface.py index 0711b09..8a957bb 100644 --- a/ocean_model_skill_assessor/plot/surface.py +++ b/ocean_model_skill_assessor/plot/surface.py @@ -87,7 +87,7 @@ def plot( if isinstance(obs, xr.Dataset): obs = obs.to_dataframe() if isinstance(model, xr.Dataset): - model = model.to_dataframe() + model = model.to_dataframe().reset_index() # using .values on obs prevents name clashes for time and depth model["diff"] = obs.cf[zname].values - model.cf[zname] # want obs and data as Datasets @@ -101,6 +101,7 @@ def plot( model = model.to_xarray() # using .values on obs prevents name clashes for time and depth model["diff"] = obs.cf[zname].values - model.cf[zname] + # model["diff"] = obs.cf[zname].values - model.cf[zname] model["diff"].attrs = {} else: raise ValueError("`kind` should be scatter or pcolormesh.") @@ -190,6 +191,7 @@ def plot( axes[1].set_title("Model", fontsize=fs_title) axes[1].set_xlabel(xlabel, fontsize=fs) axes[1].set_ylabel("") + axes[1].set_xlim(axes[0].get_xlim()) axes[1].set_ylim(axes[0].get_ylim()) # save space by not relabeling y axis axes[1].set_yticklabels("") @@ -215,6 +217,7 @@ def plot( axes[2].set_title("Obs - Model", fontsize=fs_title) axes[2].set_xlabel(xlabel, fontsize=fs) axes[2].set_ylabel("") + axes[2].set_xlim(axes[0].get_xlim()) axes[2].set_ylim(axes[0].get_ylim()) axes[2].set_ylim(obs.cf[yname].min(), obs.cf[yname].max()) axes[2].set_yticklabels("") diff --git a/ocean_model_skill_assessor/utils.py b/ocean_model_skill_assessor/utils.py index 71a4998..70f29a1 100644 --- a/ocean_model_skill_assessor/utils.py +++ b/ocean_model_skill_assessor/utils.py @@ -26,6 +26,55 @@ from .paths import Paths +def fix_dataset( + model_var: Union[xr.DataArray, xr.Dataset], ds: Union[xr.DataArray, xr.Dataset] +) -> Union[xr.DataArray, xr.Dataset]: + """Fill in info necessary to pass `check_dataset()` if possible. + + Right now it is only for converting horizontal indices to lon/lat but conceivably could do more in the future. Looks for lon/lat being 2D coords. + + Parameters + ---------- + model_var : Union[xr.DataArray,xr.Dataset] + xarray object that needs some more info filled in + ds : Union[xr.DataArray,xr.Dataset] + xarray object that has info that can be used to fill in model_var + + Returns + ------- + Union[xr.DataArray,xr.Dataset] + model_var with more information included, hopefully + """ + lonkey, latkey = ds.cf["longitude"].name, ds.cf["latitude"].name + X, Y = model_var.cf["X"], model_var.cf["Y"] + + if ( + "longitude" not in model_var.cf + and "X" in model_var.cf + and "longitude" in ds.cf + and ds.cf["longitude"].ndim == 2 + ): + # model_var[lonkey] = ds.cf["longitude"].isel({Y.name: Y, X.name: X}) + # model_var[lonkey].attrs = ds[lonkey].attrs + model_var = model_var.assign_coords( + {lonkey: ds.cf["longitude"].isel({Y.name: Y, X.name: X})} + ) + + if ( + "latitude" not in model_var.cf + and "Y" in model_var.cf + and "latitude" in ds.cf + and ds.cf["latitude"].ndim == 2 + ): + # model_var[latkey] = ds.cf["latitude"].isel({Y.name: Y, X.name: X}) + # model_var[latkey].attrs = ds[latkey].attrs + model_var = model_var.assign_coords( + {latkey: ds.cf["latitude"].isel({Y.name: Y, X.name: X})} + ) + + return model_var + + def check_dataset( ds: Union[xr.DataArray, xr.Dataset], is_model: bool = True, no_Z: bool = False ): @@ -102,13 +151,21 @@ def check_dataframe(dfd: pd.DataFrame, no_Z: bool) -> pd.DataFrame: return dfd -def check_catalog(cat: Catalog): +def check_catalog( + cat: Catalog, + source_names: Optional[list] = None, + skip_strings: Optional[list] = None, +): """Check a catalog for required keys. Parameters ---------- catalogs : Catalog Catalog object + source_names : list + Use these source_names instead of list(cat) if entered, for checking. + skip_strings : list of strings, optional + If provided, source_names in catalog will only be checked for goodness if they do not contain one of skip_strings. For example, if `skip_strings=["_base"]` then any source in the catalog whose name contains that string will be skipped. """ @@ -123,7 +180,19 @@ def check_catalog(cat: Catalog): "maptype", } - for source_name in list(cat): + skip_strings = skip_strings or [] + + if source_names is None: + source_names = list(cat) + + for skip_string in skip_strings: + source_names = [ + source_name + for source_name in source_names + if skip_string not in source_name + ] + + for source_name in source_names: missing_keys = set(required_keys) - set(cat[source_name].metadata.keys()) if len(missing_keys) > 0: @@ -136,8 +205,9 @@ def check_catalog(cat: Catalog): "profile", "trajectoryProfile", "timeSeriesProfile", + "grid", ] - future_featuretypes = ["trajectory", "grid"] + future_featuretypes = ["trajectory"] if cat[source_name].metadata["featuretype"] in future_featuretypes: raise KeyError( @@ -159,6 +229,7 @@ def open_catalogs( catalogs: Union[str, Catalog, Sequence], paths: Optional[Paths] = None, skip_check: bool = False, + skip_strings: Optional[list] = None, ) -> List[Catalog]: """Initialize catalog objects from inputs. @@ -170,6 +241,8 @@ def open_catalogs( Paths object for finding paths to use. Required if any catalog is a string referencing paths. skip_check : bool If True, do not check catalogs. Use this for testing as needed. Default is False. + skip_strings : list of strings, optional + If provided, source_names in catalog will only be checked for goodness if they do not contain one of skip_strings. For example, if `skip_strings=["_base"]` then any source in the catalog whose name contains that string will be skipped. Returns ------- @@ -192,7 +265,7 @@ def open_catalogs( ) if not skip_check: - check_catalog(cat) + check_catalog(cat, skip_strings=skip_strings) cats.append(cat) return cats @@ -705,30 +778,38 @@ def kwargs_search_from_model( return kwargs_search -def calculate_anomaly(dd: Union[pd.Series, xr.DataArray], monthly_mean) -> pd.Series: +def calculate_anomaly( + dd_in: Union[pd.Series, pd.DataFrame, xr.DataArray], monthly_mean +) -> pd.Series: """Given monthly mean that is indexed by month of year, subtract it from time series to get anomaly. - Should work with both pd.Series and xr. DataArray. + Should work with both pd.Series/pd.DataFrame and xr. DataArray. Assume that variable in monthly_mean is the same as in the input time series. The way it works for DataArrays is by changing it to a DataFrame. Assumes this is a time series. - Returns either as a pd.Series. Is that a problem? + Returns dd as the type as DataFrame it is came in as Series and Dataset if it came in DataArray. It is pd.Series in the middle so this probably won't work well for datasets that are more complex than time series. """ - varname = dd.name + varname = dd_in.name varname_mean = f"{varname}_mean" varname_anomaly = f"{varname}_anomaly" # if monthly_mean is None: # monthly_mean = dd[varname].groupby(dd.cf["T"].dt.month).mean() - if isinstance(dd, xr.DataArray): - dd = dd.squeeze().to_dataframe() + # in_type = type(dd) - elif isinstance(dd, pd.Series): - dd = dd.to_frame() # this changes dd into a DataFrame + # if isinstance(dd, xr.DataArray): + # dd = dd.squeeze().to_dataframe() - dd["time"] = dd.index.values # save times + if isinstance(dd_in, pd.Series): + dd_in = dd_in.to_frame() # this changes dd into a DataFrame + + # import pdb; pdb.set_trace() + + dd = pd.DataFrame() + dd["time"] = dd_in.cf["T"].values + # dd["time"] = dd_in.index.values # save times dd = dd.set_index(dd["time"].dt.month) dd[varname_mean] = monthly_mean dd = dd.set_index(dd["time"].name) @@ -743,9 +824,25 @@ def calculate_anomaly(dd: Union[pd.Series, xr.DataArray], monthly_mean) -> pd.Se dd.loc[inan, varname_mean] = pd.NA dd[varname_mean] = dd[varname_mean].interpolate() - dd[varname_anomaly] = dd[varname] - dd[varname_mean] - - return dd[varname_anomaly] + dd[varname_anomaly] = dd_in.squeeze() - dd[varname_mean] + + # return in original container + if isinstance(dd_in, xr.DataArray): + dd_out = xr.DataArray( + coords={dd_in.cf["T"].name: dd.index.values}, + data=dd[varname_anomaly].values, + ).broadcast_like(dd_in) + if len(dd_in.coords) > len(dd_out.coords): + coordstoadd = list(set(dd_in.coords) - set(dd_out.coords)) + for coord in coordstoadd: + dd_out[coord] = dd_in[coord] + dd_out.attrs = dd_in.attrs + dd_out.name = dd_in.name + + elif isinstance(dd_in, (pd.Series, pd.DataFrame)): + dd_out = dd[varname_anomaly] + + return dd_out def calculate_distance(lons, lats): diff --git a/ocean_model_skill_assessor/vocab/vocab_labels.json b/ocean_model_skill_assessor/vocab/vocab_labels.json index c2540a7..624dfef 100644 --- a/ocean_model_skill_assessor/vocab/vocab_labels.json +++ b/ocean_model_skill_assessor/vocab/vocab_labels.json @@ -4,6 +4,9 @@ "u": "x-axis velocity [m/s]", "v": "y-axis velocity [m/s]", "w": "z velocity [m/s]", +"along": "Along-channel velocity [m/s]", +"across": "Across-channel velocity [m/s]", +"speed": "Horizontal speed [m/s]", "east": "Eastward velocity [m/s]", "north": "Northward velicity [m/s]", "water_dir": "Sea water direction [degrees]", diff --git a/tests/test_datasets.py b/tests/test_datasets.py index f863a51..7e7e022 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -6,6 +6,7 @@ import cf_pandas as cfp import cf_xarray as cfx +import numpy as np import pandas as pd import pytest import xarray as xr @@ -247,7 +248,10 @@ def check_output(cat, featuretype, key_variable, project_cache, no_Z): ) dsexpected = xr.open_dataset(base_dir / rel_path) dsactual = xr.open_dataset(project_cache / "tests" / rel_path) - assert dsexpected.equals(dsactual) + for var in dsexpected.coords: + assert dsexpected[var].equals(dsactual[var]) + for var in dsexpected.data_vars: + np.allclose(dsexpected[var], dsactual[var], equal_nan=True) # compare saved stats rel_path = pathlib.Path("out", f"{cat.name}_{featuretype}_{key_variable}.yaml") @@ -257,9 +261,12 @@ def check_output(cat, featuretype, key_variable, project_cache, no_Z): statsactual = yaml.safe_load(fp) for key in statsexpected.keys(): try: - TestCase().assertAlmostEqual( - statsexpected[key]["value"], statsactual[key]["value"], places=5 - ) + if isinstance(statsexpected[key]["value"], list): + np.allclose(statsexpected[key]["value"], statsactual[key]["value"]) + else: + TestCase().assertAlmostEqual( + statsexpected[key]["value"], statsactual[key]["value"], places=5 + ) except AssertionError as msg: print(msg) @@ -292,7 +299,11 @@ def check_output(cat, featuretype, key_variable, project_cache, no_Z): ) dsexpected = xr.open_dataset(base_dir / rel_path) dsactual = xr.open_dataset(project_cache / "tests" / rel_path) - assert dsexpected.equals(dsactual) + # assert dsexpected.equals(dsactual) + for var in dsexpected.coords: + assert dsexpected[var].equals(dsactual[var]) + for var in dsexpected.data_vars: + np.allclose(dsexpected[var], dsactual[var], equal_nan=True) def test_bad_catalog(dataset_filenames):