From 43fc9b21938937b256693785df66402f96afa63b Mon Sep 17 00:00:00 2001 From: Franziska Winterstein Date: Thu, 6 Jun 2024 13:57:47 +0200 Subject: [PATCH] corrected _get_plot_kwargs, usable for timeseries and zonalmean --- esmvaltool/diag_scripts/lifetime/lifetime.py | 169 ++++++++++++------ .../diag_scripts/lifetime/lifetime_base.py | 56 ++---- 2 files changed, 127 insertions(+), 98 deletions(-) diff --git a/esmvaltool/diag_scripts/lifetime/lifetime.py b/esmvaltool/diag_scripts/lifetime/lifetime.py index aac7229e28..617e8c6b52 100644 --- a/esmvaltool/diag_scripts/lifetime/lifetime.py +++ b/esmvaltool/diag_scripts/lifetime/lifetime.py @@ -36,7 +36,7 @@ Facet used to label different datasets in plot titles and legends. For example, ``facet_used_for_labels: dataset`` will use dataset names in plot titles and legends; ``facet_used_for_labels: exp`` will use experiments in - plot titles and legends. In addition, ``facet_used_for_labels`` is used to + plot titles and legends. In addition, ``facet_used_for_labels`` is used tosa select the correct ``plot_kwargs`` for the different datasets (see configuration options for the different plot types below). figure_kwargs: dict, optional @@ -285,7 +285,8 @@ calculate_gridmassdry, calculate_lifetime, calculate_reaction_rate, - calculate_rho + calculate_rho, + climatological_tropopause ) from esmvaltool.diag_scripts.shared import ( ProvenanceLogger, @@ -326,15 +327,16 @@ def __init__(self, config): # Load input data self.input_data_dataset = self._calculate_coefficients() - # name - self.names = {'short_name': + # base info + self.info = {'short_name': f"tau_{self._get_name('reactant').upper()}"} oxidants = [ox.upper() for ox in self._get_name('oxidant')] - self.names['long_name'] = ("Lifetime of" + self.info['long_name'] = ("Lifetime of" f" {self._get_name('reactant').upper()}" " with respect to" f" {', '.join(oxidants)}") self.units = self.cfg['units'] + self.info['units'] = self.cfg['units'] # Check given plot types and set default settings for them self.supported_plot_types = [ @@ -359,6 +361,7 @@ def __init__(self, config): self.plots[plot_type].setdefault('legend_kwargs', {}) self.plots[plot_type].setdefault('plot_kwargs', {}) self.plots[plot_type].setdefault('pyplot_kwargs', {}) + self.plots[plot_type].setdefault('by_timestep', False) if plot_type == 'annual_cycle': self.plots[plot_type].setdefault('gridline_kwargs', {}) @@ -456,6 +459,7 @@ def _get_custom_mpl_rc_params(self, plot_type): def _get_label(self, dataset): """Get label of dataset.""" + print(dataset) return dataset[self.cfg['facet_used_for_labels']] def _get_cbar_kwargs(self, plot_type, bias=False): @@ -510,6 +514,7 @@ def _get_plot_kwargs(self, plot_type, dataset, bias=False): # Replace facets with dataset entries for string arguments for (key, val) in plot_kwargs.items(): + print(val) if isinstance(val, str): val = self._fill_facet_placeholders( val, @@ -537,11 +542,13 @@ def _calculate_coefficients(self): # loops over different variables ch4, oh, ta etc. input_data_dataset = {} for dataset in list_of_datasets: - input_data_dataset[dataset] = {} + input_data_dataset[dataset] = self._get_dataset_data(dataset)[0] input_data_dataset[dataset]['dataset_data'] = ( self._get_dataset_data(dataset)) - variables = {} + + + variables = {} for variable in input_data_dataset[dataset]['dataset_data']: filename = variable['filename'] @@ -585,6 +592,34 @@ def _calculate_coefficients(self): " for the present type of vertical coordinate." ) + if not set(['TROP', 'STRA']).isdisjoint(self.cfg['regions']): + + # calculate climatological tropopause pressure (tp_clim) + # but only if no tropopause is given by data + if ( + 'ptp' not in variables + and 'tp_i' not in variables): + tropopause = climatological_tropopause(variables['ta'][:, 0, :, :]) + + # If z_coord is defined as: + # - air_pressure, use: + # - ptp and air_pressure + # - tp_clim and air_pressure + # - atmosphere_hybrid_sigma_pressure_coordinate, use: + # - tp_i and model_level_number + # - ptp and (derived) air_pressure + # - tp_clim and (derived) air_pressure + use_z_coord = 'air_pressure' + if z_coord.name() == 'air_pressure': + if 'ptp' in variables: + tropopause = variables['ptp'] + elif z_coord.name() == 'atmosphere_hybrid_sigma_pressure_coordinate': + if 'tp_i' in variables: + tropopause = variables['tp_i'] + use_z_coord = 'model_level_number' + elif 'ptp' in variables: + tropopause = variables['ptp'] + weight = self._define_weight(variables) if reaction.coords('atmosphere_hybrid_sigma_pressure_coordinate', @@ -593,6 +628,8 @@ def _calculate_coefficients(self): add_model_level(reaction) input_data_dataset[dataset]['z_coord'] = z_coord + input_data_dataset[dataset]['use_z_coord'] = use_z_coord + input_data_dataset[dataset]['tropopause'] = tropopause input_data_dataset[dataset]['variables'] = variables input_data_dataset[dataset]['reaction'] = reaction input_data_dataset[dataset]['weight'] = weight @@ -732,7 +769,7 @@ def plot_zonalmean_with_ref(self, plot_func, region, region) # convert units - cube.convert_units(self.units) + cube.convert_units(self.info['units']) # Create single figure with multiple axes with mpl.rc_context(self._get_custom_mpl_rc_params(plot_type)): @@ -823,27 +860,27 @@ def plot_zonalmean_with_ref(self, plot_func, region, return (plot_path, netcdf_paths) - def plot_zonalmean_without_ref(self, plot_func, region, dataset, - base_dataset): + def plot_zonalmean_without_ref(self, plot_func, region, dataset, base_datasets, label): """Plot zonal mean profile for single dataset without reference.""" plot_type = 'zonalmean' logger.info("Plotting zonal mean profile without reference dataset" - " for '%s'", - self._get_label(dataset)) + " for '%s'", label) - # Make sure that the data has the correct dimensions + # zonalmean lifetime is calculated for each time step (sum over longitude) cube = calculate_lifetime(dataset, plot_type, region) + # lifetime is averaged over time + cube = cube.collapsed(['time'], iris.analysis.MEAN) # convert units - cube.convert_units(self.units) + cube.convert_units(self.info['units']) # Create plot with desired settings with mpl.rc_context(self._get_custom_mpl_rc_params(plot_type)): fig = plt.figure(**self.cfg['figure_kwargs']) axes = fig.add_subplot() - plot_kwargs = self._get_plot_kwargs(plot_type, base_dataset) + plot_kwargs = self._get_plot_kwargs(plot_type, base_datasets) plot_kwargs['axes'] = axes plot_zonalmean = plot_func(cube, **plot_kwargs) @@ -851,16 +888,17 @@ def plot_zonalmean_without_ref(self, plot_func, region, dataset, fontsize = self.plots[plot_type]['fontsize'] colorbar = fig.colorbar(plot_zonalmean, ax=axes, **self._get_cbar_kwargs(plot_type)) - colorbar.set_label(self._get_cbar_label(plot_type, dataset), + colorbar.set_label(self._get_cbar_label( + plot_type, self.info), fontsize=fontsize) colorbar.ax.tick_params(labelsize=fontsize) # Customize plot - axes.set_title(self._get_label(dataset)) - fig.suptitle(f"{dataset['long_name']} ({dataset['start_year']}-" + axes.set_title(label) + fig.suptitle(f"{self.info['long_name']} ({dataset['start_year']}-" f"{dataset['end_year']})") axes.set_xlabel('latitude [°N]') - z_coord = cube.coord(axis='Z') + z_coord = cube.coord(name_or_coord='air_pressure') axes.set_ylabel(f'{z_coord.long_name} [{z_coord.units}]') if self.plots[plot_type]['log_y']: axes.set_yscale('log') @@ -878,7 +916,7 @@ def plot_zonalmean_without_ref(self, plot_func, region, dataset, self._set_rasterized([axes]) # File paths - plot_path = self.get_plot_path(plot_type, base_dataset) + plot_path = self.get_plot_path(plot_type, dataset) netcdf_path = get_diagnostic_filename(Path(plot_path).stem, self.cfg) return (plot_path, {netcdf_path: cube}) @@ -949,7 +987,7 @@ def _get_multi_dataset_facets(datasets): def _get_reference_dataset(datasets, short_name): """Extract reference dataset.""" ref_datasets = [d for d in datasets if - d.get('reference_for_monitor_diags', False)] + datasets[d].get('reference_for_monitor_diags', False)] if len(ref_datasets) > 1: raise ValueError( f"Expected at most 1 reference dataset (with " @@ -975,15 +1013,37 @@ def create_timeseries_plot(self, region, input_data, base_datasets): # Plot all datasets in one single figure ancestors = [] cubes = {} + for label, dataset in input_data.items(): ancestors.extend(variable['filename'] for variable in dataset['dataset_data']) - cube = calculate_lifetime(dataset, - plot_type, - region) + # call by timestep will take longer, however is less memory intensive + if self.plots[plot_type]['by_timestep']: + slice_dataset = {} + slice_dataset['z_coord'] = dataset['z_coord'] + slice_dataset['use_z_coord'] = dataset['use_z_coord'] + cube_slices = iris.cube.CubeList() + for (reaction_slice, + weight_slice, + tp_slice) in zip(dataset['reaction'].slices_over('time'), + dataset['weight'].slices_over('time'), + dataset['tropopause'].slices_over('time')): + + slice_dataset['reaction'] = reaction_slice + slice_dataset['weight'] = weight_slice + slice_dataset['tropopause'] = tp_slice + + cube_slices.append(calculate_lifetime(slice_dataset, + plot_type, + region)) + cube = cube_slices.merge_cube() + else: + cube = calculate_lifetime(dataset, + plot_type, + region) # convert units - cube.convert_units(self.units) + cube.convert_units(self.info['units']) cubes[label] = cube self._check_cube_dimensions(cube, plot_type) @@ -1019,10 +1079,10 @@ def create_timeseries_plot(self, region, input_data, base_datasets): # Default plot appearance multi_dataset_facets = self._get_multi_dataset_facets( list(base_datasets.values())) - axes.set_title(f'{self.names["long_name"]} in region {region}') + axes.set_title(f'{self.info["long_name"]} in region {region}') axes.set_xlabel('Time') axes.set_ylabel(f"{chr(964)}({self._get_name('reactant').upper()})" - f" [{self.units}]") + f" [{self.info['units']}]") gridline_kwargs = self._get_gridline_kwargs(plot_type) if gridline_kwargs is not False: axes.grid(**gridline_kwargs) @@ -1044,9 +1104,9 @@ def create_timeseries_plot(self, region, input_data, base_datasets): # Save netCDF file netcdf_path = get_diagnostic_filename(Path(plot_path).stem, self.cfg) var_attrs = { - 'short_name': self.names['short_name'], - 'long_name': self.names['long_name'], - 'units': self.units + 'short_name': self.info['short_name'], + 'long_name': self.info['long_name'], + 'units': self.info['units'] } io.save_1d_data(cubes, netcdf_path, 'time', var_attrs) @@ -1088,7 +1148,7 @@ def create_annual_cycle_plot(self, region, input_data, base_datasets): plot_type, region) # convert units - cube.convert_units(self.units) + cube.convert_units(self.info['units']) cubes[label] = cube self._check_cube_dimensions(cube, plot_type) @@ -1102,10 +1162,10 @@ def create_annual_cycle_plot(self, region, input_data, base_datasets): # Default plot appearance multi_dataset_facets = self._get_multi_dataset_facets( list(base_datasets.values())) - axes.set_title(f'{self.names["long_name"]} in region {region}') + axes.set_title(f'{self.info["long_name"]} in region {region}') axes.set_xlabel('Month') axes.set_ylabel(f"$\tau$({self._get_name('reactant').upper()})" - " [{self.units}]") + " [{self.info['units']}]") axes.set_xticks(range(1, 13), [str(m) for m in range(1, 13)]) gridline_kwargs = self._get_gridline_kwargs(plot_type) if gridline_kwargs is not False: @@ -1128,14 +1188,14 @@ def create_annual_cycle_plot(self, region, input_data, base_datasets): # Save netCDF file netcdf_path = get_diagnostic_filename(Path(plot_path).stem, self.cfg) var_attrs = { - 'short_name': self.names['short_name'], - 'long_name': self.names['long_name'], - 'units': self.units + 'short_name': self.info['short_name'], + 'long_name': self.info['long_name'], + 'units': self.info['units'] } io.save_1d_data(cubes, netcdf_path, 'month_number', var_attrs) # Provenance tracking - caption = (f"Annual cycle of {self.names['long_name']} for " + caption = (f"Annual cycle of {self.info['long_name']} for " f"various datasets.") provenance_record = { 'ancestors': ancestors, @@ -1172,18 +1232,16 @@ def create_zonalmean_plot(self, region, input_data, # Create a single plot for each dataset (incl. reference dataset if # given) ancestors = [] - for _, dataset in input_data.items(): + for label, dataset in input_data.items(): if dataset == ref_dataset: continue ancestors.extend(variable['filename'] for variable in dataset['dataset_data']) - if ref_dataset is None: (plot_path, netcdf_paths) = ( self.plot_zonalmean_without_ref( plot_func, region, - dataset, - base_datasets['label']) + dataset, base_datasets[label], label) ) caption = ( f"Zonal mean profile of {dataset['long_name']} of dataset " @@ -1256,7 +1314,7 @@ def create_1d_profile_plot(self, region, input_data, base_datasets): plot_type, region) # convert units - cube.convert_units(self.units) + cube.convert_units(self.info['units']) cubes[label] = cube self._check_cube_dimensions(cube, plot_type) @@ -1271,9 +1329,9 @@ def create_1d_profile_plot(self, region, input_data, base_datasets): # Default plot appearance multi_dataset_facets = self._get_multi_dataset_facets( list(base_datasets.values())) - axes.set_title(f'{self.names["long_name"]} in region {region}') + axes.set_title(f'{self.info["long_name"]} in region {region}') axes.set_xlabel(f"$\tau$({self._get_name('reactant').upper()})" - f" [{self.units}]") + f" [{self.info['units']}]") z_coord = cube.coord(axis='Z') axes.set_ylabel(f'{z_coord.long_name} [{z_coord.units}]') @@ -1324,15 +1382,15 @@ def create_1d_profile_plot(self, region, input_data, base_datasets): # Save netCDF file netcdf_path = get_diagnostic_filename(Path(plot_path).stem, self.cfg) var_attrs = { - 'short_name': self.names['short_name'], - 'long_name': self.names['long_name'], - 'units': self.units + 'short_name': self.info['short_name'], + 'long_name': self.info['long_name'], + 'units': self.info['units'] } io.save_1d_data(cubes, netcdf_path, z_coord.standard_name, var_attrs) # Provenance tracking caption = ("Vertical one-dimensional profile of " - f"{self.names['long_name']}" + f"{self.info['long_name']}" " for various datasets.") provenance_record = { 'ancestors': ancestors, @@ -1347,16 +1405,19 @@ def create_1d_profile_plot(self, region, input_data, base_datasets): def compute(self): """Plot preprocessed data.""" + input_data = self.input_data_dataset + base_datasets = {label: dataset['dataset_data'][0] + for label, dataset in input_data.items()} + + # at the moment regions only apply to TROP and STRAT for region in self.cfg['regions']: logger.info("Plotting lifetime for region %s", region) - input_data = self.input_data_dataset - base_datasets = {label: dataset['dataset_data'][0] - for label, dataset in input_data.items()} self.create_timeseries_plot(region, input_data, base_datasets) self.create_annual_cycle_plot(region, input_data, base_datasets) - self.create_zonalmean_plot(region, input_data, - base_datasets) - self.create_1d_profile_plot(region, input_data, base_datasets) + + self.create_zonalmean_plot(region, input_data, + base_datasets) + self.create_1d_profile_plot(region, input_data, base_datasets) def main(): diff --git a/esmvaltool/diag_scripts/lifetime/lifetime_base.py b/esmvaltool/diag_scripts/lifetime/lifetime_base.py index 26951269fc..e873ac0362 100644 --- a/esmvaltool/diag_scripts/lifetime/lifetime_base.py +++ b/esmvaltool/diag_scripts/lifetime/lifetime_base.py @@ -373,6 +373,7 @@ def dpres_plevel_1d(plev, pmin, pmax): def calculate_lifetime(dataset, plot_type, region): """Calculate the lifetime for the given plot_type and region.""" + # extract region from weights and reaction reaction = extract_region(dataset, region, case='reaction') weight = extract_region(dataset, region, case='weight') @@ -391,45 +392,16 @@ def calculate_lifetime(dataset, plot_type, region): def extract_region(dataset, region, case='reaction'): """Return cube with everything outside region set to zero. - If z_coord is defined as: - - air_pressure, use: - - ptp and air_pressure - - tp_clim and air_pressure - - atmosphere_hybrid_sigma_pressure_coordinate, use: - - tp_i and model_level_number - - ptp and (derived) air_pressure - - tp_clim and (derived) air_pressure Current aware regions: - - TROP: troposphere (excl. tropopause), requires tropopause pressure - - STRA: stratosphere (incl. tropopause), requires tropopause pressure + - TROP: troposphere (excl. tropopause) + - STRA: stratosphere (incl. tropopause) """ var = dataset[case] - z_coord = dataset['z_coord'] - - # calculate climatological tropopause pressure - # but only if tropopause is not given by data - if ( - 'ptp' not in dataset['variables'] - and 'tp_i' not in dataset['variables']): - tp_clim = climatological_tropopause(var[:, 0, :, :]) + use_z_coord = dataset['use_z_coord'] # mask regions outside if region in ['TROP', 'STRA']: - use_z_coord = 'air_pressure' - if z_coord.name() == 'air_pressure': - if 'ptp' in dataset['variables']: - tropopause = dataset['variables']['ptp'] - else: - tropopause = tp_clim - elif z_coord.name() == 'atmosphere_hybrid_sigma_pressure_coordinate': - if 'tp_i' in dataset['variables']: - tropopause = dataset['variables']['tp_i'] - use_z_coord = 'model_level_number' - elif 'ptp' in dataset['variables']: - tropopause = dataset['variables']['ptp'] - else: - tropopause = tp_clim z_4d = broadcast_to_shape( var.coord(use_z_coord).points, @@ -438,11 +410,11 @@ def extract_region(dataset, region, case='reaction'): ) tp_4d = broadcast_to_shape( - tropopause.data, + dataset['tropopause'].data, var.shape, - var.coord_dims('time') - + var.coord_dims('latitude') - + var.coord_dims('longitude'), + tuple(( var.coord_dims(item)[0] + for item in var.dim_coords + if not item == dataset['z_coord'] )), ) if region == 'TROP': @@ -498,7 +470,7 @@ def sum_up_to_plot_dimensions(var, plot_type): if plot_type == 'timeseries': cube = var.collapsed(['longitude', 'latitude', z_coord], iris.analysis.SUM) - elif plot_type == 'zonal_mean_profile': + elif plot_type == 'zonalmean': cube = var.collapsed(['longitude'], iris.analysis.SUM) elif plot_type == '1d_profile': cube = var.collapsed(['longitude', 'latitude'], iris.analysis.SUM) @@ -525,13 +497,9 @@ def calculate_reaction_rate(temp, reaction_type, # special reaction rate if coeff_b is not None: - reaction_rate = (da.multiply(coeff_a, - iris.analysis.maths.exp( - da.multiply(coeff_b, - iris.analysis.maths.log( - reaction_rate)) - - da.divide(coeff_er - ,reaction_rate)))) + reaction_rate = coeff_a * iris.analysis.maths.exp( + coeff_b * iris.analysis.maths.log(reaction_rate) + - coeff_er / reaction_rate ) else: # standard reaction rate (arrhenius) reaction_rate = coeff_a * iris.analysis.maths.exp(