Skip to content

Reshape Reloaded #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: GitHub Actions Demo
name: Sharpy CI
run-name: ${{ github.actor }} CI for sharpy
on:
push:
Expand Down Expand Up @@ -41,7 +41,8 @@ jobs:
rm -rf $CONDA_ROOT
cd $GITHUB_WORKSPACE/..
rm -f Miniconda3-*.sh
CPKG=Miniconda3-latest-Linux-x86_64.sh
# CPKG=Miniconda3-latest-Linux-x86_64.sh
CPKG=Miniconda3-py311_24.3.0-0-Linux-x86_64.sh
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm comment

wget -q https://repo.anaconda.com/miniconda/$CPKG
bash $CPKG -u -b -f -p $CONDA_ROOT
export PATH=$CONDA_ROOT/condabin:$CONDA_ROOT/bin:${PATH}
Expand Down
3 changes: 1 addition & 2 deletions conda-recipe/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ if [ ! -d "${INSTALLED_DIR}/imex/lib" ]; then
rm -rf ${INSTALLED_DIR}/imex
IMEX_SHA=$(cat imex_version.txt)
if [ ! -d "mlir-extensions" ]; then
git clone --recurse-submodules --branch main --single-branch https://github.com/intel/mlir-extensions
git clone --recurse-submodules https://github.com/intel/mlir-extensions
fi
pushd mlir-extensions
git reset --hard HEAD
git fetch --prune
git checkout $IMEX_SHA
git apply ${RECIPE_DIR}/imex_*.patch
LLVM_SHA=$(cat build_tools/llvm_version.txt)
# if [ ! -d "llvm-project" ]; then ln -s ~/github/llvm-project .; fi
if [ ! -d "llvm-project" ]; then
Expand Down
40 changes: 0 additions & 40 deletions conda-recipe/imex_findsycl.patch

This file was deleted.

2 changes: 1 addition & 1 deletion imex_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
617949ac6105f28faeab4fa3018142195d1125c0
a6109b1005932d8b4c1d2e8ab0ec4abe7411762a
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def build_cmake(self, ext):
os.chdir(str(build_temp))
self.spawn(["cmake", str(cwd)] + cmake_args)
if not self.dry_run:
self.spawn(["cmake", "--build", "."] + build_args)
self.spawn(["cmake", "--build", ".", "-j5"] + build_args)
# Troubleshooting: if fail on line above then delete all possible
# temporary CMake files including "CMakeCache.txt" in top level dir.
os.chdir(str(cwd))
Expand Down
8 changes: 8 additions & 0 deletions sharpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def _validate_device(device):
f"{func} = lambda start, end, step, endpoint, dtype=float64, device='', team=1: ndarray(_csp.Creator.linspace(start, end, step, endpoint, dtype, _validate_device(device), team))"
)


for func in api.api_categories["ManipOp"]:
FUNC = func.upper()
if func == "reshape":
exec(
f"{func} = lambda this, shape, cp=None: ndarray(_csp.ManipOp.reshape(this._t, shape, cp))"
)

for func in api.api_categories["ReduceOp"]:
FUNC = func.upper()
exec(
Expand Down
2 changes: 1 addition & 1 deletion sharpy/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@
"concat", # (arrays, /, *, axis=0)
"expand_dims", # (x, /, *, axis)
"flip", # (x, /, *, axis=None)
"reshape", # (x, /, shape)
"reshape", # (x, /, shape, *, copy: bool | None = None)
"roll", # (x, /, shift, *, axis=None)
"squeeze", # (x, /, axis)
"stack", # (arrays, /, *, axis=0)
Expand Down
3 changes: 2 additions & 1 deletion src/ManipOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "sharpy/jit/mlir.hpp"

#include <imex/Dialect/Dist/IR/DistOps.h>
#include <imex/Dialect/Dist/Utils/Utils.h>
#include <imex/Dialect/NDArray/IR/NDArrayOps.h>
#include <mlir/IR/Builders.h>

Expand Down Expand Up @@ -41,7 +42,7 @@ struct DeferredReshape : public Deferred {
: ::imex::getIntAttr(builder, COPY_ALWAYS ? true : false, 1);

auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType());
auto outTyp = imex::dist::cloneWithShape(aTyp, shape());

auto op =
builder.create<::imex::ndarray::ReshapeOp>(loc, outTyp, av, shp, copyA);
Expand Down
Loading
Loading