-
Notifications
You must be signed in to change notification settings - Fork 789
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
183 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from . import SpinDatastore | ||
|
||
|
||
class SpinInput(object): | ||
def __init__(self, artifacts, task): | ||
self.artifacts = artifacts | ||
self.task = task | ||
|
||
def __getattr__(self, name): | ||
# We always look for any artifacts provided by the user first | ||
if self.artifacts is not None and name in self.artifacts: | ||
return self.artifacts[name] | ||
|
||
try: | ||
return getattr(self.task.artifacts, name).data | ||
except AttributeError: | ||
raise AttributeError( | ||
f"Attribute '{name}' not found in the previous execution of the task for " | ||
f"`{self.step_name}`." | ||
) | ||
|
||
|
||
class StaticSpinInputsDatastore(SpinDatastore): | ||
def __init__(self, spin_parser_validator): | ||
super(StaticSpinInputsDatastore, self).__init__(spin_parser_validator) | ||
self._previous_tasks = {} | ||
|
||
def __getattr__(self, name): | ||
if name not in self.previous_steps: | ||
raise AttributeError( | ||
f"Attribute '{name}' not found in the previous execution of the task for " | ||
f"`{self.step_name}`." | ||
) | ||
|
||
input_step = SpinInput( | ||
self.spin_parser_validator.artifacts["join"][name], | ||
self.get_previous_tasks[name], | ||
) | ||
setattr(self, name, input_step) | ||
return input_step | ||
|
||
def __iter__(self): | ||
for prev_step_name in self.previous_steps: | ||
yield self[prev_step_name] | ||
|
||
def __len__(self): | ||
return len(self.get_previous_tasks) | ||
|
||
@property | ||
def get_previous_tasks(self): | ||
if self._previous_tasks: | ||
return self._previous_tasks | ||
|
||
for prev_step_name in self.previous_steps: | ||
previous_task = self.get_all_previous_tasks(prev_step_name) | ||
self._previous_tasks[prev_step_name] = previous_task | ||
return self._previous_tasks | ||
|
||
|
||
class SpinInputsDatastore(SpinDatastore): | ||
def __init__(self, spin_parser_validator): | ||
super(SpinInputsDatastore, self).__init__(spin_parser_validator) | ||
self._previous_tasks = None | ||
|
||
def __len__(self): | ||
return len(self.get_previous_tasks) | ||
|
||
def __getitem__(self, idx): | ||
_item_task = self.get_previous_tasks[idx] | ||
_item_artifacts = self.spin_parser_validator.artifacts | ||
# _item_artifacts = self.spin_parser_validator.artifacts[idx] | ||
return SpinInput(_item_artifacts, _item_task) | ||
|
||
def __iter__(self): | ||
for idx in range(len(self.get_previous_tasks)): | ||
yield self[idx] | ||
|
||
@property | ||
def get_previous_tasks(self): | ||
if self._previous_tasks: | ||
return self._previous_tasks | ||
|
||
# This a join step for a foreach split, so only has one previous step | ||
prev_step_name = self.previous_steps[0] | ||
self._previous_tasks = self.get_all_previous_tasks(prev_step_name) | ||
# Sort the tasks by index | ||
self._previous_tasks = sorted(self._previous_tasks, key=lambda x: x.index) | ||
return self._previous_tasks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
class LinearStepDatastore(object): | ||
def __init__(self, task_pathspec): | ||
from metaflow import Task | ||
|
||
self._task_pathspec = task_pathspec | ||
self._task = Task(task_pathspec, _namespace_check=False) | ||
self._previous_task = None | ||
self._data = {} | ||
|
||
# Set them to empty dictionaries in order to persist artifacts | ||
# See `persist` method in `TaskDatastore` for more details | ||
self._objects = {} | ||
self._info = {} | ||
|
||
def __contains__(self, name): | ||
try: | ||
_ = self.__getattr__(name) | ||
except AttributeError: | ||
return False | ||
return True | ||
|
||
def __getitem__(self, name): | ||
return self.__getattr__(name) | ||
|
||
def __setitem__(self, name, value): | ||
self._data[name] = value | ||
|
||
def __getattr__(self, name): | ||
# Check internal data first | ||
if name in self._data: | ||
return self._data[name] | ||
|
||
# We always look for any artifacts provided by the user first | ||
if name in self.artifacts: | ||
return self.artifacts[name] | ||
|
||
if self.run_id is None: | ||
raise AttributeError( | ||
f"Attribute '{name}' not provided by the user and no `run_id` was provided. " | ||
) | ||
|
||
# If the linear step is part of a foreach step, we need to set the input attribute | ||
# and the index attribute | ||
if name == "input": | ||
if not self._task.index: | ||
raise AttributeError( | ||
f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a foreach step." | ||
) | ||
|
||
foreach_stack = self._task["_foreach_stack"].data | ||
foreach_index = foreach_stack[-1].index | ||
foreach_var = foreach_stack[-1].var | ||
|
||
# Fetch the artifact corresponding to the foreach var and index from the previous task | ||
input_val = self.previous_task[foreach_var].data[foreach_index] | ||
setattr(self, name, input_val) | ||
return input_val | ||
|
||
# If the linear step is part of a foreach step, we need to set the index attribute | ||
if name == "index": | ||
if not self._task.index: | ||
raise AttributeError( | ||
f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a foreach step." | ||
) | ||
foreach_stack = self._task["_foreach_stack"].data | ||
foreach_index = foreach_stack[-1].index | ||
setattr(self, name, foreach_index) | ||
return foreach_index | ||
|
||
# If the user has not provided the artifact, we look for it in the | ||
# task using the client API | ||
try: | ||
return getattr(self.previous_task.artifacts, name).data | ||
except AttributeError: | ||
raise AttributeError( | ||
f"Attribute '{name}' not found in the previous execution of the task for " | ||
f"`{self.step_name}`." | ||
) | ||
|
||
@property | ||
def previous_task(self): | ||
# This is a linear step, so we only have one immediate ancestor | ||
if self._previous_task: | ||
return self._previous_task | ||
|
||
prev_task_pathspecs = self._task.immediate_ancestors | ||
prev_task_pathspec = list(chain.from_iterable(prev_task_pathspecs.values()))[0] | ||
self._previous_task = Task(prev_task_pathspec, _namespace_check=False) | ||
return self._previous_task | ||
|
||
def get(self, key, default=None): | ||
try: | ||
return self.__getattr__(key) | ||
except AttributeError: | ||
return default |