Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions google/resumable_media/requests/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _write_to_stream(self, response):
self.media_url, expected_md5_hash, actual_md5_hash)
raise common.DataCorruption(response, msg)

def consume(self, transport):
def consume(self, transport, **transport_kwargs):
"""Consume the resource to be downloaded.

If a ``stream`` is attached to this download, then the downloaded
Expand All @@ -162,6 +162,7 @@ def consume(self, transport):
u'headers': headers,
u'retry_strategy': self._retry_strategy,
}
request_kwargs.update(transport_kwargs)
if self._stream is not None:
request_kwargs[u'stream'] = True

Expand Down Expand Up @@ -204,7 +205,7 @@ class ChunkedDownload(_helpers.RequestsMixin, _download.ChunkedDownload):
ValueError: If ``start`` is negative.
"""

def consume_next_chunk(self, transport):
def consume_next_chunk(self, transport, **transport_kwargs):
"""Consume the next chunk of the resource to be downloaded.

Args:
Expand All @@ -221,7 +222,7 @@ def consume_next_chunk(self, transport):
# NOTE: We assume "payload is None" but pass it along anyway.
result = _helpers.http_request(
transport, method, url, data=payload, headers=headers,
retry_strategy=self._retry_strategy)
retry_strategy=self._retry_strategy, **transport_kwargs)
self._process_response(result)
return result

Expand Down
86 changes: 86 additions & 0 deletions google/resumable_media/requests/transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2019 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.auth import transport
from google.auth.transport.requests import AuthorizedSession


class TimeoutAuthorizedSession(AuthorizedSession):
"""A Requests Session class with credentials.

This class is used to perform requests to API endpoints that require
authorization::

from google.resumable_media.requests.transport import
TimeoutAuthorizedSession

authed_session = TimeoutAuthorizedSession(credentials)

response = authed_session.request(
'GET', 'https://www.googleapis.com/storage/v1/b')

The underlying :meth:`request` implementation handles adding the
credentials' headers to the request and refreshing credentials as needed.

Args:
credentials (google.auth.credentials.Credentials): The credentials to
add to the request.
refresh_status_codes (Sequence[int]): Which HTTP status codes indicate
that credentials should be refreshed and the request should be
retried.
max_refresh_attempts (int): The maximum number of times to attempt to
refresh the credentials and retry the request.
refresh_timeout (Optional[int]): The timeout value in seconds for
credential refresh HTTP requests.
auth_request (google.auth.transport.requests.Request):
(Optional) An instance of
:class:`~google.auth.transport.requests.Request` used when
refreshing credentials. If not passed,
an instance of :class:`~google.auth.transport.requests.Request`
is created.
timeout(int) : The timeout value in second for request.

"""

def __init__(self, credentials,
refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES,
max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS,
refresh_timeout=None,
auth_request=None, **kwargs):
self.timeout = None
if "timeout" in kwargs:
self.timeout = kwargs.pop("timeout")
super(TimeoutAuthorizedSession, self).__init__(
credentials,
refresh_status_codes=refresh_status_codes,
max_refresh_attempts=max_refresh_attempts,
refresh_timeout=refresh_timeout,
auth_request=auth_request)

def request(self, method, url, data=None, headers=None, **kwargs):
"""
:param method: request method 'GET', 'POST' etc.
:param url: request url
:param data: request data
:param headers: request header
:param kwargs: extra data
:return: response of timeout set request.
"""
if "timeout" not in kwargs:
kwargs['timeout'] = self.timeout
return super(TimeoutAuthorizedSession, self).request(method=method,
url=url,
data=data,
headers=headers,
**kwargs)
21 changes: 11 additions & 10 deletions google/resumable_media/requests/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class SimpleUpload(_helpers.RequestsMixin, _upload.SimpleUpload):
upload_url (str): The URL where the content will be uploaded.
"""

def transmit(self, transport, data, content_type):
def transmit(self, transport, data, content_type, **transport_kwargs):
"""Transmit the resource to be uploaded.

Args:
Expand All @@ -55,7 +55,7 @@ def transmit(self, transport, data, content_type):
data, content_type)
result = _helpers.http_request(
transport, method, url, data=payload, headers=headers,
retry_strategy=self._retry_strategy)
retry_strategy=self._retry_strategy, **transport_kwargs)
self._process_response(result)
return result

Expand All @@ -75,7 +75,8 @@ class MultipartUpload(_helpers.RequestsMixin, _upload.MultipartUpload):
upload_url (str): The URL where the content will be uploaded.
"""

def transmit(self, transport, data, metadata, content_type):
def transmit(self, transport, data, metadata, content_type,
**transport_kwargs):
"""Transmit the resource to be uploaded.

Args:
Expand All @@ -94,7 +95,7 @@ def transmit(self, transport, data, metadata, content_type):
data, metadata, content_type)
result = _helpers.http_request(
transport, method, url, data=payload, headers=headers,
retry_strategy=self._retry_strategy)
retry_strategy=self._retry_strategy, **transport_kwargs)
self._process_response(result)
return result

Expand Down Expand Up @@ -283,7 +284,7 @@ class ResumableUpload(_helpers.RequestsMixin, _upload.ResumableUpload):
"""

def initiate(self, transport, stream, metadata, content_type,
total_bytes=None, stream_final=True):
total_bytes=None, stream_final=True, **transport_kwargs):
"""Initiate a resumable upload.

By default, this method assumes your ``stream`` is in a "final"
Expand Down Expand Up @@ -323,11 +324,11 @@ def initiate(self, transport, stream, metadata, content_type,
total_bytes=total_bytes, stream_final=stream_final)
result = _helpers.http_request(
transport, method, url, data=payload, headers=headers,
retry_strategy=self._retry_strategy)
retry_strategy=self._retry_strategy, **transport_kwargs)
self._process_initiate_response(result)
return result

def transmit_next_chunk(self, transport):
def transmit_next_chunk(self, transport, **transport_kwargs):
"""Transmit the next chunk of the resource to be uploaded.

If the current upload was initiated with ``stream_final=False``,
Expand Down Expand Up @@ -392,11 +393,11 @@ def transmit_next_chunk(self, transport):
method, url, payload, headers = self._prepare_request()
result = _helpers.http_request(
transport, method, url, data=payload, headers=headers,
retry_strategy=self._retry_strategy)
retry_strategy=self._retry_strategy, **transport_kwargs)
self._process_response(result, len(payload))
return result

def recover(self, transport):
def recover(self, transport, **transport_kwargs):
"""Recover from a failure.

This method should be used when a :class:`ResumableUpload` is in an
Expand All @@ -417,6 +418,6 @@ def recover(self, transport):
# NOTE: We assume "payload is None" but pass it along anyway.
result = _helpers.http_request(
transport, method, url, data=payload, headers=headers,
retry_strategy=self._retry_strategy)
retry_strategy=self._retry_strategy, **transport_kwargs)
self._process_recover_response(result)
return result
52 changes: 52 additions & 0 deletions tests/unit/requests/test__transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2019 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import mock
import requests
from google.resumable_media.requests.transport import TimeoutAuthorizedSession


def make_response(status=200, data=None, *args, **kwargs):
response = requests.Response()
response.status_code = status
response._content = data
response.timeout = kwargs['timeout']
return response


class TestTimeoutSession(object):
TEST_URL = 'http://example.com/'

def test_request_timeout(self):
auth_session = TimeoutAuthorizedSession(mock.sentinel.credentials,
auth_request=None, timeout=50)
assert auth_session.timeout == 50
with mock.patch('google.auth.transport.requests.AuthorizedSession.'
'request', new=make_response):
response = auth_session.request(method='GET',
url=self.TEST_URL)
assert response.timeout == 50

def test_default_request_timeout(self):
auth_session = TimeoutAuthorizedSession(mock.sentinel.credentials,
auth_request=None)
assert auth_session.timeout is None
with mock.patch('google.auth.transport.requests.AuthorizedSession.'
'request', new=make_response):
response = auth_session.request(method='GET',
url=self.TEST_URL)
assert response.timeout is None
response2 = auth_session.request(method='GET', url=self.TEST_URL,
timeout=10)
assert response2.timeout == 10
Loading