Skip to content

Commit

Permalink
Merge pull request #218 from rayosborn:fix-tests
Browse files Browse the repository at this point in the history
Fix-tests
  • Loading branch information
rayosborn authored Jan 13, 2025
2 parents 6f2175b + fc065f2 commit fec68d2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ install_requires =
scipy
h5py >=2.9
hdf5plugin
packaging

[options.packages.find]
where = src
Expand Down
8 changes: 4 additions & 4 deletions src/nexusformat/nexus/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -----------------------------------------------------------------------------
# Copyright (c) 2013-2021, NeXpy Development Team.
# Copyright (c) 2013-2025, NeXpy Development Team.
#
# Author: Paul Kienzle, Ray Osborn
#
Expand All @@ -12,6 +12,7 @@
import copy

import numpy as np
from packaging.version import Version

from . import NeXusError, NXfield

Expand Down Expand Up @@ -252,8 +253,7 @@ def plot(self, data_group, fmt=None, xmin=None, xmax=None,
else:
kwargs["norm"] = Normalize(vmin, vmax)

from pkg_resources import parse_version as pv
if pv(mplversion) >= pv('3.5.0'):
if Version(mplversion) >= Version('3.5.0'):
from matplotlib import colormaps
cm = copy.copy(colormaps[cmap])
else:
Expand Down Expand Up @@ -292,7 +292,7 @@ def plot(self, data_group, fmt=None, xmin=None, xmax=None,
im.set_clim(-0.5, 9.5)
elif cmin == 1:
im.set_clim(0.5, 10.5)
if pv(mplversion) >= pv('3.5.0'):
if Version(mplversion) >= Version('3.5.0'):
cb.ax.set_ylim(cmin-0.5, cmax+0.5)
cb.set_ticks(range(int(cmin), int(cmax)+1))

Expand Down
10 changes: 7 additions & 3 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import sys

import h5py as h5
import numpy as np
import pytest

from nexusformat.nexus.tree import NXfield, nxgetconfig


Expand Down Expand Up @@ -85,13 +88,14 @@ def test_field_methods(arr, request):

assert np.array_equal(field**2, arr**2)
assert field.min() == np.min(arr)
assert field.min(keepdims=True) == np.min(arr, keepdims=True)
assert field.max() == np.max(arr)
assert field.max(keepdims=True) == np.max(arr, keepdims=True)
assert field.sum() == np.sum(arr)
assert field.sum(dtype=np.float32) == np.sum(arr, dtype=np.float32)
assert field.average() == np.average(arr)
assert field.average(keepdims=True) == np.average(arr, keepdims=True)
if sys.version_info >= (3, 8):
assert field.min(keepdims=True) == np.min(arr, keepdims=True)
assert field.max(keepdims=True) == np.max(arr, keepdims=True)
assert field.average(keepdims=True) == np.average(arr, keepdims=True)


@pytest.mark.parametrize(
Expand Down

0 comments on commit fec68d2

Please sign in to comment.