Skip to content

Commit

Permalink
Make it work with subprocess and truncated buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
talsperre committed Jan 23, 2025
1 parent 5cb5292 commit 23d7f12
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 38 deletions.
5 changes: 5 additions & 0 deletions metaflow/cli_components/step_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def spin_internal(
max_user_code_retries=None,
namespace=None,
):
import sys

if ctx.obj.is_quiet:
echo = echo_dev_null
else:
Expand All @@ -248,4 +250,7 @@ def spin_internal(
None, # no unbounded foreach context
)
# echo("Task is: ", task)
print("Task is: ", task)
print("I am here 3")
print("sys.executable: ", sys.executable)
# pass
129 changes: 91 additions & 38 deletions metaflow/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
self._whitelist_decorators = None
self._config_file_name = None
self._max_log_size = max_log_size
self._encoding = sys.stdout.encoding or "UTF-8"

# Create a new run_id for the spin task
self._run_id = self._metadata.new_run_id()
Expand Down Expand Up @@ -203,10 +204,8 @@ def _new_task(self, step, input_paths=None, **kwargs):
)

def execute(self):
exception = None
with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as config_file:
# Configurations are passed through a file to avoid overloading the
# command-line. We only need to create this file once and it can be reused
# for any task launch
config_value = dump_config_values(self._flow)
if config_value:
json.dump(config_value, config_file)
Expand All @@ -215,30 +214,36 @@ def execute(self):
else:
self._config_file_name = None

self.task = self._new_task(self._step_func.name, {})
_ds = self._flow_datastore.get_task_datastore(
self._run_id, self._step_func.name, self.task.task_id, attempt=0, mode="w"
)

for deco in self.whitelist_decorators:
deco.runtime_task_created(
_ds,
self.task = self._new_task(self._step_func.name, {})
_ds = self._flow_datastore.get_task_datastore(
self._run_id,
self._step_func.name,
self.task.task_id,
self.split_index,
self.input_paths,
is_cloned=False,
ubf_context=None,
attempt=0,
mode="w",
)

self.launch_spin()
for deco in self.whitelist_decorators:
deco.runtime_task_created(
_ds,
self.task.task_id,
self.split_index,
self.input_paths,
is_cloned=False,
ubf_context=None,
)

# Start a new worker to spin a step
# on finish clean tasks
exception = None
for deco in self.whitelist_decorators:
deco.runtime_finished(exception)
try:
self._launch_and_monitor_task()
except Exception as ex:
self._logger("Task failed.", system_msg=True, bad=True)
exception = ex
raise
finally:
for deco in self.whitelist_decorators:
deco.runtime_finished(exception)

def launch_spin(self):
def _launch_and_monitor_task(self):
args = CLIArgs(self.task, spin=True)
env = dict(os.environ)

Expand All @@ -255,28 +260,76 @@ def launch_spin(self):
if self._config_file_name:
args.top_level_options["local-config-file"] = self._config_file_name

print(f"Args Entrypoint updated is {args.entrypoint}")
env.update(args.get_env())
env["PYTHONUNBUFFERED"] = "x"

stdout_buffer = TruncatedBuffer("stdout", self._max_log_size)
stderr_buffer = TruncatedBuffer("stderr", self._max_log_size)

cmdline = args.get_args()
print(f"Command line is: {cmdline}")
self._logger(f"Launching command: {' '.join(cmdline)}", system_msg=True)

process = subprocess.Popen(
cmdline,
env=env,
bufsize=1,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
try:
process = subprocess.Popen(
cmdline,
env=env,
bufsize=1,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
text=True,
)
except Exception as e:
raise TaskFailed(self.task, f"Failed to launch task: {str(e)}")

while True:
stdout_line = process.stdout.readline()
if stdout_line:
self._process_output(stdout_line, stdout_buffer)

stderr_line = process.stderr.readline()
if stderr_line:
self._process_output(stderr_line, stderr_buffer, is_stderr=True)

if process.poll() is not None:
break

# Process any remaining output
for line in process.stdout:
self._process_output(line, stdout_buffer)
for line in process.stderr:
self._process_output(line, stderr_buffer, is_stderr=True)

returncode = process.wait()

self.task.save_metadata(
"runtime",
{
"return_code": returncode,
"success": returncode == 0,
},
)

if returncode != 0:
raise TaskFailed(self.task, f"Task failed with return code {returncode}")
else:
self._logger("Task finished successfully.", system_msg=True)

self.task.save_logs(
{
"stdout": stdout_buffer.get_buffer(),
"stderr": stderr_buffer.get_buffer(),
}
)

# Read and print subprocess output
stdout, stderr = process.communicate()
print("STDOUT:\n")
print(f"{stdout.decode()}")
print("-" * 100)
print("STDERR:\n")
print(f"stderr: {stderr.decode()}")
def _process_output(self, line, buffer, is_stderr=False):
buffer.write(line.encode(self._encoding))
text = line.strip()
self.task.log(
text,
system_msg=False,
timestamp=datetime.now(),
)


class NativeRuntime(object):
Expand Down

0 comments on commit 23d7f12

Please sign in to comment.