Skip to content

Commit 9e8bafa

Browse files
authored
Verify environment variables (#94)
1 parent 0518487 commit 9e8bafa

22 files changed

+390
-138
lines changed

.bandit.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@
8383
# IPAS Required Checkers. Do not disable these
8484
# Additional checkers may be added if desired
8585
tests:
86-
[ 'B301', 'B302', 'B303', 'B304', 'B305', 'B306', 'B308', 'B310', 'B311', 'B312', 'B313', 'B314', 'B315', 'B316', 'B317', 'B318', 'B319', 'B320', 'B321', 'B323', 'B324', 'B401', 'B402', 'B403', 'B404', 'B405', 'B406', 'B407', 'B408', 'B409', 'B410', 'B411', 'B412', 'B413']
86+
[ 'B301', 'B302', 'B303', 'B304', 'B305', 'B306', 'B308', 'B310', 'B311', 'B312', 'B313', 'B314', 'B315', 'B316', 'B317', 'B318', 'B319', 'B320', 'B321', 'B323', 'B324', 'B401', 'B402', 'B403', 'B405', 'B406', 'B407', 'B408', 'B409', 'B410', 'B411', 'B412', 'B413']
8787

8888
# (optional) list skipped test IDs here, eg '[B101, B406]':
8989
# The following checkers are not required but be added to tests list if desired
9090
skips:
91-
[ 'B101', 'B102', 'B103', 'B104', 'B105', 'B106', 'B107', 'B108', 'B110', 'B112', 'B201', 'B501', 'B502', 'B503', 'B504', 'B505', 'B506', 'B507', 'B601', 'B602', 'B603', 'B604', 'B605', 'B606', 'B607', 'B608', 'B609', 'B610', 'B611', 'B701', 'B702', 'B703']
91+
[ 'B101', 'B102', 'B103', 'B104', 'B105', 'B106', 'B107', 'B108', 'B110', 'B112', 'B201', 'B404', 'B501', 'B502', 'B503', 'B504', 'B505', 'B506', 'B507', 'B601', 'B602', 'B603', 'B604', 'B605', 'B606', 'B607', 'B608', 'B609', 'B610', 'B611', 'B701', 'B702', 'B703']
9292

9393
### (optional) plugin settings - some test plugins require configuration data
9494
### that may be given here, per-plugin. All bandit test plugins have a built in

.github/workflows/ci.yml

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,16 @@ jobs:
2626
token: ${{ secrets.PATIMEX }}
2727
path: third_party/imex
2828
ref: ${{ env.IMEX_SHA }}
29-
- name: Cache miniconda
30-
id: cache-miniconda
31-
uses: actions/cache@v3
32-
env:
33-
MINICONDA_CACHE_NUMBER: 2 # Increase to reset cache
34-
with:
35-
path: third_party/install/miniconda/**
36-
key: ${{ runner.os }}-miniconda-${{ env.MINICONDA_CACHE_NUMBER }}-${{ hashFiles('conda-env.txt') }}
29+
# - name: Cache miniconda
30+
# id: cache-miniconda
31+
# uses: actions/cache@v3
32+
# env:
33+
# MINICONDA_CACHE_NUMBER: 2 # Increase to reset cache
34+
# with:
35+
# path: third_party/install/miniconda/**
36+
# key: ${{ runner.os }}-miniconda-${{ env.MINICONDA_CACHE_NUMBER }}-${{ hashFiles('conda-env.txt') }}
3737
- name: Miniconda
38-
if: steps.cache-miniconda.outputs.cache-hit != 'true'
38+
# if: steps.cache-miniconda.outputs.cache-hit != 'true'
3939
run: |
4040
rm -rf "$GITHUB_WORKSPACE"/third_party/install/miniconda
4141
mkdir -p "$GITHUB_WORKSPACE"/third_party/install
@@ -50,12 +50,12 @@ jobs:
5050
cd -
5151
conda create --file conda-env.txt --name sharpy
5252
conda clean -a -y
53-
- name: Setup miniconda
54-
if: steps.cache-miniconda.outputs.cache-hit == 'true'
55-
run: |
56-
echo "$GITHUB_WORKSPACE/third_party/install/miniconda/bin" >> $GITHUB_PATH
57-
echo "$GITHUB_WORKSPACE/third_party/install/miniconda/condabin" >> $GITHUB_PATH
58-
export PATH=$GITHUB_WORKSPACE/third_party/install/miniconda/bin:${PATH}
53+
# - name: Setup miniconda
54+
# if: steps.cache-miniconda.outputs.cache-hit == 'true'
55+
# run: |
56+
# echo "$GITHUB_WORKSPACE/third_party/install/miniconda/bin" >> $GITHUB_PATH
57+
# echo "$GITHUB_WORKSPACE/third_party/install/miniconda/condabin" >> $GITHUB_PATH
58+
# export PATH=$GITHUB_WORKSPACE/third_party/install/miniconda/bin:${PATH}
5959
- name: Setup LLVM Cache Var
6060
run: |
6161
echo 'LLVM_SHA<<EOF' >> $GITHUB_ENV

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ If your compiler does not default to a recent (e.g. g++ >= 9) version, try somet
3434
# single rank
3535
pytest test
3636
# distributed on multiple ($N) ranks/processes
37-
SHARPY_IDTR_SO=`pwd`/sharpy/libidtr.so mpirun -n $N python -m pytest test
37+
mpirun -n $N python -m pytest test
3838
```
3939

4040
## Running
@@ -57,13 +57,13 @@ sp.fini()
5757
Assuming the above is in file `simple.py` a single-process run is executed like
5858

5959
```bash
60-
SHARPY_IDTR_SO=`pwd`/sharpy/libidtr.so python simple.py
60+
python simple.py
6161
```
6262

6363
and multi-process run is executed like
6464

6565
```bash
66-
SHARPY_IDTR_SO=`pwd`/sharpy/libidtr.so mpirun -n 5 python simple.py
66+
mpirun -n 5 python simple.py
6767
```
6868

6969
### Distributed Execution without mpirun
@@ -76,7 +76,7 @@ Additionally SHARPY_MPI_HOSTS can be used to control the host to use for spawnin
7676
The following command will run the stencil example on 3 MPI ranks:
7777

7878
```bash
79-
SHARPY_IDTR_SO=`pwd`/sharpy/libidtr.so \
79+
SHARPY_FALLBACK=numpy \
8080
SHARPY_MPI_SPAWN=2 \
8181
SHARPY_MPI_EXECUTABLE=`which python` \
8282
SHARPY_MPI_EXE_ARGS="examples/stencil-2d.py 10 2000 star 2" \

examples/black_scholes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def run(nopt, backend, iterations, datatype):
176176
import sharpy as np
177177
from sharpy import fini, init, sync
178178

179-
device = os.getenv("SHARPY_USE_GPU", "")
179+
device = os.getenv("SHARPY_DEVICE", "")
180180
create_full = partial(np.full, device=device)
181181
erf = np.erf
182182

examples/shallow_water.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def run(n, backend, datatype, benchmark_mode):
5555
from sharpy import fini, init, sync
5656
from sharpy.numpy import fromfunction as _fromfunction
5757

58-
device = os.getenv("SHARPY_USE_GPU", "")
58+
device = os.getenv("SHARPY_DEVICE", "")
5959
create_full = partial(np.full, device=device)
6060
fromfunction = partial(_fromfunction, device=device)
6161

examples/stencil-2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def main():
118118
print("Data type = double precision")
119119
print("Compact representation of stencil loop body")
120120

121-
device = os.getenv("SHARPY_USE_GPU", "")
121+
device = os.getenv("SHARPY_DEVICE", "")
122122

123123
# there is certainly a more Pythonic way to initialize W,
124124
# but it will have no impact on performance.

examples/wave_equation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def run(n, backend, datatype, benchmark_mode):
5555
from sharpy import fini, init, sync
5656
from sharpy.numpy import fromfunction as _fromfunction
5757

58-
device = os.getenv("SHARPY_USE_GPU", "")
58+
device = os.getenv("SHARPY_DEVICE", "")
5959
create_full = partial(np.full, device=device)
6060
fromfunction = partial(_fromfunction, device=device)
6161

sharpy/__init__.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# At this point there are no checks of input arguments whatsoever, arguments
1414
# are simply forwarded as-is.
1515

16+
import os
17+
import re
1618
from importlib import import_module
1719
from os import getenv
1820
from typing import Any
@@ -42,8 +44,11 @@
4244

4345

4446
def init(cw=None):
47+
libidtr = os.path.join(os.path.dirname(__file__), "libidtr.so")
48+
assert os.path.isfile(libidtr), "libidtr.so not found"
49+
4550
cw = _sharpy_cw if cw is None else cw
46-
_init(cw)
51+
_init(cw, libidtr)
4752

4853

4954
def to_numpy(a):
@@ -64,31 +69,42 @@ def to_numpy(a):
6469
f"{op} = lambda this: ndarray(_csp.EWUnyOp.op(_csp.{OP}, this._t))"
6570
)
6671

72+
73+
def _validate_device(device):
74+
if len(device) == 0 or re.search(
75+
r"^((opencl|level-zero|cuda):)?(host|gpu|cpu|accelerator)(:\d+)?$",
76+
device,
77+
):
78+
return device
79+
else:
80+
raise ValueError(f"Invalid device string: {device}")
81+
82+
6783
for func in api.api_categories["Creator"]:
6884
FUNC = func.upper()
6985
if func == "full":
7086
exec(
71-
f"{func} = lambda shape, val, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, val, dtype, device, team))"
87+
f"{func} = lambda shape, val, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, val, dtype, _validate_device(device), team))"
7288
)
7389
elif func == "empty":
7490
exec(
75-
f"{func} = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, None, dtype, device, team))"
91+
f"{func} = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, None, dtype, _validate_device(device), team))"
7692
)
7793
elif func == "ones":
7894
exec(
79-
f"{func} = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, 1, dtype, device, team))"
95+
f"{func} = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, 1, dtype, _validate_device(device), team))"
8096
)
8197
elif func == "zeros":
8298
exec(
83-
f"{func} = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, 0, dtype, device, team))"
99+
f"{func} = lambda shape, dtype=float64, device='', team=1: ndarray(_csp.Creator.full(shape, 0, dtype, _validate_device(device), team))"
84100
)
85101
elif func == "arange":
86102
exec(
87-
f"{func} = lambda start, end, step, dtype=int64, device='', team=1: ndarray(_csp.Creator.arange(start, end, step, dtype, device, team))"
103+
f"{func} = lambda start, end, step, dtype=int64, device='', team=1: ndarray(_csp.Creator.arange(start, end, step, dtype, _validate_device(device), team))"
88104
)
89105
elif func == "linspace":
90106
exec(
91-
f"{func} = lambda start, end, step, endpoint, dtype=float64, device='', team=1: ndarray(_csp.Creator.linspace(start, end, step, endpoint, dtype, device, team))"
107+
f"{func} = lambda start, end, step, endpoint, dtype=float64, device='', team=1: ndarray(_csp.Creator.linspace(start, end, step, endpoint, dtype, _validate_device(device), team))"
92108
)
93109

94110
for func in api.api_categories["ReduceOp"]:
@@ -116,10 +132,17 @@ def to_numpy(a):
116132

117133
_fb_env = getenv("SHARPY_FALLBACK")
118134
if _fb_env is not None:
135+
if not _fb_env.isalnum():
136+
raise ValueError(f"Invalid SHARPY_FALLBACK value '{_fb_env}'")
119137

120138
class _fallback:
121139
"Fallback to whatever is provided in SHARPY_FALLBACK"
122-
_fb_lib = import_module(_fb_env)
140+
try:
141+
_fb_lib = import_module(_fb_env)
142+
except ModuleNotFoundError:
143+
raise ValueError(
144+
f"Invalid SHARPY_FALLBACK value '{_fb_env}': module not found"
145+
)
123146

124147
def __init__(self, fname: str, mod=None) -> None:
125148
"""get callable with name 'fname' from fallback-lib

src/Creator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
namespace SHARPY {
2424

25-
static const char *FORCE_DIST = getenv("SHARPY_FORCE_DIST");
25+
static bool FORCE_DIST = get_bool_env("SHARPY_FORCE_DIST");
2626

2727
inline uint64_t mkTeam(uint64_t team) {
2828
if (team && (FORCE_DIST || getTransceiver()->nranks() > 1)) {

src/Deferred.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,15 @@ void Runable::fini() { _deferred.clear(); }
8181
// When execution is needed, the function signature (input args, return
8282
// statement) is finalized, the function gets compiled and executed. The loop
8383
// completes by calling run() on the requesting object.
84-
void process_promises() {
84+
void process_promises(const std::string &libidtr) {
8585
int vtProcessSym, vtSHARPYClass, vtPopSym;
8686
VT(VT_classdef, "sharpy", &vtSHARPYClass);
8787
VT(VT_funcdef, "process", vtSHARPYClass, &vtProcessSym);
8888
VT(VT_funcdef, "pop", vtSHARPYClass, &vtPopSym);
8989
VT(VT_begin, vtProcessSym);
9090

9191
bool done = false;
92-
jit::JIT jit;
92+
jit::JIT jit(libidtr);
9393
std::vector<Runable::ptr_type> deleters;
9494

9595
do {

src/MPITransceiver.cpp

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
*/
66

77
#include "sharpy/MPITransceiver.hpp"
8+
#include "sharpy/UtilsAndTypes.hpp"
9+
#include <fstream>
810
#include <iostream>
911
#include <limits>
1012
#include <mpi.h>
@@ -44,28 +46,32 @@ MPITransceiver::MPITransceiver(bool is_cw)
4446
// Ok, let's spawn the clients.
4547
// I need some information for the startup.
4648
// 1. Name of the executable (default is the current exe)
47-
const char *_tmp = getenv("SHARPY_MPI_SPAWN");
48-
if (_tmp) {
49-
int nClientsToSpawn = atol(_tmp);
50-
std::string clientExe;
49+
int nClientsToSpawn = get_int_env("SHARPY_MPI_SPAWN", 0);
50+
if (nClientsToSpawn) {
5151
std::vector<std::string> args;
52-
_tmp = getenv("SHARPY_MPI_EXECUTABLE");
53-
if (!_tmp) {
54-
_tmp = getenv("PYTHON_EXE");
55-
if (!_tmp)
52+
std::string clientExe = get_text_env("SHARPY_MPI_EXECUTABLE");
53+
std::string exeArgs;
54+
if (clientExe.empty()) {
55+
auto pythonExe = get_text_env("PYTHON_EXE");
56+
if (pythonExe.empty())
5657
throw std::runtime_error("Spawning MPI processes requires setting "
5758
"'SHARPY_MPI_EXECUTABLE' or 'PYTHON_EXE'");
58-
clientExe = _tmp;
59+
if (!std::ifstream(pythonExe)) {
60+
throw std::runtime_error("Invalid PYTHON_EXE");
61+
}
62+
clientExe = pythonExe;
5963
// 2. arguments
60-
_tmp = "-c import sharpy as sp; sp.init(True)";
64+
exeArgs += " -c import sharpy as sp; sp.init(True)";
6165
args.push_back("-c");
6266
args.push_back("import sharpy as sp; sp.init(True)");
6367
} else {
64-
clientExe = _tmp;
68+
if (!std::ifstream(clientExe)) {
69+
throw std::runtime_error("Invalid SHARPY_MPI_EXECUTABLE.");
70+
}
6571
// 2. arguments
66-
_tmp = getenv("SHARPY_MPI_EXE_ARGS");
67-
if (_tmp) {
68-
std::istringstream iss(_tmp);
72+
exeArgs = get_text_env("SHARPY_MPI_EXE_ARGS");
73+
if (!exeArgs.empty()) {
74+
std::istringstream iss(exeArgs);
6975
std::copy(std::istream_iterator<std::string>(iss),
7076
std::istream_iterator<std::string>(),
7177
std::back_inserter(args));
@@ -78,20 +84,19 @@ MPITransceiver::MPITransceiver(bool is_cw)
7884
clientArgs[args.size()] = nullptr;
7985

8086
// 3. Special setting for MPI_Info: hosts
81-
const char *clientHost = getenv("SHARPY_MPI_HOSTS");
82-
87+
auto clientHost = get_text_env("SHARPY_MPI_HOSTS");
8388
// Prepare MPI_Info object:
8489
MPI_Info clientInfo = MPI_INFO_NULL;
85-
if (clientHost) {
90+
if (!clientHost.empty()) {
8691
MPI_Info_create(&clientInfo);
8792
MPI_Info_set(clientInfo, const_cast<char *>("host"),
88-
const_cast<char *>(clientHost));
93+
const_cast<char *>(clientHost.c_str()));
8994
std::cerr << "[SHARPY " << rank << "] Set MPI_Info_set(\"host\", \""
9095
<< clientHost << "\")\n";
9196
}
9297
// Now spawn the client processes:
9398
std::cerr << "[SHARPY " << rank << "] Spawning " << nClientsToSpawn
94-
<< " MPI processes (" << clientExe << " " << _tmp << ")"
99+
<< " MPI processes (" << clientExe << " " << exeArgs << ")"
95100
<< std::endl;
96101
int *errCodes = new int[nClientsToSpawn];
97102
MPI_Comm interComm;

src/_sharpy.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ using namespace pybind11::literals; // to bring _a
4444
#include "sharpy/itac.hpp"
4545
#include "sharpy/jit/mlir.hpp"
4646

47+
#include <fstream>
4748
#include <iostream>
4849

4950
namespace SHARPY {
@@ -96,23 +97,27 @@ void fini() {
9697
finied = true;
9798
}
9899

99-
void init(bool cw) {
100+
void init(bool cw, std::string libidtr) {
100101
if (inited)
101102
return;
102103

104+
if (!std::ifstream(libidtr)) {
105+
throw std::runtime_error(std::string("Cannot find libidtr.so"));
106+
}
107+
103108
init_transceiver(new MPITransceiver(cw));
104109
init_mediator(new MPIMediator());
105110
int cpu = sched_getcpu();
106111
std::cerr << "rank " << getTransceiver()->rank() << " is running on core "
107112
<< cpu << std::endl;
108113
if (cw) {
109114
if (getTransceiver()->rank()) {
110-
process_promises();
115+
process_promises(libidtr);
111116
fini();
112117
exit(0);
113118
}
114119
}
115-
pprocessor = new std::thread(process_promises);
120+
pprocessor = new std::thread(process_promises, libidtr);
116121
inited = true;
117122
finied = false;
118123
}

0 commit comments

Comments
 (0)