Skip to content

Commit 13a38db

Browse files
feat(AsyncRetriever): slice information during the polling request, Support creation_response interpolation in body (#541)
1 parent d7ebfd9 commit 13a38db

File tree

4 files changed

+155
-114
lines changed

4 files changed

+155
-114
lines changed

airbyte_cdk/sources/declarative/requesters/http_job_repository.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -320,14 +320,14 @@ def _get_polling_response_interpolation_context(self, job: AsyncJob) -> Dict[str
320320
return polling_response_context
321321

322322
def _get_create_job_stream_slice(self, job: AsyncJob) -> StreamSlice:
323-
stream_slice = StreamSlice(
324-
partition={},
325-
cursor_slice={},
326-
extra_fields={
323+
return StreamSlice(
324+
partition=job.job_parameters().partition,
325+
cursor_slice=job.job_parameters().cursor_slice,
326+
extra_fields=dict(job.job_parameters().extra_fields)
327+
| {
327328
"creation_response": self._get_creation_response_interpolation_context(job),
328329
},
329330
)
330-
return stream_slice
331331

332332
def _get_download_targets(self, job: AsyncJob) -> Iterable[str]:
333333
if not self.download_target_requester:

airbyte_cdk/sources/declarative/requesters/request_options/interpolated_nested_request_input_provider.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
1313
from airbyte_cdk.sources.types import Config, StreamSlice
14+
from airbyte_cdk.utils.mapping_helpers import get_interpolation_context
1415

1516

1617
@dataclass
@@ -52,8 +53,8 @@ def eval_request_inputs(
5253
:param next_page_token: The pagination token
5354
:return: The request inputs to set on an outgoing HTTP request
5455
"""
55-
kwargs = {
56-
"stream_slice": stream_slice,
57-
"next_page_token": next_page_token,
58-
}
56+
kwargs = get_interpolation_context(
57+
stream_slice=stream_slice,
58+
next_page_token=next_page_token,
59+
)
5960
return self._interpolator.eval(self.config, **kwargs) # type: ignore # self._interpolator is always initialized with a value and will not be None

airbyte_cdk/sources/declarative/requesters/request_options/interpolated_request_input_provider.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
99
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
1010
from airbyte_cdk.sources.types import Config, StreamSlice, StreamState
11+
from airbyte_cdk.utils.mapping_helpers import get_interpolation_context
1112

1213

1314
@dataclass
@@ -51,10 +52,10 @@ def eval_request_inputs(
5152
:param valid_value_types: A tuple of types that the interpolator should allow
5253
:return: The request inputs to set on an outgoing HTTP request
5354
"""
54-
kwargs = {
55-
"stream_slice": stream_slice,
56-
"next_page_token": next_page_token,
57-
}
55+
kwargs = get_interpolation_context(
56+
stream_slice=stream_slice,
57+
next_page_token=next_page_token,
58+
)
5859
interpolated_value = self._interpolator.eval( # type: ignore # self._interpolator is always initialized with a value and will not be None
5960
self.config,
6061
valid_key_types=valid_key_types,

unit_tests/sources/declarative/requesters/test_http_job_repository.py

Lines changed: 140 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
import json
5+
from typing import Optional
56
from unittest import TestCase
67
from unittest.mock import Mock
78

@@ -28,6 +29,8 @@
2829
)
2930
from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod
3031
from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever
32+
from airbyte_cdk.sources.message import MessageRepository
33+
from airbyte_cdk.sources.streams.http.error_handlers import ErrorHandler
3134
from airbyte_cdk.sources.types import StreamSlice
3235
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
3336
from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse
@@ -45,111 +48,12 @@
4548
a_record_id,a_value
4649
"""
4750
_A_CURSOR_FOR_PAGINATION = "a-cursor-for-pagination"
51+
_ERROR_HANDLER = DefaultErrorHandler(config=_ANY_CONFIG, parameters={})
4852

4953

5054
class HttpJobRepositoryTest(TestCase):
5155
def setUp(self) -> None:
52-
message_repository = Mock()
53-
error_handler = DefaultErrorHandler(config=_ANY_CONFIG, parameters={})
54-
55-
self._create_job_requester = HttpRequester(
56-
name="stream <name>: create_job",
57-
url_base=_URL_BASE,
58-
path=_EXPORT_PATH,
59-
error_handler=error_handler,
60-
http_method=HttpMethod.POST,
61-
config=_ANY_CONFIG,
62-
disable_retries=False,
63-
parameters={},
64-
message_repository=message_repository,
65-
use_cache=False,
66-
stream_response=False,
67-
)
68-
69-
self._polling_job_requester = HttpRequester(
70-
name="stream <name>: polling",
71-
url_base=_URL_BASE,
72-
path=_EXPORT_PATH + "/{{creation_response['id']}}",
73-
error_handler=error_handler,
74-
http_method=HttpMethod.GET,
75-
config=_ANY_CONFIG,
76-
disable_retries=False,
77-
parameters={},
78-
message_repository=message_repository,
79-
use_cache=False,
80-
stream_response=False,
81-
)
82-
83-
self._download_retriever = SimpleRetriever(
84-
requester=HttpRequester(
85-
name="stream <name>: fetch_result",
86-
url_base="",
87-
path="{{download_target}}",
88-
error_handler=error_handler,
89-
http_method=HttpMethod.GET,
90-
config=_ANY_CONFIG,
91-
disable_retries=False,
92-
parameters={},
93-
message_repository=message_repository,
94-
use_cache=False,
95-
stream_response=True,
96-
),
97-
record_selector=RecordSelector(
98-
extractor=ResponseToFileExtractor({}),
99-
record_filter=None,
100-
transformations=[],
101-
schema_normalization=TypeTransformer(TransformConfig.NoTransform),
102-
config=_ANY_CONFIG,
103-
parameters={},
104-
),
105-
primary_key=None,
106-
name="any name",
107-
paginator=DefaultPaginator(
108-
decoder=NoopDecoder(),
109-
page_size_option=None,
110-
page_token_option=RequestOption(
111-
field_name="locator",
112-
inject_into=RequestOptionType.request_parameter,
113-
parameters={},
114-
),
115-
pagination_strategy=CursorPaginationStrategy(
116-
cursor_value="{{ headers['Sforce-Locator'] }}",
117-
decoder=NoopDecoder(),
118-
config=_ANY_CONFIG,
119-
parameters={},
120-
),
121-
url_base=_URL_BASE,
122-
config=_ANY_CONFIG,
123-
parameters={},
124-
),
125-
config=_ANY_CONFIG,
126-
parameters={},
127-
)
128-
129-
self._repository = AsyncHttpJobRepository(
130-
creation_requester=self._create_job_requester,
131-
polling_requester=self._polling_job_requester,
132-
download_retriever=self._download_retriever,
133-
abort_requester=None,
134-
delete_requester=None,
135-
status_extractor=DpathExtractor(
136-
decoder=JsonDecoder(parameters={}),
137-
field_path=["status"],
138-
config={},
139-
parameters={} or {},
140-
),
141-
status_mapping={
142-
"ready": AsyncJobStatus.COMPLETED,
143-
"failure": AsyncJobStatus.FAILED,
144-
"pending": AsyncJobStatus.RUNNING,
145-
},
146-
download_target_extractor=DpathExtractor(
147-
decoder=JsonDecoder(parameters={}),
148-
field_path=["urls"],
149-
config={},
150-
parameters={} or {},
151-
),
152-
)
56+
self._repository = self._create_async_job_repository()
15357

15458
self._http_mocker = HttpMocker()
15559
self._http_mocker.__enter__()
@@ -178,6 +82,35 @@ def test_given_different_statuses_when_update_jobs_status_then_update_status_pro
17882
self._repository.update_jobs_status([job])
17983
assert job.status() == AsyncJobStatus.COMPLETED
18084

85+
def test_when_update_jobs_status_then_allow_access_to_stream_slice_information(self) -> None:
86+
stream_slice = StreamSlice(partition={"path": "path_from_slice"}, cursor_slice={})
87+
self._mock_create_response(_A_JOB_ID)
88+
self._http_mocker.get(
89+
HttpRequest(url=f"{_EXPORT_URL}/{stream_slice['path']}/{_A_JOB_ID}"),
90+
HttpResponse(body=json.dumps({"id": _A_JOB_ID, "status": "ready"})),
91+
)
92+
repository = self._create_async_job_repository(
93+
HttpRequester(
94+
name="stream <name>: polling",
95+
url_base=_URL_BASE,
96+
path=_EXPORT_PATH + "/{{stream_slice['path']}}/{{creation_response['id']}}",
97+
error_handler=_ERROR_HANDLER,
98+
http_method=HttpMethod.GET,
99+
config=_ANY_CONFIG,
100+
disable_retries=False,
101+
parameters={},
102+
message_repository=Mock(),
103+
# this might not align with the rest of the components in async job repository but if message_repository becomes important for tests, please share this instance with the other components
104+
use_cache=False,
105+
stream_response=False,
106+
)
107+
)
108+
109+
job = repository.start(stream_slice)
110+
repository.update_jobs_status([job])
111+
112+
assert job.status() == AsyncJobStatus.COMPLETED
113+
181114
def test_given_unknown_status_when_update_jobs_status_then_raise_error(self) -> None:
182115
self._mock_create_response(_A_JOB_ID)
183116
self._http_mocker.get(
@@ -277,3 +210,109 @@ def _mock_create_response(self, job_id: str) -> None:
277210
HttpRequest(url=_EXPORT_URL),
278211
HttpResponse(body=json.dumps({"id": job_id})),
279212
)
213+
214+
def _create_async_job_repository(
215+
self, polling_job_requester: Optional[HttpRequester] = None
216+
) -> AsyncHttpJobRepository:
217+
message_repository = Mock()
218+
create_job_requester = HttpRequester(
219+
name="stream <name>: create_job",
220+
url_base=_URL_BASE,
221+
path=_EXPORT_PATH,
222+
error_handler=_ERROR_HANDLER,
223+
http_method=HttpMethod.POST,
224+
config=_ANY_CONFIG,
225+
disable_retries=False,
226+
parameters={},
227+
message_repository=message_repository,
228+
use_cache=False,
229+
stream_response=False,
230+
)
231+
polling_job_requester = (
232+
polling_job_requester
233+
if polling_job_requester
234+
else HttpRequester(
235+
name="stream <name>: polling",
236+
url_base=_URL_BASE,
237+
path=_EXPORT_PATH + "/{{creation_response['id']}}",
238+
error_handler=_ERROR_HANDLER,
239+
http_method=HttpMethod.GET,
240+
config=_ANY_CONFIG,
241+
disable_retries=False,
242+
parameters={},
243+
message_repository=message_repository,
244+
use_cache=False,
245+
stream_response=False,
246+
)
247+
)
248+
249+
download_retriever = SimpleRetriever(
250+
requester=HttpRequester(
251+
name="stream <name>: fetch_result",
252+
url_base="",
253+
path="{{download_target}}",
254+
error_handler=_ERROR_HANDLER,
255+
http_method=HttpMethod.GET,
256+
config=_ANY_CONFIG,
257+
disable_retries=False,
258+
parameters={},
259+
message_repository=message_repository,
260+
use_cache=False,
261+
stream_response=True,
262+
),
263+
record_selector=RecordSelector(
264+
extractor=ResponseToFileExtractor({}),
265+
record_filter=None,
266+
transformations=[],
267+
schema_normalization=TypeTransformer(TransformConfig.NoTransform),
268+
config=_ANY_CONFIG,
269+
parameters={},
270+
),
271+
primary_key=None,
272+
name="any name",
273+
paginator=DefaultPaginator(
274+
decoder=NoopDecoder(),
275+
page_size_option=None,
276+
page_token_option=RequestOption(
277+
field_name="locator",
278+
inject_into=RequestOptionType.request_parameter,
279+
parameters={},
280+
),
281+
pagination_strategy=CursorPaginationStrategy(
282+
cursor_value="{{ headers['Sforce-Locator'] }}",
283+
decoder=NoopDecoder(),
284+
config=_ANY_CONFIG,
285+
parameters={},
286+
),
287+
url_base=_URL_BASE,
288+
config=_ANY_CONFIG,
289+
parameters={},
290+
),
291+
config=_ANY_CONFIG,
292+
parameters={},
293+
)
294+
295+
return AsyncHttpJobRepository(
296+
creation_requester=create_job_requester,
297+
polling_requester=polling_job_requester,
298+
download_retriever=download_retriever,
299+
abort_requester=None,
300+
delete_requester=None,
301+
status_extractor=DpathExtractor(
302+
decoder=JsonDecoder(parameters={}),
303+
field_path=["status"],
304+
config={},
305+
parameters={} or {},
306+
),
307+
status_mapping={
308+
"ready": AsyncJobStatus.COMPLETED,
309+
"failure": AsyncJobStatus.FAILED,
310+
"pending": AsyncJobStatus.RUNNING,
311+
},
312+
download_target_extractor=DpathExtractor(
313+
decoder=JsonDecoder(parameters={}),
314+
field_path=["urls"],
315+
config={},
316+
parameters={} or {},
317+
),
318+
)

0 commit comments

Comments
 (0)