Skip to content

Commit

Permalink
Merge pull request #1240 from ESMValGroup/julia_example
Browse files Browse the repository at this point in the history
Julia example
  • Loading branch information
mattiarighi authored Dec 19, 2019
2 parents b382005 + 6d99fec commit 82788e0
Show file tree
Hide file tree
Showing 6 changed files with 380 additions and 2 deletions.
14 changes: 12 additions & 2 deletions doc/sphinx/source/esmvaldiag/new_diagnostic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,22 @@ First create a recipe in esmvaltool/recipes to define the input data your analys
and optionally preprocessing and other settings. Also create a script in the esmvaltool/diag_scripts directory
and make sure it is referenced from your recipe. The easiest way to do this is probably to copy the example recipe
and diagnostic script and adjust those to your needs.
A good example recipe is esmvaltool/recipes/examples/recipe_python.yml
and a good example diagnostic is esmvaltool/diag_scripts/examples/diagnostic.py.

If you have no preferred programming language yet, Python 3 is highly recommended, because it is most well supported.
However, NCL, R, and Julia scripts are also supported.

Good example recipes for the different languages are:
* python: esmvaltool/recipes/examples/recipe_python.yml
* R: esmvaltool/recipes/examples/recipe_r.yml
* julia: esmvaltool/recipes/examples/recipe_julia.yml
* ncl: esmvaltool/recipes/examples/recipe_ncl.yml

Good example diagnostics are:
* python: esmvaltool/diag_scripts/examples/diagnostic.py
* R: esmvaltool/diag_scripts/examples/diagnostic.R
* julia: esmvaltool/diag_scripts/examples/diagnostic.jl
* ncl: esmvaltool/diag_scripts/examples/diagnostic.ncl

Unfortunately not much documentation is available at this stage,
so have a look at the other recipes and diagnostics for further inspiration.

Expand Down
128 changes: 128 additions & 0 deletions esmvaltool/diag_scripts/examples/diagnostic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# #############################################################################
# diagnostic.jl
# Authors: J. von Hardenberg (ISAC-CNR, Italy)
# #############################################################################
# Description
# Example of ESMValTool diagnostic written in Julia
#
# Modification history
# 20190807-vonhardenberg_jost written for v2.0
# 20191117-vonhardenberg_jost added more realistic writing of file and plot
# ############################################################################

import YAML
import JSON
using NetCDF
# Used to write output NetCDF file with original attributes
using RainFARM
using Statistics

using PyPlot
# Avoid plotting to screen
pygui(false)

# Provides the plotmap() function
include(joinpath(ENV["diag_scripts"], "shared/external.jl"))

function provenance_record(infile)
xprov = Dict("ancestors" => infile,
"authors" => ["vonhardenberg_jost", "arnone_enrico"],
"references" => ["zhang11wcc"],
"projects" => ["crescendo", "c3s-magic"],
"caption" => "Example diagnostic in Julia",
"statistics" => ["other"],
"realms" => ["atmos"],
"themes" => ["phys"],
"domains" => ["global"])
return(xprov)
end

function compute_diagnostic(metadata, varname, diag_base, parameter,
work_dir, plot_dir)
provenance = Dict()
for (infile, value) in metadata
dataset = value["dataset"]
reference_dataset = value["reference_dataset"]
start_year = value["start_year"]
end_year = value["end_year"]
exp = value["exp"]
ensemble = value["ensemble"]
println(diag_base, ": working on file ", infile)
println(diag_base, ": calling diagnostic with following parameters")
println(dataset, " ", reference_dataset, " ", start_year, " ",
end_year, " ", exp, " ", ensemble, parameter)
# Call here actual diagnostic
println(diag_base, ": I am your Julia diagnostic")

# Read the variable, lon and lat
var = ncread(infile, varname)
lon = ncread(infile, "lon")
lat = ncread(infile, "lat")

units = ncgetatt(infile, varname, "units")

# Compute time average and add parameter
varm = mean(var, dims = 3) .+ parameter

# Output filename
outfile = string(work_dir, "/", varname, "_", dataset, "_", exp, "_",
ensemble, "_", start_year, "-",
end_year, "_timmean.nc")

# Use the RainFARM function write_netcdf2d to write variable to
# output file copying original attributes from infile
write_netcdf2d(outfile, varm, lon, lat, varname, infile)

# Create provenance record for the output file
xprov = provenance_record(infile)

# Plot the field
plotfile = string(plot_dir, "/", varname, "_", dataset, "_", exp, "_",
ensemble, "_", start_year, "-",
end_year, "_timmean.png")
title = string("Mean ", varname, " ", dataset, " ", exp, " ", ensemble,
" ", start_year, "-", end_year)
plotmap(lon, lat, var, title = title, proj = "robinson", clabel = units)
savefig(plotfile)
xprov["plot_file"] = plotfile
provenance[outfile] = xprov
end
return provenance
end

function main(settings)

metadata = YAML.load_file(settings["input_files"][1])
climofiles = collect(keys(metadata))
climolist = metadata[climofiles[1]]
varname = climolist["short_name"]
diag_base = climolist["diagnostic"]

println(diag_base, ": starting routine")
println(diag_base, ": creating work and plot directories")
work_dir = settings["work_dir"]
plot_dir = settings["plot_dir"]
run_dir = settings["run_dir"]
mkpath(work_dir)
mkpath(run_dir)
mkpath(plot_dir)
cd(run_dir)

# Reading an example parameter from the settings
parameter = settings["parameter1"]

# Compute the main diagnostic
provenance = compute_diagnostic(metadata, varname, diag_base,
parameter, work_dir, plot_dir)

# setup provenance file
provenance_file = string(run_dir, "/diagnostic_provenance.yml")

# Write provenance file
open(provenance_file, "w") do io
JSON.print(io, provenance, 4)
end
end

settings = YAML.load_file(ARGS[1])
main(settings)
201 changes: 201 additions & 0 deletions esmvaltool/diag_scripts/shared/external.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using PyPlot, PyCall
using NetCDF

"""
create_yaml(filename, dict)
Write a dictionary to a YAML file.
Expand Down Expand Up @@ -30,3 +33,201 @@ function print_yaml(os::IOStream, obj::Array{String}, indent::String)
println(os, indent, "- ", obj[i])
end
end

"""
plotmap(fname, var; ...)
Easy plotting of gridded datasets on a global map.
Plots variable `var` from a netcdf file `fname`.
Optional arguments are available to control the details of the plot.
# Arguments
- 'fname::String': netcdf filename containing the data to plot
- 'var'::String : name of the variable to plot
Optional:
- 'lon::String': name of the lon variable ("lon")
- 'lat::String': name of the lat variable ("lat")
- 'lonb::String': name of the lon bounds variable ("lon_bnds")
- 'latb::String': name of the lat boundsvariable ("lat_bnds")
- 'title': title string ("")
- 'cstep': divisions of the colorbar axis ([])
- 'cmap': colormap ("RdBu_r")
- 'proj': projection. One of ["platecarree", "robinson", "mollweide"]. Defaults to "platecarree".
- 'cpad': padding (shift) of the colorbar (0.08)
- 'sub': matplotlib subplot option (e.g. "221" for the first panel of 2x2 subplots) ("111")
- 'clabel': label of the colorbar (defaults to the units string read from the netcdf file)
- 'cdir': direction of the colorbar. One of ["horizontal", "vertical"]. ("horizontal")
- 'cscale': scaling of the colorbar (0.65)
- 'cfs': colorbar ticks font size (12)
- 'lfs': colorbar label font size (12)
- 'tfs': title font size (14)
- 'tpad': padding (shift) of the title string
- 'tweight': weight of title font. One of [ 'normal" "bold" "heavy" "light" "ultrabold" "ultralight"]. ("normal")
- 'grid': grid spacing (defaults to [60,30]). set to empty [] to remove gridlines.
- 'region::NTuple{4,Int64}': region to plot in format (lon1, lon2, lat1, lat2). Defaults to global.
- 'style': one of ["pcolormesh" "contourf"]. Defaults to "pcolormesh".
- 'levels': contour plot levels. Can be an array or a number of levels (auto)
- 'extend': plot levels outside `levels` range. One of ["neither", "both", "min", "max"]. Default: "neither".
Author: Jost von Hardenberg, 2019
"""
function plotmap(fname::String, var::String; lon="lon", lat="lat",
lonb="lon_bnds", latb="lat_bnds", title="", cstep=[],
cmap="RdBu_r", proj="", cpad=0.08, tpad=24, sub=111,
clabel="NONE", cdir="horizontal", cscale=0.65, tfs=14,
cfs=12, lfs=12, tweight="normal", grid=[60,30], region=(),
style="pcolormesh", levels=0, extend="neither")

# pcolormesh needs cell boundaries
if style=="pcolormesh"
try
lonb=ncread(fname, lonb);
lonv=vcat(lonb[1,:],lonb[2,end])
catch
lonv=ncread(fname, lon);
end
try
latb=ncread(fname, latb);
latv=vcat(latb[1,:],latb[2,end])
catch
latv=ncread(fname, lat);
end
else
lonv=ncread(fname, lon);
latv=ncread(fname, lat);
end

data=ncread(fname, var);
units=ncgetatt(fname, var, "units")
if clabel=="NONE" clabel=units end

plotmap(lonv, latv, data; title=title, cstep=cstep, cmap=cmap, proj=proj,
cpad=cpad, tpad=tpad, sub=sub, clabel=clabel, cdir=cdir,
cscale=cscale, tfs=tfs, cfs=cfs, lfs=lfs, tweight=tweight, grid=grid,
region=region, style=style, levels=levels, extend=extend)

end

"""
plotmap(lon, lat, data; ...)
Easy plotting of gridded datasets on a global map.
Plots data in 2D array `data` with longitudes `lon` and latitudes `lat`.
Optional arguments are available to control the details of the plot.
# Arguments
- 'data::Array{Float32,2}': data to plot. If a 3D array is passed, only the first frame is plotted: `data[:,:,1]`
- 'lon::Array{Float64,1}' : longitudes
- 'lat::Array{Float64,1}' : latitudes
Optional:
- 'title': title string ("")
- 'cstep': divisions of the colorbar axis ([])
- 'cmap': colormap ("RdBu_r")
- 'proj': projection. One of ["platecarree", "robinson", "mollweide"]. Defaults to "platecarree".
- 'cpad': padding (shift) of the colorbar (0.08)
- 'sub': matplotlib subplot option (e.g. "221" for the first panel of 2x2 subplots) ("111")
- 'clabel': label of the colorbar (defaults to the units string read from the netcdf file)
- 'cdir': direction of the colorbar. One of ["horizontal", "vertical"]. ("horizontal")
- 'cscale': scaling of the colorbar (0.65)
- 'cfs': colorbar ticks font size (12)
- 'lfs': colorbar label font size (12)
- 'tfs': title font size (14)
- 'tpad': padding (shift) of the title string
- 'tweight': weight of title font. One of [ 'normal" "bold" "heavy" "light" "ultrabold" "ultralight"]. ("normal")
- 'grid': grid spacing (defaults to [60,30]). set to empty [] to remove gridlines.
- 'region::NTuple{4,Int64}': region to plot in format (lon1, lon2, lat1, lat2). Defaults to global.
- 'style': one of ["pcolormesh" "contourf"]. Defaults to "pcolormesh".
- 'levels': contour plot levels. Can be an array or a number of levels (auto)
- 'extend': plot levels outside `levels` range. One of ["neither", "both", "min", "max"]. Default: "neither".
Author: Jost von Hardenberg, 2019
"""
function plotmap(lon, lat, data; title="", cstep=[], cmap="RdBu_r", proj="",
cpad=0.08, tpad=24, sub=111, clabel="", cdir="horizontal",
cscale=0.65, tfs=14, cfs=12, lfs=12, tweight="normal",
grid=[60,30], region=(), style="pcolormesh", levels=0,
extend="neither")

dd = size(data)

if length(dd)==3 data=data[:,:,1] end
if style=="pcolormesh"
if length(lon) in dd
#println("pcolormesh needs cell boundaries, reconstructing lon")
lonb=zeros(2,length(lon))
lonb[1,2:end]=0.5*(lon[2:end]+lon[1:(end-1)])
lonb[1,1]=lon[1]-(lon[2]-lon[1])*0.5
lonb[2,end]=lon[end]+(lon[end]-lon[end-1])*0.5
lon=vcat(lonb[1,:],lonb[2,end])
end
if length(lat) in dd
#println("pcolormesh needs cell boundaries, reconstructing lat")
latb=zeros(2,length(lat))
latb[1,2:end]=0.5*(lat[2:end]+lat[1:(end-1)])
latb[1,1]=lat[1]-(lat[2]-lat[1])*0.5
latb[2,end]=lat[end]+(lat[end]-lat[end-1])*0.5
if latb[1,1]>89; latb[1,1]=90 ; end
if latb[1,1]<-89; latb[1,1]=-90 ; end
if latb[2,end]>89; latb[2,end]=90 ; end
if latb[2,end]<-89; latb[2,end]=-90 ; end
lat=vcat(latb[1,:],latb[2,end])
end
if length(lon)==(dd[1]+1) data=data' end
else
if length(lon)==dd[1] data=data' end
end

ccrs = pyimport("cartopy.crs")
cutil = pyimport("cartopy.util")

if proj=="robinson"
proj=ccrs.Robinson()
dlabels=false
elseif proj == "mollweide"
proj=ccrs.Mollweide()
dlabels=false
else
proj=ccrs.PlateCarree()
dlabels=true
end

ax = subplot(sub, projection=proj)
if length(region)>0 ax.set_extent(region, crs=ccrs.PlateCarree()) end
ax.coastlines()
xlocvec=vcat(-vcat(grid[1]:grid[1]:180)[end:-1:1], vcat(0:grid[1]:180))
ylocvec=vcat(-vcat(grid[2]:grid[2]:90)[end:-1:1], vcat(0:grid[2]:90))

if dlabels
ax.gridlines(linewidth=1, color="gray", alpha=0.5, linestyle="--",
draw_labels=true, xlocs=xlocvec, ylocs=ylocvec)
else
ax.gridlines(linewidth=1, color="gray", alpha=0.5, linestyle="--",
xlocs=xlocvec, ylocs=ylocvec)
end

if style=="contourf"
data_cyc, lon_cyc = cutil.add_cyclic_point(data, coord=lon)
if levels==0
contourf(lon_cyc, lat, data_cyc, transform=ccrs.PlateCarree(),
cmap=cmap, extend=extend)
else
contourf(lon_cyc, lat, data_cyc, transform=ccrs.PlateCarree(),
cmap=cmap, levels=levels, extend=extend)
end
else
pcolormesh(lon, lat, data, transform=ccrs.PlateCarree(), cmap=cmap)
end

if length(cstep)>0 clim(cstep[1],cstep[end]); end
if length(title)>1 PyPlot.title(title, pad=tpad, fontsize=tfs, weight=tweight) end
cbar=colorbar(orientation=cdir, extend="both", pad=cpad, label=clabel,
shrink=cscale)
cbar.set_label(label=clabel,size=lfs)
cbar.ax.tick_params(labelsize=cfs)
if length(cstep)>0 cbar.set_ticks(cstep) end
tight_layout()

end

5 changes: 5 additions & 0 deletions esmvaltool/install/Julia/julia_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@ Compat
DataFrames
RainFARM
YAML
JSON
PyCall
PyPlot
NetCDF
Statistics
2 changes: 2 additions & 0 deletions esmvaltool/install/Julia/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ if VERSION >= v"0.7.0-DEV.2005"
using Pkg
end

ENV["PYTHON"] = string(ENV["CONDA_PREFIX"], "/bin/python")

@info "Installing the packages from" scriptDir * "/julia_requirements.txt"
pkgName=in
open(scriptDir * "/julia_requirements.txt") do f
Expand Down
Loading

0 comments on commit 82788e0

Please sign in to comment.