Skip to content

Commit 818afc1

Browse files
committed
2-year minibatch w/ 3D fields
1 parent 5a056d2 commit 818afc1

File tree

7 files changed

+67
-52
lines changed

7 files changed

+67
-52
lines changed

experiments/calibration/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
3131

3232
[compat]
3333
ClimaCalibrate = "0.0.11"
34+
ClimaLand = "=0.15.9"
3435
ClimaTimeSteppers = "0.8.2"
3536
ClimaCore = "0.14.26"
36-
EnsembleKalmanProcesses = "2.0"
37+
EnsembleKalmanProcesses = "2.1.2"
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# Generate and save experiment observations to disk
2-
using ClimaAnalysis, JLD2
3-
include("experiments/calibration/coarse_amip/observation_utils.jl")
2+
using ClimaAnalysis, JLD2, ClimaCoupler
3+
include(joinpath(pkgdir(ClimaCoupler),"experiments/calibration/coarse_amip/observation_utils.jl"))
44

55
const obs_dir = "/home/ext_nefrathe_caltech_edu/calibration_obs"
6-
const diagnostic_dir = "experiments/calibration/output/old/iteration_000/member_001/model_config/output_0000/clima_atmos/"
6+
const simdir = SimDir(joinpath(pkgdir(ClimaCoupler),"experiments/calibration/output/iteration_000/member_001/model_config/"))
77

8-
diagnostic_var2d = OutputVar(joinpath(diagnostic_dir, "rsdt_1M_average.nc"));
9-
pressure = OutputVar(joinpath(diagnostic_dir, "pfull_1M_average.nc"));
10-
diagnostic_var3d = OutputVar(joinpath(diagnostic_dir, "ta_1M_average.nc"));
8+
diagnostic_var2d = get_monthly_averages(simdir, "rsut")
9+
pressure = get_monthly_averages(simdir, "pfull")
10+
diagnostic_var3d = get_monthly_averages(simdir, "ta")
1111
diagnostic_var3d = ClimaAnalysis.Atmos.to_pressure_coordinates(diagnostic_var3d, pressure)
1212

1313
nt = get_all_output_vars(obs_dir, diagnostic_var2d, diagnostic_var3d)
1414
nyears = 18
1515
observation_vec = create_observation_vector(nt, nyears)
16-
JLD2.save_object("experiments/calibration/coarse_amip/observations.jld2", observation_vec)
16+
JLD2.save_object("experiments/calibration/coarse_amip/observations_3d.jld2", observation_vec)

experiments/calibration/coarse_amip/model_config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
FLOAT_TYPE: "Float32"
2-
use_itime: true
2+
# use_itime: true
33
albedo_model: "CouplerAlbedo"
44
atmos_config_file: "config/longrun_configs/amip_target_edonly.yml"
55
checkpoint_dt: "99999days"

experiments/calibration/coarse_amip/model_interface.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ function ClimaCalibrate.forward_model(iter, member)
1515
output_dir_root = config_dict["coupler_output_dir"]
1616
eki = JLD2.load_object(ClimaCalibrate.ekp_path(output_dir_root, iter))
1717
minibatch = EKP.get_current_minibatch(eki)
18-
@show minibatch
18+
@info "Current minibatch: $minibatch"
1919
config_dict["start_date"] = minibatch_to_start_date(minibatch)
2020

21-
spinup_time = 93days
21+
spinup_days = 92
2222
nyears = length(minibatch)
23-
t_end_days = spinup_time + 365 * nyears
23+
t_end_days = spinup_days + 365 * nyears
2424
config_dict["t_end"] = "$(t_end_days)days"
2525

2626
# Set member parameter file

experiments/calibration/coarse_amip/observation_map.jl

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ function ClimaCalibrate.observation_map(iteration)
1010
single_member_dims = length(EKP.get_obs(first(observation_vec))) * batch_size
1111
G_ensemble = Array{Float64}(undef, single_member_dims, ensemble_size)
1212
for m in 1:ensemble_size
13-
@info "Processing member $m"
1413
member_path = ClimaCalibrate.path_to_ensemble_member(output_dir, iteration, m)
1514
simdir_path = joinpath(member_path, "model_config/output_active")
15+
@info "Processing member $m: $simdir_path"
1616
try
1717
G_ensemble[:, m] .= process_member_data(SimDir(simdir_path))
1818

@@ -45,7 +45,7 @@ function process_member_data(simdir::SimDir)
4545
ta = process_outputvar(simdir, "ta")
4646
hur = process_outputvar(simdir, "hur")
4747
hus = process_outputvar(simdir, "hus")
48-
48+
# Map over each year
4949
year_observations = map(1:4:length(rsut)) do year_start
5050
year_end = min(year_start + 3, length(rsut))
5151
yr_ind = year_start:year_end
@@ -62,7 +62,7 @@ function process_member_data(simdir::SimDir)
6262
hur_yr = downsample_and_vectorize(hur[yr_ind])
6363
hus_yr = downsample_and_vectorize(hus[yr_ind])
6464
# ql, qi
65-
vcat(net_rad_yr, rsut_yr, rlut_yr, cre_yr, pr_yr, ts_yr)#, ta, hur, hus)
65+
vcat(net_rad_yr, rsut_yr, rlut_yr, cre_yr, pr_yr, ts_yr, ta_yr, hur_yr, hus_yr)
6666
end
6767
return vcat(year_observations...)
6868
end
@@ -72,6 +72,16 @@ function downsample_and_vectorize(seasonal_avgs)
7272
return vcat(vec.(downsampled_seasonal_avg_arrays)...)
7373
end
7474

75+
# Process an outputvar into a vector of seasonal averages
76+
function process_outputvar(simdir, name)
77+
monthly_avgs = preprocess_monthly_averages(simdir, name)
78+
seasons = split_monthly_averages_into_seasons(monthly_avgs)
79+
# Ensure each season has three months
80+
@assert all(map(x -> length(times(x)) == 3, seasons))
81+
seasonal_avgs = average_time.(seasons)
82+
return seasonal_avgs
83+
end
84+
7585
# Preprocess monthly averages to the right dimensions and dates, remove NaNs
7686
function preprocess_monthly_averages(simdir, name)
7787
monthly_avgs = get_monthly_averages(simdir, name)
@@ -81,26 +91,23 @@ function preprocess_monthly_averages(simdir, name)
8191
monthly_avgs = ClimaAnalysis.Atmos.to_pressure_coordinates(monthly_avgs, pressure)
8292
monthly_avgs = limit_pressure_dim_to_era5_range(monthly_avgs)
8393
end
84-
# TODO: Ask ollie how to replace nans in a way that makes sense
85-
monthly_avgs = ClimaAnalysis.replace(monthly_avgs, missing => 0.0, NaN => 0.0)
94+
# TODO: Replace NaNs with global mean
95+
monthly_avgs = ClimaAnalysis.replace(monthly_avgs, NaN => 0.0)
8696
monthly_avgs = ClimaAnalysis.shift_to_start_of_previous_month(monthly_avgs)
8797
# Remove spinup time
8898
monthly_avgs = window(monthly_avgs, "time"; left = spinup_time)
8999
return monthly_avgs
90100
end
91101

92-
# Process an outputvar into a vector of seasonal averages
93-
function process_outputvar(simdir, name)
94-
monthly_avgs = preprocess_monthly_averages(simdir, name)
95-
seasons = split_by_season_across_time(monthly_avgs)
96-
# Ensure each season has three months
97-
if !all(map(x -> length(times(x)) == 3, seasons))
98-
@info "Uneven months per season, rebalancing..."
99-
rebalance_months_per_season!(seasons)
100-
end
101-
seasonal_avgs = average_time.(seasons)
102+
function split_monthly_averages_into_seasons(monthly_avgs)
103+
all_times = times(monthly_avgs)
104+
@assert length(all_times) % 3 == 0
102105

103-
return seasonal_avgs
106+
# Window over 3 months at a time to create seasonal outputvars
107+
split_by_seasons = map(all_times[1:3:end]) do t
108+
window(monthly_avgs, "time"; left = t, right = t + 2months)
109+
end
110+
return split_by_seasons
104111
end
105112

106113
function rebalance_months_per_season!(seasons)

experiments/calibration/coarse_amip/observation_utils.jl

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,19 @@ function get_all_output_vars(obs_dir, diagnostic_var2d, diagnostic_var3d)
7575
# specific humidity
7676
hus = resample(era5_outputvar(joinpath(obs_dir, "era5_monthly_avg_pressure_level_q.nc")))
7777

78-
# Cloud specific liquid water content
79-
ql = era5_outputvar(joinpath(obs_dir, "era5_specific_cloud_liquid_water_content_1deg.nc"))
80-
# Cloud specific ice water content
81-
qi = era5_outputvar(joinpath(obs_dir, "era5_specific_cloud_ice_water_content_1deg.nc"))
82-
foreach((ql, qi)) do var
83-
# Convert from hPa to Pa in-place so we don't create more huge OutputVars
84-
@assert var.dim_attributes[pressure_name(var)]["units"] == "hPa"
85-
var.dims[pressure_name(var)] .*= 100.0
86-
set_dim_units!(var, pressure_name(var), "Pa")
87-
end
78+
# # Cloud specific liquid water content
79+
# ql = era5_outputvar(joinpath(obs_dir, "era5_specific_cloud_liquid_water_content_1deg.nc"))
80+
# # Cloud specific ice water content
81+
# qi = era5_outputvar(joinpath(obs_dir, "era5_specific_cloud_ice_water_content_1deg.nc"))
82+
# foreach((ql, qi)) do var
83+
# # Convert from hPa to Pa in-place so we don't create more huge OutputVars
84+
# @assert var.dim_attributes[pressure_name(var)]["units"] == "hPa"
85+
# var.dims[pressure_name(var)] .*= 100.0
86+
# set_dim_units!(var, pressure_name(var), "Pa")
87+
# end
8888
# TODO: determine where time is spent here
89-
ql = resample(reverse_dim(reverse_dim(ql, latitude_name(ql)), pressure_name(ql)))
90-
qi = resample(reverse_dim(reverse_dim(qi, latitude_name(qi)), pressure_name(qi)))
89+
# ql = resample(reverse_dim(reverse_dim(ql, latitude_name(ql)), pressure_name(ql)))
90+
# qi = resample(reverse_dim(reverse_dim(qi, latitude_name(qi)), pressure_name(qi)))
9191

9292
return (; rlut, rsut, rsutcs, rlutcs, pr, net_rad, cre, shf, ts, ta, hur, hus)#, ql, qi)
9393
end
@@ -111,7 +111,10 @@ end
111111
#####
112112
# Processing to create EKP.ObservationSeries
113113
#####
114-
114+
function to_datetime(var::OutputVar)
115+
start_date = DateTime(var.attributes["start_date"], dateformat"yyyy-mm-ddTHH:MM:SS")
116+
return [start_date + Second(t) for t in times(var)]
117+
end
115118
to_datetime(start_date, time) = DateTime(start_date) + Second(time)
116119
to_datetime(time) = DateTime(start_date) + Second(time)
117120

@@ -140,14 +143,17 @@ function get_yearly_averages(var)
140143
return year_averaged_matrices
141144
end
142145

143-
# Given an outputvar, compute the standard deviation at each point for each season.
146+
# TODO: compute seasonal stdev properly
144147
function get_seasonal_stdev(output_var)
145148
all_seasonal_averages = get_seasonal_averages(output_var)
146149
all_seasonal_averages = downsample.(all_seasonal_averages, 3)
147-
seasonal_average_matrix = cat(all_seasonal_averages...; dims = 3)
148-
interannual_stdev = std(seasonal_average_matrix, dims = 3)
149-
# TODO: Add spatial variance
150-
return dropdims(interannual_stdev; dims = 3)
150+
151+
# Determine dimensionality of the data
152+
dims = ndims(all_seasonal_averages[1]) + 1
153+
154+
seasonal_average_matrix = cat(all_seasonal_averages...; dims)
155+
interannual_stdev = std(seasonal_average_matrix; dims)
156+
return dropdims(interannual_stdev; dims)
151157
end
152158

153159
# Given an outputvar, compute the covariance for each season.
@@ -226,15 +232,15 @@ function create_observation_vector(nt, yrs = 19)
226232
# shf_obs = make_single_year_of_seasonal_observations(shf, yr)
227233
ts_obs = make_single_year_of_seasonal_observations(ts, yr)
228234

229-
# ta_obs = make_single_year_of_seasonal_observations(ta, yr)
230-
# hur_obs = make_single_year_of_seasonal_observations(hur, yr)
231-
# hus_obs = make_single_year_of_seasonal_observations(hus, yr)
232-
EKP.combine_observations([net_rad_obs, rsut_obs, rlut_obs, cre_obs, pr_obs, ts_obs])#, ta_obs, hur_obs, hus_obs])
235+
ta_obs = make_single_year_of_seasonal_observations(ta, yr)
236+
hur_obs = make_single_year_of_seasonal_observations(hur, yr)
237+
hus_obs = make_single_year_of_seasonal_observations(hus, yr)
238+
EKP.combine_observations([net_rad_obs, rsut_obs, rlut_obs, cre_obs, pr_obs, ts_obs, ta_obs, hur_obs, hus_obs])
233239
end
234240

235241
return all_observations # NOT an EKP.ObservationSeries
236242
end
237-
# TODO: Ask kevin to implement in
243+
# TODO: Ask kevin to implement in ClimaAnalysis
238244
downsample(var::ClimaAnalysis.OutputVar, n) = downsample(var.data, n)
239245

240246
function downsample(arr::AbstractArray, n)

experiments/calibration/coarse_amip/run_calibration.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ priors = [
2929
constrained_gaussian("precipitation_timescale", 600, 400, 0, 1200),
3030
]
3131
prior = combine_distributions(priors)
32-
observation_path = joinpath(experiment_dir, "observations.jld2")
32+
observation_path = joinpath(experiment_dir, "observations_3d.jld2")
3333
observation_vec = JLD2.load_object(observation_path)
3434

3535
# Create the EKP.ObservationSeries
@@ -49,6 +49,7 @@ eki = EKP.EnsembleKalmanProcess(
4949
EKP.construct_initial_ensemble(prior, ensemble_size),
5050
observation_series,
5151
EKP.TransformInversion(),
52+
verbose=true
5253
)
5354

5455
eki = CAL.calibrate(CAL.WorkerBackend, eki, ensemble_size, n_iterations, prior, output_dir)

0 commit comments

Comments
 (0)