Skip to content

Commit aa27198

Browse files
committed
fix dim 0 for random.uniform
1 parent bd1e873 commit aa27198

File tree

2 files changed

+12
-20
lines changed

2 files changed

+12
-20
lines changed

sharpy/random/__init__.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import numpy as np
2+
import sharpy as sp
23
from sharpy import float64
34
from sharpy.numpy import fromfunction
45

56
def uniform(low, high, size, device='', team=1):
67
data = np.random.uniform(low, high, size)
7-
# TODO handle 0 dim
8-
if isinstance(data, float):
9-
return data
8+
if len(data.shape) == 0:
9+
sp_data = sp.empty(())
10+
sp_data[()] = data[()]
11+
return sp_data
1012
return fromfunction(lambda *index: data[index], data.shape, dtype=float64, device=device, team=team)
1113

1214
def rand(*shape, device='', team=1):

test/test_random.py

+7-17
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,12 @@ def test_random_rand(shape, seed):
2323
else:
2424
assert np.allclose(sp.to_numpy(sp_data), np_data)
2525

26-
# @pytest.mark.parametrize("low,high", [(0, 1), (4, 10), (-100, 100)])
27-
# def test_random_uniform(low, high, shape, seed):
28-
# sp.random.seed(seed)
29-
# sp_data = sp.random.uniform(low, high, shape)
30-
31-
# np.random.seed(seed)
32-
# np_data = np.random.uniform(low, high, shape)
33-
34-
# print('np', np_data)
35-
# print('sp', sp_data)
26+
@pytest.mark.parametrize("low,high", [(0, 1), (4, 10), (-100, 100)])
27+
def test_random_uniform(low, high, shape, seed):
28+
sp.random.seed(seed)
29+
sp_data = sp.random.uniform(low, high, shape)
3630

37-
# # if isinstance(np_data, float):
38-
# # assert isinstance(sp_data, float) and sp_data == np_data
39-
# # else:
40-
# # assert np.allclose(sp.to_numpy(sp_data), np_data)
31+
np.random.seed(seed)
32+
np_data = np.random.uniform(low, high, shape)
4133

42-
sp.init()
43-
test_random_uniform(0, 1, (), 0)
44-
sp.fini()
34+
assert np.allclose(sp.to_numpy(sp_data), np_data)

0 commit comments

Comments
 (0)