- 
                Notifications
    You must be signed in to change notification settings 
- Fork 931
Fixing task ID replacement for MNP jobs on AWS Batch #2574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
2a5c211
              ce15127
              86c3b84
              f2ee285
              21a62ac
              26fa49c
              cc5b44e
              96259ac
              9fa4391
              526c81c
              a0f68ca
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -421,6 +421,86 @@ def _wait_for_mapper_tasks(self, flow, step_name): | |
| TIMEOUT = 600 | ||
| last_completion_timeout = time.time() + TIMEOUT | ||
| print("Waiting for batch secondary tasks to finish") | ||
|  | ||
| # Prefer Batch API when metadata is local (nodes can't share local metadata files). | ||
| # If metadata isn't bound yet but we are on Batch, also prefer Batch API. | ||
| md = getattr(self, "metadata", None) | ||
| if md is not None and md.TYPE == "local": | ||
| return self._wait_for_mapper_tasks_datastore( | ||
| flow, step_name, last_completion_timeout | ||
| ) | ||
| if md is None and "AWS_BATCH_JOB_ID" in os.environ: | ||
| return self._wait_for_mapper_tasks_datastore( | ||
| flow, step_name, last_completion_timeout | ||
| ) | ||
| return self._wait_for_mapper_tasks_metadata( | ||
| flow, step_name, last_completion_timeout | ||
| ) | ||
|  | ||
| def _wait_for_mapper_tasks_datastore( | ||
| self, flow, step_name, last_completion_timeout | ||
| ): | ||
| """ | ||
| Poll the shared datastore (S3) for DONE markers for each mapper task. | ||
| This avoids relying on a metadata service or local metadata files. | ||
| """ | ||
| from metaflow.datastore.task_datastore import TaskDataStore | ||
|  | ||
| pathspecs = getattr(flow, "_control_mapper_tasks", []) | ||
| total = len(pathspecs) | ||
| if total == 0: | ||
| print("No mapper tasks discovered for datastore wait; returning") | ||
| return True | ||
|  | ||
| print("Waiting for mapper DONE markers in datastore for %d tasks" % total) | ||
| poll_sleep = 3.0 | ||
| while last_completion_timeout > time.time(): | ||
| time.sleep(poll_sleep) | ||
| completed = 0 | ||
| for ps in pathspecs: | ||
| try: | ||
| parts = ps.split("/") | ||
| if len(parts) == 3: | ||
| run_id, step, task_id = parts | ||
| else: | ||
| # Fallback in case of unexpected format | ||
| run_id, step, task_id = self.run_id, step_name, parts[-1] | ||
| tds = TaskDataStore( | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so - the TaskDataStore instances are constructed with a fixed run_id/step_name/task_id so it won't be able to query the mapper tasks | ||
| self.flow_datastore, | ||
| run_id, | ||
| step, | ||
| task_id, | ||
| mode="r", | ||
| allow_not_done=True, | ||
| ) | ||
| if tds.has_metadata(TaskDataStore.METADATA_DONE_SUFFIX): | ||
| completed += 1 | ||
| except Exception as e: | ||
| self.logger.warning("Datastore wait: error checking %s: %s", ps, e) | ||
| continue | ||
| if completed == total: | ||
| self.logger.info( | ||
| "All mapper tasks have written DONE markers to datastore" | ||
| ) | ||
| return True | ||
| self.logger.info( | ||
| "Waiting for mapper DONE markers. Finished: %d/%d" % (completed, total) | ||
| ) | ||
| poll_sleep = min(poll_sleep * 1.25, 10.0) | ||
|  | ||
| raise Exception( | ||
| "Batch secondary workers did not finish in %s seconds (datastore wait)" | ||
| % (time.time() - (last_completion_timeout - 600)) | ||
| ) | ||
|  | ||
| def _wait_for_mapper_tasks_metadata(self, flow, step_name, last_completion_timeout): | ||
| """ | ||
| Polls Metaflow metadata (Step client) for task completion. | ||
| Works with service-backed metadata providers but can fail with local metadata | ||
| in multi-node setups due to isolated per-node filesystems. | ||
| """ | ||
| from metaflow import Step | ||
|  | ||
| while last_completion_timeout > time.time(): | ||
| time.sleep(2) | ||
| try: | ||
|  | @@ -432,7 +512,7 @@ def _wait_for_mapper_tasks(self, flow, step_name): | |
| ): # for some reason task.finished fails | ||
| return True | ||
| else: | ||
| print( | ||
| self.logger.info( | ||
| "Waiting for all parallel tasks to finish. Finished: {}/{}".format( | ||
| len(tasks), | ||
| len(flow._control_mapper_tasks), | ||
|  | @@ -441,7 +521,8 @@ def _wait_for_mapper_tasks(self, flow, step_name): | |
| except Exception: | ||
| pass | ||
| raise Exception( | ||
| "Batch secondary workers did not finish in %s seconds" % TIMEOUT | ||
| "Batch secondary workers did not finish in %s seconds" | ||
| % (time.time() - (last_completion_timeout - 600)) | ||
| ) | ||
|  | ||
| @classmethod | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this change relevant to the batch parallel issue, or something different? the PR seems to work fine without this part as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed it works, this was to cover the instances where particular AWS keys have already been set in the environment, which messed up getting the AWS client. This is relevant for the batch process given that we're using the batch client now.