Skip to content

Commit 0636bd7

Browse files
authored
Reshape Reloaded (#2)
- resurrecting always-copying reshape - adjusting to fixes in imex (dtype-independent wait, async) - enabling tests using reshape - temporarily disabling async execution of reshape until fixed in IMEX
1 parent b0c02fc commit 0636bd7

15 files changed

+120
-182
lines changed

.github/workflows/ci.yml

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: GitHub Actions Demo
1+
name: Sharpy CI
22
run-name: ${{ github.actor }} CI for sharpy
33
on:
44
push:
@@ -41,7 +41,8 @@ jobs:
4141
rm -rf $CONDA_ROOT
4242
cd $GITHUB_WORKSPACE/..
4343
rm -f Miniconda3-*.sh
44-
CPKG=Miniconda3-latest-Linux-x86_64.sh
44+
# CPKG=Miniconda3-latest-Linux-x86_64.sh
45+
CPKG=Miniconda3-py311_24.3.0-0-Linux-x86_64.sh
4546
wget -q https://repo.anaconda.com/miniconda/$CPKG
4647
bash $CPKG -u -b -f -p $CONDA_ROOT
4748
export PATH=$CONDA_ROOT/condabin:$CONDA_ROOT/bin:${PATH}

conda-recipe/build.sh

+1-2
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,12 @@ if [ ! -d "${INSTALLED_DIR}/imex/lib" ]; then
5353
rm -rf ${INSTALLED_DIR}/imex
5454
IMEX_SHA=$(cat imex_version.txt)
5555
if [ ! -d "mlir-extensions" ]; then
56-
git clone --recurse-submodules --branch main --single-branch https://github.com/intel/mlir-extensions
56+
git clone --recurse-submodules https://github.com/intel/mlir-extensions
5757
fi
5858
pushd mlir-extensions
5959
git reset --hard HEAD
6060
git fetch --prune
6161
git checkout $IMEX_SHA
62-
git apply ${RECIPE_DIR}/imex_*.patch
6362
LLVM_SHA=$(cat build_tools/llvm_version.txt)
6463
# if [ ! -d "llvm-project" ]; then ln -s ~/github/llvm-project .; fi
6564
if [ ! -d "llvm-project" ]; then

conda-recipe/imex_findsycl.patch

-40
This file was deleted.

examples/stencil-2d.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,8 @@ def main():
197197
# * Analyze and output results.
198198
# ******************************************************************************
199199

200-
B = np.spmd.gather(B)
201-
norm = np.linalg.norm(np.reshape(B, n * n), ord=1)
200+
B = np.spmd.gather(np.reshape(B, (n * n,)))
201+
norm = np.linalg.norm(B, ord=1)
202202
active_points = (n - 2 * r) ** 2
203203
norm /= active_points
204204

imex_version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
617949ac6105f28faeab4fa3018142195d1125c0
1+
a6109b1005932d8b4c1d2e8ab0ec4abe7411762a

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def build_cmake(self, ext):
4343
os.chdir(str(build_temp))
4444
self.spawn(["cmake", str(cwd)] + cmake_args)
4545
if not self.dry_run:
46-
self.spawn(["cmake", "--build", "."] + build_args)
46+
self.spawn(["cmake", "--build", ".", "-j5"] + build_args)
4747
# Troubleshooting: if fail on line above then delete all possible
4848
# temporary CMake files including "CMakeCache.txt" in top level dir.
4949
os.chdir(str(cwd))

sharpy/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ def _validate_device(device):
107107
f"{func} = lambda start, end, step, endpoint, dtype=float64, device='', team=1: ndarray(_csp.Creator.linspace(start, end, step, endpoint, dtype, _validate_device(device), team))"
108108
)
109109

110+
111+
for func in api.api_categories["ManipOp"]:
112+
FUNC = func.upper()
113+
if func == "reshape":
114+
exec(
115+
f"{func} = lambda this, shape, cp=None: ndarray(_csp.ManipOp.reshape(this._t, shape, cp))"
116+
)
117+
110118
for func in api.api_categories["ReduceOp"]:
111119
FUNC = func.upper()
112120
exec(

sharpy/array_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@
175175
"concat", # (arrays, /, *, axis=0)
176176
"expand_dims", # (x, /, *, axis)
177177
"flip", # (x, /, *, axis=None)
178-
"reshape", # (x, /, shape)
178+
"reshape", # (x, /, shape, *, copy: bool | None = None)
179179
"roll", # (x, /, shift, *, axis=None)
180180
"squeeze", # (x, /, axis)
181181
"stack", # (arrays, /, *, axis=0)

src/ManipOp.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "sharpy/jit/mlir.hpp"
1313

1414
#include <imex/Dialect/Dist/IR/DistOps.h>
15+
#include <imex/Dialect/Dist/Utils/Utils.h>
1516
#include <imex/Dialect/NDArray/IR/NDArrayOps.h>
1617
#include <mlir/IR/Builders.h>
1718

@@ -41,7 +42,7 @@ struct DeferredReshape : public Deferred {
4142
: ::imex::getIntAttr(builder, COPY_ALWAYS ? true : false, 1);
4243

4344
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
44-
auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType());
45+
auto outTyp = imex::dist::cloneWithShape(aTyp, shape());
4546

4647
auto op =
4748
builder.create<::imex::ndarray::ReshapeOp>(loc, outTyp, av, shp, copyA);

0 commit comments

Comments
 (0)