From f4df5db66a4757802a28bd1ddd76756839038d52 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Wed, 15 Jan 2025 14:03:11 -0800 Subject: [PATCH] dummy commit --- metaflow/datastore/spin_datastore/__init__.py | 0 .../spin_datastore/inputs_datastore.py | 88 +++++++++++++++++ .../spin_datastore/step_datastore.py | 95 +++++++++++++++++++ 3 files changed, 183 insertions(+) create mode 100644 metaflow/datastore/spin_datastore/__init__.py create mode 100644 metaflow/datastore/spin_datastore/inputs_datastore.py create mode 100644 metaflow/datastore/spin_datastore/step_datastore.py diff --git a/metaflow/datastore/spin_datastore/__init__.py b/metaflow/datastore/spin_datastore/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/metaflow/datastore/spin_datastore/inputs_datastore.py b/metaflow/datastore/spin_datastore/inputs_datastore.py new file mode 100644 index 00000000000..898048869a7 --- /dev/null +++ b/metaflow/datastore/spin_datastore/inputs_datastore.py @@ -0,0 +1,88 @@ +from . import SpinDatastore + + +class SpinInput(object): + def __init__(self, artifacts, task): + self.artifacts = artifacts + self.task = task + + def __getattr__(self, name): + # We always look for any artifacts provided by the user first + if self.artifacts is not None and name in self.artifacts: + return self.artifacts[name] + + try: + return getattr(self.task.artifacts, name).data + except AttributeError: + raise AttributeError( + f"Attribute '{name}' not found in the previous execution of the task for " + f"`{self.step_name}`." + ) + + +class StaticSpinInputsDatastore(SpinDatastore): + def __init__(self, spin_parser_validator): + super(StaticSpinInputsDatastore, self).__init__(spin_parser_validator) + self._previous_tasks = {} + + def __getattr__(self, name): + if name not in self.previous_steps: + raise AttributeError( + f"Attribute '{name}' not found in the previous execution of the task for " + f"`{self.step_name}`." + ) + + input_step = SpinInput( + self.spin_parser_validator.artifacts["join"][name], + self.get_previous_tasks[name], + ) + setattr(self, name, input_step) + return input_step + + def __iter__(self): + for prev_step_name in self.previous_steps: + yield self[prev_step_name] + + def __len__(self): + return len(self.get_previous_tasks) + + @property + def get_previous_tasks(self): + if self._previous_tasks: + return self._previous_tasks + + for prev_step_name in self.previous_steps: + previous_task = self.get_all_previous_tasks(prev_step_name) + self._previous_tasks[prev_step_name] = previous_task + return self._previous_tasks + + +class SpinInputsDatastore(SpinDatastore): + def __init__(self, spin_parser_validator): + super(SpinInputsDatastore, self).__init__(spin_parser_validator) + self._previous_tasks = None + + def __len__(self): + return len(self.get_previous_tasks) + + def __getitem__(self, idx): + _item_task = self.get_previous_tasks[idx] + _item_artifacts = self.spin_parser_validator.artifacts + # _item_artifacts = self.spin_parser_validator.artifacts[idx] + return SpinInput(_item_artifacts, _item_task) + + def __iter__(self): + for idx in range(len(self.get_previous_tasks)): + yield self[idx] + + @property + def get_previous_tasks(self): + if self._previous_tasks: + return self._previous_tasks + + # This a join step for a foreach split, so only has one previous step + prev_step_name = self.previous_steps[0] + self._previous_tasks = self.get_all_previous_tasks(prev_step_name) + # Sort the tasks by index + self._previous_tasks = sorted(self._previous_tasks, key=lambda x: x.index) + return self._previous_tasks diff --git a/metaflow/datastore/spin_datastore/step_datastore.py b/metaflow/datastore/spin_datastore/step_datastore.py new file mode 100644 index 00000000000..a33bbfca2cb --- /dev/null +++ b/metaflow/datastore/spin_datastore/step_datastore.py @@ -0,0 +1,95 @@ +class LinearStepDatastore(object): + def __init__(self, task_pathspec): + from metaflow import Task + + self._task_pathspec = task_pathspec + self._task = Task(task_pathspec, _namespace_check=False) + self._previous_task = None + self._data = {} + + # Set them to empty dictionaries in order to persist artifacts + # See `persist` method in `TaskDatastore` for more details + self._objects = {} + self._info = {} + + def __contains__(self, name): + try: + _ = self.__getattr__(name) + except AttributeError: + return False + return True + + def __getitem__(self, name): + return self.__getattr__(name) + + def __setitem__(self, name, value): + self._data[name] = value + + def __getattr__(self, name): + # Check internal data first + if name in self._data: + return self._data[name] + + # We always look for any artifacts provided by the user first + if name in self.artifacts: + return self.artifacts[name] + + if self.run_id is None: + raise AttributeError( + f"Attribute '{name}' not provided by the user and no `run_id` was provided. " + ) + + # If the linear step is part of a foreach step, we need to set the input attribute + # and the index attribute + if name == "input": + if not self._task.index: + raise AttributeError( + f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a foreach step." + ) + + foreach_stack = self._task["_foreach_stack"].data + foreach_index = foreach_stack[-1].index + foreach_var = foreach_stack[-1].var + + # Fetch the artifact corresponding to the foreach var and index from the previous task + input_val = self.previous_task[foreach_var].data[foreach_index] + setattr(self, name, input_val) + return input_val + + # If the linear step is part of a foreach step, we need to set the index attribute + if name == "index": + if not self._task.index: + raise AttributeError( + f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a foreach step." + ) + foreach_stack = self._task["_foreach_stack"].data + foreach_index = foreach_stack[-1].index + setattr(self, name, foreach_index) + return foreach_index + + # If the user has not provided the artifact, we look for it in the + # task using the client API + try: + return getattr(self.previous_task.artifacts, name).data + except AttributeError: + raise AttributeError( + f"Attribute '{name}' not found in the previous execution of the task for " + f"`{self.step_name}`." + ) + + @property + def previous_task(self): + # This is a linear step, so we only have one immediate ancestor + if self._previous_task: + return self._previous_task + + prev_task_pathspecs = self._task.immediate_ancestors + prev_task_pathspec = list(chain.from_iterable(prev_task_pathspecs.values()))[0] + self._previous_task = Task(prev_task_pathspec, _namespace_check=False) + return self._previous_task + + def get(self, key, default=None): + try: + return self.__getattr__(key) + except AttributeError: + return default