Skip to content

Commit eb93372

Browse files
authored
Let grdcut() accept xarray.DataArray as input (#541)
1 parent 87e7d42 commit eb93372

File tree

2 files changed

+30
-22
lines changed

2 files changed

+30
-22
lines changed

pygmt/gridops.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def grdcut(grid, **kwargs):
4747
4848
Parameters
4949
----------
50-
grid : str
51-
The name of the input grid file.
50+
grid : str or xarray.DataArray
51+
The file name of the input grid or the grid loaded as a DataArray.
5252
outgrid : str or None
5353
The name of the output netCDF file with extension .nc to store the grid
5454
in.
@@ -94,12 +94,7 @@ def grdcut(grid, **kwargs):
9494
if kind == "file":
9595
file_context = dummy_context(grid)
9696
elif kind == "grid":
97-
raise NotImplementedError(
98-
"xarray.DataArray is not supported as the input grid yet!"
99-
)
100-
# file_context = lib.virtualfile_from_grid(grid)
101-
# See https://github.com/GenericMappingTools/gmt/pull/3532
102-
# for a feature request.
97+
file_context = lib.virtualfile_from_grid(grid)
10398
else:
10499
raise GMTInvalidInput("Unrecognized data type: {}".format(type(grid)))
105100

pygmt/tests/test_grdcut.py

+27-14
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
from ..helpers import GMTTempFile
1313

1414

15+
@pytest.fixture(scope="module", name="grid")
16+
def fixture_grid():
17+
"Load the grid data from the sample earth_relief file"
18+
return load_earth_relief(registration="pixel")
19+
20+
1521
def test_grdcut_file_in_file_out():
1622
"grduct an input grid file, and output to a grid file"
1723
with GMTTempFile(suffix=".nc") as tmpfile:
@@ -41,23 +47,30 @@ def test_grdcut_file_in_dataarray_out():
4147
assert outgrid.sizes["lon"] == 180
4248

4349

44-
def test_grdcut_dataarray_in_file_out():
45-
"grdcut an input DataArray, and output to a grid file"
46-
# Not supported yet.
47-
# See https://github.com/GenericMappingTools/gmt/pull/3532
48-
49-
50-
def test_grdcut_dataarray_in_dataarray_out():
50+
def test_grdcut_dataarray_in_file_out(grid):
5151
"grdcut an input DataArray, and output to a grid file"
52-
# Not supported yet.
53-
# See https://github.com/GenericMappingTools/gmt/pull/3532
52+
with GMTTempFile(suffix=".nc") as tmpfile:
53+
result = grdcut(grid, outgrid=tmpfile.name, region="0/180/0/90")
54+
assert result is None # grdcut returns None if output to a file
55+
result = grdinfo(tmpfile.name, C=True)
56+
assert result == "0 180 0 90 -8182 5651.5 1 1 180 90 1 1\n"
5457

5558

56-
def test_grdcut_dataarray_in_fail():
57-
"Make sure that grdcut fails correctly if DataArray is the input grid"
58-
with pytest.raises(NotImplementedError):
59-
grid = load_earth_relief()
60-
grdcut(grid, region="0/180/0/90")
59+
def test_grdcut_dataarray_in_dataarray_out(grid):
60+
"grdcut an input DataArray, and output as DataArray"
61+
outgrid = grdcut(grid, region="0/180/0/90")
62+
assert isinstance(outgrid, xr.DataArray)
63+
# check information of the output grid
64+
# the '@earth_relief_01d' is in pixel registration, so the grid range is
65+
# not exactly 0/180/0/90
66+
assert outgrid.coords["lat"].data.min() == 0.5
67+
assert outgrid.coords["lat"].data.max() == 89.5
68+
assert outgrid.coords["lon"].data.min() == 0.5
69+
assert outgrid.coords["lon"].data.max() == 179.5
70+
assert outgrid.data.min() == -8182.0
71+
assert outgrid.data.max() == 5651.5
72+
assert outgrid.sizes["lat"] == 90
73+
assert outgrid.sizes["lon"] == 180
6174

6275

6376
def test_grdcut_fails():

0 commit comments

Comments
 (0)