Skip to content

Commit

Permalink
dummy commit
Browse files Browse the repository at this point in the history
  • Loading branch information
talsperre committed Jan 23, 2025
1 parent 23d7f12 commit f4df5db
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 0 deletions.
Empty file.
88 changes: 88 additions & 0 deletions metaflow/datastore/spin_datastore/inputs_datastore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from . import SpinDatastore


class SpinInput(object):
def __init__(self, artifacts, task):
self.artifacts = artifacts
self.task = task

def __getattr__(self, name):
# We always look for any artifacts provided by the user first
if self.artifacts is not None and name in self.artifacts:
return self.artifacts[name]

try:
return getattr(self.task.artifacts, name).data
except AttributeError:
raise AttributeError(
f"Attribute '{name}' not found in the previous execution of the task for "
f"`{self.step_name}`."
)


class StaticSpinInputsDatastore(SpinDatastore):
def __init__(self, spin_parser_validator):
super(StaticSpinInputsDatastore, self).__init__(spin_parser_validator)
self._previous_tasks = {}

def __getattr__(self, name):
if name not in self.previous_steps:
raise AttributeError(
f"Attribute '{name}' not found in the previous execution of the task for "
f"`{self.step_name}`."
)

input_step = SpinInput(
self.spin_parser_validator.artifacts["join"][name],
self.get_previous_tasks[name],
)
setattr(self, name, input_step)
return input_step

def __iter__(self):
for prev_step_name in self.previous_steps:
yield self[prev_step_name]

def __len__(self):
return len(self.get_previous_tasks)

@property
def get_previous_tasks(self):
if self._previous_tasks:
return self._previous_tasks

for prev_step_name in self.previous_steps:
previous_task = self.get_all_previous_tasks(prev_step_name)
self._previous_tasks[prev_step_name] = previous_task
return self._previous_tasks


class SpinInputsDatastore(SpinDatastore):
def __init__(self, spin_parser_validator):
super(SpinInputsDatastore, self).__init__(spin_parser_validator)
self._previous_tasks = None

def __len__(self):
return len(self.get_previous_tasks)

def __getitem__(self, idx):
_item_task = self.get_previous_tasks[idx]
_item_artifacts = self.spin_parser_validator.artifacts
# _item_artifacts = self.spin_parser_validator.artifacts[idx]
return SpinInput(_item_artifacts, _item_task)

def __iter__(self):
for idx in range(len(self.get_previous_tasks)):
yield self[idx]

@property
def get_previous_tasks(self):
if self._previous_tasks:
return self._previous_tasks

# This a join step for a foreach split, so only has one previous step
prev_step_name = self.previous_steps[0]
self._previous_tasks = self.get_all_previous_tasks(prev_step_name)
# Sort the tasks by index
self._previous_tasks = sorted(self._previous_tasks, key=lambda x: x.index)
return self._previous_tasks
95 changes: 95 additions & 0 deletions metaflow/datastore/spin_datastore/step_datastore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
class LinearStepDatastore(object):
def __init__(self, task_pathspec):
from metaflow import Task

self._task_pathspec = task_pathspec
self._task = Task(task_pathspec, _namespace_check=False)
self._previous_task = None
self._data = {}

# Set them to empty dictionaries in order to persist artifacts
# See `persist` method in `TaskDatastore` for more details
self._objects = {}
self._info = {}

def __contains__(self, name):
try:
_ = self.__getattr__(name)
except AttributeError:
return False
return True

def __getitem__(self, name):
return self.__getattr__(name)

def __setitem__(self, name, value):
self._data[name] = value

def __getattr__(self, name):
# Check internal data first
if name in self._data:
return self._data[name]

# We always look for any artifacts provided by the user first
if name in self.artifacts:
return self.artifacts[name]

if self.run_id is None:
raise AttributeError(
f"Attribute '{name}' not provided by the user and no `run_id` was provided. "
)

# If the linear step is part of a foreach step, we need to set the input attribute
# and the index attribute
if name == "input":
if not self._task.index:
raise AttributeError(
f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a foreach step."
)

foreach_stack = self._task["_foreach_stack"].data
foreach_index = foreach_stack[-1].index
foreach_var = foreach_stack[-1].var

# Fetch the artifact corresponding to the foreach var and index from the previous task
input_val = self.previous_task[foreach_var].data[foreach_index]
setattr(self, name, input_val)
return input_val

# If the linear step is part of a foreach step, we need to set the index attribute
if name == "index":
if not self._task.index:
raise AttributeError(
f"Attribute '{name}' does not exist for step `{self.step_name}` as it is not part of a foreach step."
)
foreach_stack = self._task["_foreach_stack"].data
foreach_index = foreach_stack[-1].index
setattr(self, name, foreach_index)
return foreach_index

# If the user has not provided the artifact, we look for it in the
# task using the client API
try:
return getattr(self.previous_task.artifacts, name).data
except AttributeError:
raise AttributeError(
f"Attribute '{name}' not found in the previous execution of the task for "
f"`{self.step_name}`."
)

@property
def previous_task(self):
# This is a linear step, so we only have one immediate ancestor
if self._previous_task:
return self._previous_task

prev_task_pathspecs = self._task.immediate_ancestors
prev_task_pathspec = list(chain.from_iterable(prev_task_pathspecs.values()))[0]
self._previous_task = Task(prev_task_pathspec, _namespace_check=False)
return self._previous_task

def get(self, key, default=None):
try:
return self.__getattr__(key)
except AttributeError:
return default

0 comments on commit f4df5db

Please sign in to comment.