@@ -98,7 +98,7 @@ def set_custom_status(self, custom_status: Any) -> None:
9898 pass
9999
100100 @abstractmethod
101- def create_timer (self , fire_at : Union [datetime , timedelta ]) -> Task :
101+ def create_timer (self , fire_at : Union [datetime , timedelta ]) -> CancellableTask :
102102 """Create a Timer Task to fire after at the specified deadline.
103103
104104 Parameters
@@ -228,10 +228,10 @@ def call_sub_orchestrator(self, orchestrator: Union[Orchestrator[TInput, TOutput
228228 """
229229 pass
230230
231- # TOOD: Add a timeout parameter, which allows the task to be canceled if the event is
231+ # TOOD: Add a timeout parameter, which allows the task to be cancelled if the event is
232232 # not received within the specified timeout. This requires support for task cancellation.
233233 @abstractmethod
234- def wait_for_external_event (self , name : str ) -> CompletableTask :
234+ def wait_for_external_event (self , name : str ) -> CancellableTask :
235235 """Wait asynchronously for an event to be raised with the name `name`.
236236
237237 Parameters
@@ -324,6 +324,10 @@ class OrchestrationStateError(Exception):
324324 pass
325325
326326
327+ class TaskCancelledError (Exception ):
328+ """Exception type for cancelled orchestration tasks."""
329+
330+
327331class Task (ABC , Generic [T ]):
328332 """Abstract base class for asynchronous tasks in a durable orchestration."""
329333 _result : T
@@ -435,6 +439,48 @@ def fail(self, message: str, details: Union[Exception, pb.TaskFailureDetails]):
435439 self ._parent .on_child_completed (self )
436440
437441
442+ class CancellableTask (CompletableTask [T ]):
443+ """A completable task that can be cancelled before it finishes."""
444+
445+ def __init__ (self ) -> None :
446+ super ().__init__ ()
447+ self ._is_cancelled = False
448+ self ._cancel_handler : Optional [Callable [[], None ]] = None
449+
450+ @property
451+ def is_cancelled (self ) -> bool :
452+ """Returns True if the task was cancelled, False otherwise."""
453+ return self ._is_cancelled
454+
455+ def get_result (self ) -> T :
456+ if self ._is_cancelled :
457+ raise TaskCancelledError ('The task was cancelled.' )
458+ return super ().get_result ()
459+
460+ def set_cancel_handler (self , cancel_handler : Callable [[], None ]) -> None :
461+ self ._cancel_handler = cancel_handler
462+
463+ def cancel (self ) -> bool :
464+ """Attempts to cancel this task.
465+
466+ Returns
467+ -------
468+ bool
469+ True if cancellation was applied, False if the task had already completed.
470+ """
471+ if self ._is_complete :
472+ return False
473+
474+ if self ._cancel_handler is not None :
475+ self ._cancel_handler ()
476+
477+ self ._is_cancelled = True
478+ self ._is_complete = True
479+ if self ._parent is not None :
480+ self ._parent .on_child_completed (self )
481+ return True
482+
483+
438484class RetryableTask (CompletableTask [T ]):
439485 """A task that can be retried according to a retry policy."""
440486
@@ -474,14 +520,29 @@ def compute_next_delay(self) -> Optional[timedelta]:
474520 return None
475521
476522
477- class TimerTask (CompletableTask [ T ]):
478-
479- def __init__ ( self ) -> None :
523+ class TimerTask (CancellableTask [ None ]):
524+ def __init__ ( self , final_fire_at : Optional [ datetime ] = None ,
525+ maximum_timer_interval : Optional [ timedelta ] = None ) :
480526 super ().__init__ ()
527+ self ._final_fire_at = final_fire_at
528+ self ._maximum_timer_interval = maximum_timer_interval
481529
482530 def set_retryable_parent (self , retryable_task : RetryableTask ):
483531 self ._retryable_parent = retryable_task
484532
533+ def _handle_timer_fired (self , current_utc_datetime : datetime ) -> Optional [datetime ]:
534+ if (self ._final_fire_at is not None
535+ and self ._maximum_timer_interval is not None
536+ and current_utc_datetime < self ._final_fire_at ):
537+ return self ._get_next_fire_at (current_utc_datetime )
538+ super ().complete (None )
539+ return None
540+
541+ def _get_next_fire_at (self , current_utc_datetime : datetime ) -> datetime :
542+ if current_utc_datetime + self ._maximum_timer_interval < self ._final_fire_at :
543+ return current_utc_datetime + self ._maximum_timer_interval
544+ return self ._final_fire_at
545+
485546
486547class WhenAnyTask (CompositeTask [Task ]):
487548 """A task that completes when any of its child tasks complete."""
0 commit comments