Skip to content

Commit c81d2e2

Browse files
authored
Merge pull request #776 from pydata/reshape-func
ENH: Implement `reshape` function
2 parents 373f29f + 9325477 commit c81d2e2

File tree

4 files changed

+130
-9
lines changed

4 files changed

+130
-9
lines changed

sparse/mlir_backend/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
)
1616
from ._ops import (
1717
add,
18+
reshape,
1819
)
1920

2021
__all__ = [
2122
"add",
2223
"asarray",
2324
"asdtype",
25+
"reshape",
2426
]

sparse/mlir_backend/_constructors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def from_sps(cls, arr: np.ndarray) -> "Dense":
239239

240240
return dense_instance
241241

242-
def to_sps(self, shape: tuple[int, ...]) -> sps.csr_array:
242+
def to_sps(self, shape: tuple[int, ...]) -> np.ndarray:
243243
data = ranked_memref_to_numpy(self.data)
244244
return data.reshape(shape)
245245

sparse/mlir_backend/_ops.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from mlir import ir
66
from mlir.dialects import arith, func, linalg, sparse_tensor, tensor
77

8+
import numpy as np
9+
810
from ._common import fn_cache
9-
from ._constructors import Tensor
11+
from ._constructors import Tensor, numpy_to_ranked_memref
1012
from ._core import CWD, DEBUG, MLIR_C_RUNNER_UTILS, ctx, pm
11-
from ._dtypes import DType, FloatingDType
13+
from ._dtypes import DType, FloatingDType, Index
1214

1315

1416
@fn_cache
@@ -68,11 +70,35 @@ def add(a, b):
6870
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
6971

7072

73+
@fn_cache
74+
def get_reshape_module(
75+
a_tensor_type: ir.RankedTensorType,
76+
shape_tensor_type: ir.RankedTensorType,
77+
out_tensor_type: ir.RankedTensorType,
78+
) -> ir.Module:
79+
with ir.Location.unknown(ctx):
80+
module = ir.Module.create()
81+
82+
with ir.InsertionPoint(module.body):
83+
84+
@func.FuncOp.from_py_func(a_tensor_type, shape_tensor_type)
85+
def reshape(a, shape):
86+
return tensor.reshape(out_tensor_type, a, shape)
87+
88+
reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
89+
if DEBUG:
90+
(CWD / "reshape_module.mlir").write_text(str(module))
91+
pm.run(module.operation)
92+
if DEBUG:
93+
(CWD / "reshape_module_opt.mlir").write_text(str(module))
94+
95+
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
96+
97+
7198
def add(x1: Tensor, x2: Tensor) -> Tensor:
7299
ret_obj = x1._format_class()
73100
out_tensor_type = x1._obj.get_tensor_definition(x1.shape)
74101

75-
# TODO: Add proper caching
76102
# TODO: Decide what will be the output tensor_type
77103
add_module = get_add_module(
78104
x1._obj.get_tensor_definition(x1.shape),
@@ -88,3 +114,24 @@ def add(x1: Tensor, x2: Tensor) -> Tensor:
88114
*x2._obj.to_module_arg(),
89115
)
90116
return Tensor(ret_obj, shape=out_tensor_type.shape)
117+
118+
119+
def reshape(x: Tensor, /, shape: tuple[int, ...]) -> Tensor:
120+
ret_obj = x._format_class()
121+
x_tensor_type = x._obj.get_tensor_definition(x.shape)
122+
out_tensor_type = x._obj.get_tensor_definition(shape)
123+
124+
with ir.Location.unknown(ctx):
125+
shape_tensor_type = ir.RankedTensorType.get([len(shape)], Index.get_mlir_type())
126+
127+
reshape_module = get_reshape_module(x_tensor_type, shape_tensor_type, out_tensor_type)
128+
129+
shape = np.array(shape)
130+
reshape_module.invoke(
131+
"reshape",
132+
ctypes.pointer(ctypes.pointer(ret_obj)),
133+
*x._obj.to_module_arg(),
134+
ctypes.pointer(ctypes.pointer(numpy_to_ranked_memref(shape))),
135+
)
136+
137+
return Tensor(ret_obj, shape=out_tensor_type.shape)

sparse/mlir_backend/tests/test_simple.py

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,15 @@ def sampler_real_floating(size: tuple[int, ...]):
7575
raise NotImplementedError(f"{dtype=} not yet supported.")
7676

7777

78+
def get_exampe_csf_arrays(dtype: np.dtype) -> tuple:
79+
pos_1 = np.array([0, 1, 3], dtype=np.int64)
80+
crd_1 = np.array([1, 0, 1], dtype=np.int64)
81+
pos_2 = np.array([0, 3, 5, 7], dtype=np.int64)
82+
crd_2 = np.array([0, 1, 3, 0, 3, 0, 1], dtype=np.int64)
83+
data = np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype)
84+
return pos_1, crd_1, pos_2, crd_2, data
85+
86+
7887
@parametrize_dtypes
7988
@pytest.mark.parametrize("shape", [(100,), (10, 200), (5, 10, 20)])
8089
def test_dense_format(dtype, shape):
@@ -176,11 +185,7 @@ def test_add(rng, dtype):
176185
@parametrize_dtypes
177186
def test_csf_format(dtype):
178187
SHAPE = (2, 2, 4)
179-
pos_1 = np.array([0, 1, 3], dtype=np.int64)
180-
crd_1 = np.array([1, 0, 1], dtype=np.int64)
181-
pos_2 = np.array([0, 3, 5, 7], dtype=np.int64)
182-
crd_2 = np.array([0, 1, 3, 0, 3, 0, 1], dtype=np.int64)
183-
data = np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype)
188+
pos_1, crd_1, pos_2, crd_2, data = get_exampe_csf_arrays(dtype)
184189
csf = [pos_1, crd_1, pos_2, crd_2, data]
185190

186191
csf_tensor = sparse.asarray(csf, shape=SHAPE, dtype=sparse.asdtype(dtype), format="csf")
@@ -192,3 +197,70 @@ def test_csf_format(dtype):
192197
csf_2 = [pos_1, crd_1, pos_2, crd_2, data * 2]
193198
for actual, expected in zip(res_tensor, csf_2, strict=False):
194199
np.testing.assert_array_equal(actual, expected)
200+
201+
202+
@parametrize_dtypes
203+
def test_reshape(rng, dtype):
204+
DENSITY = 0.5
205+
sampler = generate_sampler(dtype, rng)
206+
207+
# CSR, CSC, COO
208+
for shape, new_shape in [((100, 50), (25, 200)), ((80, 1), (8, 10))]:
209+
for format in ["csr", "csc", "coo"]:
210+
if format == "coo":
211+
# NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
212+
continue
213+
if format == "csc":
214+
# NOTE: Blocked by https://github.com/llvm/llvm-project/issues/109641
215+
continue
216+
217+
arr = sps.random_array(
218+
shape, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
219+
)
220+
if format == "coo":
221+
arr.sum_duplicates()
222+
223+
tensor = sparse.asarray(arr)
224+
225+
actual = sparse.reshape(tensor, shape=new_shape).to_scipy_sparse()
226+
expected = arr.todense().reshape(new_shape)
227+
228+
np.testing.assert_array_equal(actual.todense(), expected)
229+
230+
# CSF
231+
csf_shape = (2, 2, 4)
232+
for shape, new_shape, expected_arrs in [
233+
(
234+
csf_shape,
235+
(4, 4, 1),
236+
[
237+
np.array([0, 0, 3, 5, 7]),
238+
np.array([0, 1, 3, 0, 3, 0, 1]),
239+
np.array([0, 1, 2, 3, 4, 5, 6, 7]),
240+
np.array([0, 0, 0, 0, 0, 0, 0]),
241+
np.array([1, 2, 3, 4, 5, 6, 7]),
242+
],
243+
),
244+
(
245+
csf_shape,
246+
(2, 1, 8),
247+
[
248+
np.array([0, 1, 2]),
249+
np.array([0, 0]),
250+
np.array([0, 3, 7]),
251+
np.array([4, 5, 7, 0, 3, 4, 5]),
252+
np.array([1, 2, 3, 4, 5, 6, 7]),
253+
],
254+
),
255+
]:
256+
csf = get_exampe_csf_arrays(dtype)
257+
csf_tensor = sparse.asarray(csf, shape=shape, dtype=sparse.asdtype(dtype), format="csf")
258+
259+
result = sparse.reshape(csf_tensor, shape=new_shape).to_scipy_sparse()
260+
261+
for actual, expected in zip(result, expected_arrs, strict=False):
262+
np.testing.assert_array_equal(actual, expected)
263+
264+
# DENSE
265+
# NOTE: dense reshape is probably broken in MLIR
266+
# dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)

0 commit comments

Comments
 (0)