Skip to content

Commit e3aa0db

Browse files
committed
Clean up code and lint
1 parent cae8f33 commit e3aa0db

File tree

5 files changed

+146
-119
lines changed

5 files changed

+146
-119
lines changed

firebase_admin/_utils.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import json
1818
from platform import python_version
19-
from typing import Callable, Optional, Union
19+
from typing import Callable, Optional
2020

2121
import google.auth
2222
import requests
@@ -131,8 +131,8 @@ def handle_platform_error_from_requests(error, handle_func=None):
131131
return exc if exc else _handle_func_requests(error, message, error_dict)
132132

133133
def handle_platform_error_from_httpx(
134-
error: httpx.HTTPError,
135-
handle_func: Optional[Callable[...,Optional[exceptions.FirebaseError]]] = None
134+
error: httpx.HTTPError,
135+
handle_func: Optional[Callable[..., Optional[exceptions.FirebaseError]]] = None
136136
) -> exceptions.FirebaseError:
137137
"""Constructs a ``FirebaseError`` from the given httpx error.
138138
@@ -158,8 +158,7 @@ def handle_platform_error_from_httpx(
158158
exc = handle_func(error, message, error_dict)
159159

160160
return exc if exc else _handle_func_httpx(error, message, error_dict)
161-
else:
162-
return handle_httpx_error(error)
161+
return handle_httpx_error(error)
163162

164163

165164
def handle_operation_error(error):

firebase_admin/messaging.py

+75-63
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,29 @@
1515
"""Firebase Cloud Messaging module."""
1616

1717
from __future__ import annotations
18-
from typing import Callable, List, Optional, TypeVar
18+
from typing import Callable, List, Optional
1919
import concurrent.futures
2020
import json
2121
import warnings
22+
import asyncio
2223
import requests
2324
import httpx
24-
import asyncio
2525

26-
from google.auth import credentials, transport
26+
from google.auth import credentials
27+
from google.auth.transport import requests as auth_requests
2728
from googleapiclient import http
2829
from googleapiclient import _auth
2930

3031
import firebase_admin
31-
from firebase_admin import _http_client
32-
from firebase_admin import _messaging_encoder
33-
from firebase_admin import _messaging_utils
34-
from firebase_admin import _gapic_utils
35-
from firebase_admin import _utils
36-
from firebase_admin import exceptions
32+
from firebase_admin import (
33+
_http_client,
34+
_messaging_encoder,
35+
_messaging_utils,
36+
_gapic_utils,
37+
_utils,
38+
exceptions,
39+
App
40+
)
3741

3842

3943
_MESSAGING_ATTRIBUTE = '_messaging'
@@ -67,17 +71,16 @@
6771
'WebpushNotification',
6872
'WebpushNotificationAction',
6973

70-
'async_send_each'
7174
'send',
7275
'send_all',
7376
'send_multicast',
7477
'send_each',
78+
'send_each_async',
7579
'send_each_for_multicast',
7680
'subscribe_to_topic',
7781
'unsubscribe_from_topic',
78-
] # type: ignore
82+
]
7983

80-
TFirebaseError = TypeVar('TFirebaseError', bound=exceptions.FirebaseError)
8184

8285
AndroidConfig = _messaging_utils.AndroidConfig
8386
AndroidFCMOptions = _messaging_utils.AndroidFCMOptions
@@ -104,10 +107,10 @@
104107
UnregisteredError = _messaging_utils.UnregisteredError
105108

106109

107-
def _get_messaging_service(app) -> _MessagingService:
110+
def _get_messaging_service(app: Optional[App]) -> _MessagingService:
108111
return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService)
109112

110-
def send(message, dry_run=False, app=None):
113+
def send(message, dry_run=False, app: Optional[App] = None):
111114
"""Sends the given message via Firebase Cloud Messaging (FCM).
112115
113116
If the ``dry_run`` mode is enabled, the message will not be actually delivered to the
@@ -147,8 +150,8 @@ def send_each(messages, dry_run=False, app=None):
147150
"""
148151
return _get_messaging_service(app).send_each(messages, dry_run)
149152

150-
async def async_send_each(messages, dry_run=True, app: firebase_admin.App | None = None) -> BatchResponse:
151-
return await _get_messaging_service(app).async_send_each(messages, dry_run)
153+
async def send_each_async(messages, dry_run=True, app: Optional[App] = None) -> BatchResponse:
154+
return await _get_messaging_service(app).send_each_async(messages, dry_run)
152155

153156
def send_each_for_multicast(multicast_message, dry_run=False, app=None):
154157
"""Sends the given mutlicast message to each token via Firebase Cloud Messaging (FCM).
@@ -374,48 +377,53 @@ def exception(self):
374377
return self._exception
375378

376379
# Auth Flow
380+
# TODO: Remove comments
377381
# The aim here is to be able to get auth credentials right before the request is sent.
378382
# This is similar to what is done in transport.requests.AuthorizedSession().
379383
# We can then pass this in at the client level.
380-
class CustomGoogleAuth(httpx.Auth):
381-
def __init__(self, credentials: credentials.Credentials):
382-
self._credential = credentials
384+
385+
# Notes:
386+
# - This implementations does not cover timeouts on requests sent to refresh credentials.
387+
# - Uses HTTP/1 and a blocking credential for refreshing.
388+
class GoogleAuthCredentialFlow(httpx.Auth):
389+
"""Google Auth Credential Auth Flow"""
390+
def __init__(self, credential: credentials.Credentials):
391+
self._credential = credential
383392
self._max_refresh_attempts = 2
384393
self._refresh_status_codes = (401,)
385-
394+
386395
def apply_auth_headers(self, request: httpx.Request):
387396
# Build request used to refresh credentials if needed
388-
auth_request = transport.requests.Request() # type: ignore
389-
# This refreshes the credentials if needed and mutates the request headers to contain access token
390-
# and any other google auth headers
397+
auth_request = auth_requests.Request()
398+
# This refreshes the credentials if needed and mutates the request headers to
399+
# contain access token and any other google auth headers
391400
self._credential.before_request(auth_request, request.method, request.url, request.headers)
392401

393402

394403
def auth_flow(self, request: httpx.Request):
395404
# Keep original headers since `credentials.before_request` mutates the passed headers and we
396405
# want to keep the original in cause we need an auth retry.
397406
_original_headers = request.headers.copy()
398-
407+
399408
_credential_refresh_attempt = 0
400-
while (
401-
_credential_refresh_attempt < self._max_refresh_attempts
402-
):
409+
while _credential_refresh_attempt <= self._max_refresh_attempts:
403410
# copy original headers
404411
request.headers = _original_headers.copy()
405412
# mutates request headers
406413
self.apply_auth_headers(request)
407-
414+
408415
# Continue to perform the request
409416
# yield here dispatches the request and returns with the response
410417
response: httpx.Response = yield request
411-
412-
# We can check the result of the response and determine in we need to retry on refreshable status codes.
413-
# Current transport.requests.AuthorizedSession() only does this on 401 errors. We should do the same.
418+
419+
# We can check the result of the response and determine in we need to retry
420+
# on refreshable status codes. Current transport.requests.AuthorizedSession()
421+
# only does this on 401 errors. We should do the same.
414422
if response.status_code in self._refresh_status_codes:
415423
_credential_refresh_attempt += 1
416-
print(response.status_code, response.reason_phrase, _credential_refresh_attempt)
417424
else:
418-
break;
425+
break
426+
# Last yielded response is auto returned.
419427

420428

421429

@@ -453,7 +461,7 @@ def __init__(self, app) -> None:
453461
self._client = _http_client.JsonHttpClient(credential=self._credential, timeout=timeout)
454462
self._async_client = httpx.AsyncClient(
455463
http2=True,
456-
auth=CustomGoogleAuth(self._credential),
464+
auth=GoogleAuthCredentialFlow(self._credential),
457465
timeout=timeout,
458466
transport=HttpxRetryTransport()
459467
)
@@ -509,13 +517,13 @@ def send_data(data):
509517
message='Unknown error while making remote service calls: {0}'.format(error),
510518
cause=error)
511519

512-
async def async_send_each(self, messages: List[Message], dry_run: bool = True) -> BatchResponse:
520+
async def send_each_async(self, messages: List[Message], dry_run: bool = True) -> BatchResponse:
513521
"""Sends the given messages to FCM via the FCM v1 API."""
514522
if not isinstance(messages, list):
515523
raise ValueError('messages must be a list of messaging.Message instances.')
516524
if len(messages) > 1000:
517525
raise ValueError('messages must not contain more than 500 elements.')
518-
526+
519527
async def send_data(data):
520528
try:
521529
resp = await self._async_client.request(
@@ -661,7 +669,8 @@ def _handle_batch_error(self, error):
661669
"""Handles errors received from the googleapiclient while making batch requests."""
662670
return _gapic_utils.handle_platform_error_from_googleapiclient(
663671
error, _MessagingService._build_fcm_error_googleapiclient)
664-
672+
673+
# TODO: Remove comments
665674
# We should be careful to clean up the httpx clients.
666675
# Since we are using an async client we must also close in async. However we can sync wrap this.
667676
# The close method is called by the app on shutdown/clean-up of each service. We don't seem to
@@ -677,14 +686,16 @@ def _build_fcm_error_requests(cls, error, message, error_dict):
677686
return exc_type(message, cause=error, http_response=error.response) if exc_type else None
678687

679688
@classmethod
680-
def _build_fcm_error_httpx(cls, error: httpx.HTTPError, message, error_dict) -> Optional[exceptions.FirebaseError]:
689+
def _build_fcm_error_httpx(
690+
cls, error: httpx.HTTPError, message, error_dict
691+
) -> Optional[exceptions.FirebaseError]:
681692
"""Parses a httpx error response from the FCM API and creates a FCM-specific exception if
682693
appropriate."""
683694
exc_type = cls._build_fcm_error(error_dict)
684695
if isinstance(error, httpx.HTTPStatusError):
685-
return exc_type(message, cause=error, http_response=error.response) if exc_type else None
686-
else:
687-
return exc_type(message, cause=error) if exc_type else None
696+
return exc_type(
697+
message, cause=error, http_response=error.response) if exc_type else None
698+
return exc_type(message, cause=error) if exc_type else None
688699

689700

690701
@classmethod
@@ -706,42 +717,43 @@ def _build_fcm_error(cls, error_dict) -> Optional[Callable[..., exceptions.Fireb
706717
return _MessagingService.FCM_ERROR_TYPES.get(fcm_code) if fcm_code else None
707718

708719

720+
# TODO: Remove comments
721+
# Notes:
722+
# This implementation currently only covers basic retires for pre-defined status errors
709723
class HttpxRetryTransport(httpx.AsyncBaseTransport):
724+
"""HTTPX transport with retry logic."""
710725
# We could also support passing kwargs here
711-
def __init__(self) -> None:
726+
def __init__(self, **kwargs) -> None:
727+
# Hardcoded settings for now
712728
self._retryable_status_codes = (500, 503,)
713729
self._max_retry_count = 4
714730

715-
# We should use a full AsyncHTTPTransport under the hood since that is
716-
# fully implemented. We could consider making this class extend a
717-
# AsyncHTTPTransport instead and use the parent class's methods to handle
718-
# requests. We sould also ensure that that transport's internal retry is
719-
# not enabled.
720-
self._wrapped_transport = httpx.AsyncHTTPTransport(retries=0, http2=True)
721-
722-
# Checklist:
723-
# - Do we want to disable built in retries
724-
# - Can we dispatch the same request multiple times? Is there any side effects?
725-
726-
# Two types of retries
727-
# - Status code (500s, redirect)
728-
# - Error code (read, connect, other)
729-
# - more ???
730-
731+
# - We use a full AsyncHTTPTransport under the hood to make use of it's
732+
# fully implemented `handle_async_request()`.
733+
# - We could consider making the `HttpxRetryTransport`` class extend a
734+
# `AsyncHTTPTransport` instead and use the parent class's methods to handle
735+
# requests.
736+
# - We should also ensure that that transport's internal retry is
737+
# not enabled.
738+
transport_kwargs = kwargs.copy()
739+
transport_kwargs.update({'retries': 0, 'http2': True})
740+
self._wrapped_transport = httpx.AsyncHTTPTransport(**transport_kwargs)
741+
742+
731743
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
732744
_retry_count = 0
733-
745+
734746
while True:
735747
# Dispatch request
748+
# Let exceptions pass through for now
736749
response = await self._wrapped_transport.handle_async_request(request)
737-
750+
738751
# Check if request is retryable
739752
if response.status_code in self._retryable_status_codes:
740753
_retry_count += 1
741-
742-
# Figure out how we want to handle 0 here
754+
755+
# Return if retries exhausted
743756
if _retry_count > self._max_retry_count:
744757
return response
745758
else:
746759
return response
747-
# break;

integration/conftest.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,8 @@ def api_key(request):
8282
# yield loop
8383
# loop.close()
8484

85-
#
8685
def pytest_collection_modifyitems(items):
8786
pytest_asyncio_tests = (item for item in items if is_async_test(item))
8887
session_scope_marker = pytest.mark.asyncio(loop_scope="session")
8988
for async_test in pytest_asyncio_tests:
90-
async_test.add_marker(session_scope_marker, append=False)
89+
async_test.add_marker(session_scope_marker, append=False)

integration/test_messaging.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
"""Integration tests for firebase_admin.messaging module."""
1616

17-
import asyncio
1817
import re
1918
from datetime import datetime
2019

@@ -224,7 +223,7 @@ def test_unsubscribe():
224223
assert resp.success_count + resp.failure_count == 1
225224

226225
@pytest.mark.asyncio
227-
async def test_async_send_each():
226+
async def test_send_each_async():
228227
messages = [
229228
messaging.Message(
230229
topic='foo-bar', notification=messaging.Notification('Title', 'Body')),
@@ -234,7 +233,7 @@ async def test_async_send_each():
234233
token='not-a-token', notification=messaging.Notification('Title', 'Body')),
235234
]
236235

237-
batch_response = await messaging.async_send_each(messages, dry_run=True)
236+
batch_response = await messaging.send_each_async(messages, dry_run=True)
238237

239238
assert batch_response.success_count == 2
240239
assert batch_response.failure_count == 1
@@ -257,7 +256,7 @@ async def test_async_send_each():
257256

258257

259258
# @pytest.mark.asyncio
260-
# async def test_async_send_each_error():
259+
# async def test_send_each_async_error():
261260
# messages = [
262261
# messaging.Message(
263262
# topic='foo-bar', notification=messaging.Notification('Title', 'Body')),
@@ -267,7 +266,7 @@ async def test_async_send_each():
267266
# token='not-a-token', notification=messaging.Notification('Title', 'Body')),
268267
# ]
269268

270-
# batch_response = await messaging.async_send_each(messages, dry_run=True)
269+
# batch_response = await messaging.send_each_async(messages, dry_run=True)
271270

272271
# assert batch_response.success_count == 2
273272
# assert batch_response.failure_count == 1
@@ -289,13 +288,13 @@ async def test_async_send_each():
289288
# assert response.message_id is None
290289

291290
@pytest.mark.asyncio
292-
async def test_async_send_each_500():
291+
async def test_send_each_async_500():
293292
messages = []
294293
for msg_number in range(500):
295294
topic = 'foo-bar-{0}'.format(msg_number % 10)
296295
messages.append(messaging.Message(topic=topic))
297296

298-
batch_response = await messaging.async_send_each(messages, dry_run=True)
297+
batch_response = await messaging.send_each_async(messages, dry_run=True)
299298

300299
assert batch_response.success_count == 500
301300
assert batch_response.failure_count == 0
@@ -304,4 +303,3 @@ async def test_async_send_each_500():
304303
assert response.success is True
305304
assert response.exception is None
306305
assert re.match('^projects/.*/messages/.*$', response.message_id)
307-

0 commit comments

Comments
 (0)