Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 16 additions & 28 deletions src/boututils/run_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pathlib
import re
import subprocess
from subprocess import PIPE, STDOUT, Popen, call

if os.name == "nt":
# Default on Windows
Expand Down Expand Up @@ -32,7 +31,7 @@ def getmpirun(default=DEFAULT_MPIRUN):
return MPIRUN


def shell(command, pipe=False):
def shell(command, pipe=False, env=None):
"""Run a shell command

Parameters
Expand All @@ -48,24 +47,14 @@ def shell(command, pipe=False):
tuple : (int, str)
The return code, and either command output if pipe=True else None
"""
output = None
status = 0
if pipe:
child = Popen(command, stderr=STDOUT, stdout=PIPE, shell=True)
# This returns a b'string' which is casted to string in
# python 2. However, as we want to use f.write() in our
# runtest, we cast this to utf-8 here
output = child.stdout.read().decode("utf-8", "ignore")
# Wait for the process to finish. Note that child.wait()
# would have deadlocked the system as stdout is PIPEd, we
# therefore use communicate, which in the end also waits for
# the process to finish
child.communicate()
status = child.returncode
else:
status = call(command, shell=True)

return status, output

result = subprocess.run(
command, shell=True, capture_output=pipe, env=env, text=True
)
output = result.stdout if pipe else ""
if result.stderr:
output = f"{output}\nSTDERR:\n{result.stderr}"
return result.returncode, output


def determineNumberOfCPUs():
Expand Down Expand Up @@ -236,17 +225,15 @@ def launch(
if output is not None:
cmd = f"{cmd} > {output}"

if mthread is not None:
if os.name == "nt":
# We're on windows, so we have to do it a little different
cmd = f'cmd /C "set OMP_NUM_THREADS={mthread} && {cmd}"'
else:
cmd = f"OMP_NUM_THREADS={mthread} {cmd}"
# Set OMP_NUM_THREADS if mthread is provided (for OpenMP in BOUT++)
env = os.environ.copy()
if mthread:
env["OMP_NUM_THREADS"] = str(mthread)

if verbose:
print(cmd)

return shell(cmd, pipe=pipe)
return shell(cmd, pipe=pipe, env=env)


def shell_safe(command, *args, **kwargs):
Expand Down Expand Up @@ -284,8 +271,9 @@ def launch_safe(command, *args, **kwargs):
Optional arguments passed to `shell`

"""

s, out = launch(command, *args, **kwargs)
if s:
if s != 0:
raise RuntimeError(
f"Run failed with {s}.\nCommand was:\n{command}\n\nOutput was\n\n{out}"
)
Expand Down
Loading