Skip to content

feat: improve task polling #409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 33 additions & 21 deletions src/posit/connect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from __future__ import annotations

import time

from typing_extensions import overload

from . import resources
Expand Down Expand Up @@ -97,39 +95,53 @@ def update(self, *args, **kwargs) -> None:
result = response.json()
super().update(**result)

def wait_for(self, *, initial_wait: int = 1, max_wait: int = 10, backoff: float = 1.5) -> None:
def wait_for(self, *, wait: int = 1, max_attempts: int | None = None) -> None:
"""Wait for the task to finish.

Parameters
----------
initial_wait : int, default 1
Initial wait time in seconds. First API request will use this as the wait parameter.
max_wait : int, default 10
wait : int, default 1
Maximum wait time in seconds between polling requests.
backoff : float, default 1.5
Backoff multiplier for increasing wait times.
max_attempts : int | None, default None
Maximum number of polling attempts. If None, polling will continue indefinitely.

Raises
------
TimeoutError
If the task does not finish within the maximum attempts.

Notes
-----
If the task finishes before the wait time or maximum attempts are reached, the function will return immediately. For example, if the wait time is set to 5 seconds and the task finishes in 2 seconds, the function will return after 2 seconds.

If the task does not finished after the maximum attempts, a TimeoutError will be raised. By default, the maximum attempts is None, which means the function will wait indefinitely until the task finishes.

Examples
--------
>>> task.wait_for()
None

Notes
-----
This method implements an exponential backoff strategy to reduce the number of API calls
while waiting for long-running tasks. The first request uses the initial_wait value,
and subsequent requests increase the wait time by the backoff factor, up to max_wait. To disable exponential backoff, set backoff to 1.0.
"""
wait_time = initial_wait
Waiting for a task to finish with a custom wait time.

while not self.is_finished:
self.update()
>>> task.wait_for(wait=5)
None

# Wait client-side
time.sleep(wait_time)
Waiting for a task with a maximum number of attempts.

# Calculate next wait time with backoff
wait_time = min(wait_time * backoff, max_wait)
>>> task.wait_for(max_attempts=3)
None
"""
attempts = 0
while not self.is_finished:
if max_attempts is not None and attempts >= max_attempts:
break
self.update(wait=wait)
attempts += 1

if not self.is_finished:
raise TimeoutError(
f"Task {self['id']} did not finish within the specified wait time or maximum attempts."
)


class Tasks(resources.Resources):
Expand Down
63 changes: 33 additions & 30 deletions tests/posit/connect/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest import mock

import pytest
import responses
from responses import BaseResponse, matchers

Expand Down Expand Up @@ -118,6 +119,7 @@ def test(self):
responses.get(
f"https://connect.example/__api__/v1/tasks/{uid}",
json={**load_mock_dict(f"v1/tasks/{uid}.json"), "finished": True},
match=[matchers.query_param_matcher({"wait": 1})],
),
]

Expand All @@ -127,35 +129,27 @@ def test(self):
assert not task.is_finished

# invoke
task.wait_for()
task.wait_for(wait=1)

# assert
assert task.is_finished
assert mock_tasks_get[0].call_count == 1
assert mock_tasks_get[1].call_count == 1

@responses.activate
@mock.patch("time.sleep", autospec=True)
def test_exponential_backoff(self, mock_sleep):
def test_with_custom_wait(self):
uid = "jXhOhdm5OOSkGhJw"

# behavior
mock_tasks_get = [
responses.get(
f"https://connect.example/__api__/v1/tasks/{uid}",
json={**load_mock_dict(f"v1/tasks/{uid}.json"), "finished": False},
),
responses.get(
f"https://connect.example/__api__/v1/tasks/{uid}",
json={**load_mock_dict(f"v1/tasks/{uid}.json"), "finished": False},
),
responses.get(
f"https://connect.example/__api__/v1/tasks/{uid}",
json={**load_mock_dict(f"v1/tasks/{uid}.json"), "finished": False},
),
responses.get(
f"https://connect.example/__api__/v1/tasks/{uid}",
json={**load_mock_dict(f"v1/tasks/{uid}.json"), "finished": True},
match=[matchers.query_param_matcher({"wait": 5})],
),
]

Expand All @@ -165,31 +159,19 @@ def test_exponential_backoff(self, mock_sleep):
assert not task.is_finished

# invoke
task.wait_for(initial_wait=1, max_wait=5, backoff=2.0)
task.wait_for(wait=5)

# assert
assert task.is_finished
assert mock_tasks_get[0].call_count == 1
assert mock_tasks_get[1].call_count == 1

# Verify sleep calls
mock_sleep.assert_has_calls([mock.call(1), mock.call(2), mock.call(4)], any_order=False)

@responses.activate
@mock.patch("time.sleep", autospec=True)
def test_no_backoff(self, mock_sleep):
def test_immediate_completion(self):
uid = "jXhOhdm5OOSkGhJw"

# behavior
mock_tasks_get = [
responses.get(
f"https://connect.example/__api__/v1/tasks/{uid}",
json={**load_mock_dict(f"v1/tasks/{uid}.json"), "finished": False},
),
responses.get(
f"https://connect.example/__api__/v1/tasks/{uid}",
json={**load_mock_dict(f"v1/tasks/{uid}.json"), "finished": False},
),
responses.get(
f"https://connect.example/__api__/v1/tasks/{uid}",
json={**load_mock_dict(f"v1/tasks/{uid}.json"), "finished": True},
Expand All @@ -199,18 +181,39 @@ def test_no_backoff(self, mock_sleep):
# setup
c = connect.Client("https://connect.example", "12345")
task = c.tasks.get(uid)
assert not task.is_finished
assert task.is_finished

# invoke
task.wait_for(initial_wait=2, max_wait=5, backoff=1.0)
task.wait_for(wait=1)

# assert
assert task.is_finished
assert mock_tasks_get[0].call_count == 1
assert mock_tasks_get[1].call_count == 1

# Verify sleep calls
mock_sleep.assert_has_calls([mock.call(2), mock.call(2)], any_order=False)
@responses.activate
def test_maximum_attempts(self):
uid = "jXhOhdm5OOSkGhJw"

# behavior
mock_tasks_get = [
responses.get(
f"https://connect.example/__api__/v1/tasks/{uid}",
json={**load_mock_dict(f"v1/tasks/{uid}.json"), "finished": False},
),
]

# setup
c = connect.Client("https://connect.example", "12345")
task = c.tasks.get(uid)
assert not task.is_finished

# invoke and assert
with pytest.raises(TimeoutError):
task.wait_for(wait=1, max_attempts=1)

# assert
assert not task.is_finished
assert mock_tasks_get[0].call_count == 2 # 1 for initial check, 1 for timeout check


class TestTasksGet:
Expand Down