15
15
"""Firebase Cloud Messaging module."""
16
16
17
17
from __future__ import annotations
18
- from typing import Callable , List , Optional , TypeVar
18
+ from typing import Callable , List , Optional
19
19
import concurrent .futures
20
20
import json
21
21
import warnings
22
+ import asyncio
22
23
import requests
23
24
import httpx
24
- import asyncio
25
25
26
- from google .auth import credentials , transport
26
+ from google .auth import credentials
27
+ from google .auth .transport import requests as auth_requests
27
28
from googleapiclient import http
28
29
from googleapiclient import _auth
29
30
30
31
import firebase_admin
31
- from firebase_admin import _http_client
32
- from firebase_admin import _messaging_encoder
33
- from firebase_admin import _messaging_utils
34
- from firebase_admin import _gapic_utils
35
- from firebase_admin import _utils
36
- from firebase_admin import exceptions
32
+ from firebase_admin import (
33
+ _http_client ,
34
+ _messaging_encoder ,
35
+ _messaging_utils ,
36
+ _gapic_utils ,
37
+ _utils ,
38
+ exceptions ,
39
+ App
40
+ )
37
41
38
42
39
43
_MESSAGING_ATTRIBUTE = '_messaging'
67
71
'WebpushNotification' ,
68
72
'WebpushNotificationAction' ,
69
73
70
- 'async_send_each'
71
74
'send' ,
72
75
'send_all' ,
73
76
'send_multicast' ,
74
77
'send_each' ,
78
+ 'send_each_async' ,
75
79
'send_each_for_multicast' ,
76
80
'subscribe_to_topic' ,
77
81
'unsubscribe_from_topic' ,
78
- ] # type: ignore
82
+ ]
79
83
80
- TFirebaseError = TypeVar ('TFirebaseError' , bound = exceptions .FirebaseError )
81
84
82
85
AndroidConfig = _messaging_utils .AndroidConfig
83
86
AndroidFCMOptions = _messaging_utils .AndroidFCMOptions
104
107
UnregisteredError = _messaging_utils .UnregisteredError
105
108
106
109
107
- def _get_messaging_service (app ) -> _MessagingService :
110
+ def _get_messaging_service (app : Optional [ App ] ) -> _MessagingService :
108
111
return _utils .get_app_service (app , _MESSAGING_ATTRIBUTE , _MessagingService )
109
112
110
- def send (message , dry_run = False , app = None ):
113
+ def send (message , dry_run = False , app : Optional [ App ] = None ):
111
114
"""Sends the given message via Firebase Cloud Messaging (FCM).
112
115
113
116
If the ``dry_run`` mode is enabled, the message will not be actually delivered to the
@@ -147,8 +150,8 @@ def send_each(messages, dry_run=False, app=None):
147
150
"""
148
151
return _get_messaging_service (app ).send_each (messages , dry_run )
149
152
150
- async def async_send_each (messages , dry_run = True , app : firebase_admin . App | None = None ) -> BatchResponse :
151
- return await _get_messaging_service (app ).async_send_each (messages , dry_run )
153
+ async def send_each_async (messages , dry_run = True , app : Optional [ App ] = None ) -> BatchResponse :
154
+ return await _get_messaging_service (app ).send_each_async (messages , dry_run )
152
155
153
156
def send_each_for_multicast (multicast_message , dry_run = False , app = None ):
154
157
"""Sends the given mutlicast message to each token via Firebase Cloud Messaging (FCM).
@@ -374,48 +377,53 @@ def exception(self):
374
377
return self ._exception
375
378
376
379
# Auth Flow
380
+ # TODO: Remove comments
377
381
# The aim here is to be able to get auth credentials right before the request is sent.
378
382
# This is similar to what is done in transport.requests.AuthorizedSession().
379
383
# We can then pass this in at the client level.
380
- class CustomGoogleAuth (httpx .Auth ):
381
- def __init__ (self , credentials : credentials .Credentials ):
382
- self ._credential = credentials
384
+
385
+ # Notes:
386
+ # - This implementations does not cover timeouts on requests sent to refresh credentials.
387
+ # - Uses HTTP/1 and a blocking credential for refreshing.
388
+ class GoogleAuthCredentialFlow (httpx .Auth ):
389
+ """Google Auth Credential Auth Flow"""
390
+ def __init__ (self , credential : credentials .Credentials ):
391
+ self ._credential = credential
383
392
self ._max_refresh_attempts = 2
384
393
self ._refresh_status_codes = (401 ,)
385
-
394
+
386
395
def apply_auth_headers (self , request : httpx .Request ):
387
396
# Build request used to refresh credentials if needed
388
- auth_request = transport . requests . Request () # type: ignore
389
- # This refreshes the credentials if needed and mutates the request headers to contain access token
390
- # and any other google auth headers
397
+ auth_request = auth_requests . Request ()
398
+ # This refreshes the credentials if needed and mutates the request headers to
399
+ # contain access token and any other google auth headers
391
400
self ._credential .before_request (auth_request , request .method , request .url , request .headers )
392
401
393
402
394
403
def auth_flow (self , request : httpx .Request ):
395
404
# Keep original headers since `credentials.before_request` mutates the passed headers and we
396
405
# want to keep the original in cause we need an auth retry.
397
406
_original_headers = request .headers .copy ()
398
-
407
+
399
408
_credential_refresh_attempt = 0
400
- while (
401
- _credential_refresh_attempt < self ._max_refresh_attempts
402
- ):
409
+ while _credential_refresh_attempt <= self ._max_refresh_attempts :
403
410
# copy original headers
404
411
request .headers = _original_headers .copy ()
405
412
# mutates request headers
406
413
self .apply_auth_headers (request )
407
-
414
+
408
415
# Continue to perform the request
409
416
# yield here dispatches the request and returns with the response
410
417
response : httpx .Response = yield request
411
-
412
- # We can check the result of the response and determine in we need to retry on refreshable status codes.
413
- # Current transport.requests.AuthorizedSession() only does this on 401 errors. We should do the same.
418
+
419
+ # We can check the result of the response and determine in we need to retry
420
+ # on refreshable status codes. Current transport.requests.AuthorizedSession()
421
+ # only does this on 401 errors. We should do the same.
414
422
if response .status_code in self ._refresh_status_codes :
415
423
_credential_refresh_attempt += 1
416
- print (response .status_code , response .reason_phrase , _credential_refresh_attempt )
417
424
else :
418
- break ;
425
+ break
426
+ # Last yielded response is auto returned.
419
427
420
428
421
429
@@ -453,7 +461,7 @@ def __init__(self, app) -> None:
453
461
self ._client = _http_client .JsonHttpClient (credential = self ._credential , timeout = timeout )
454
462
self ._async_client = httpx .AsyncClient (
455
463
http2 = True ,
456
- auth = CustomGoogleAuth (self ._credential ),
464
+ auth = GoogleAuthCredentialFlow (self ._credential ),
457
465
timeout = timeout ,
458
466
transport = HttpxRetryTransport ()
459
467
)
@@ -509,13 +517,13 @@ def send_data(data):
509
517
message = 'Unknown error while making remote service calls: {0}' .format (error ),
510
518
cause = error )
511
519
512
- async def async_send_each (self , messages : List [Message ], dry_run : bool = True ) -> BatchResponse :
520
+ async def send_each_async (self , messages : List [Message ], dry_run : bool = True ) -> BatchResponse :
513
521
"""Sends the given messages to FCM via the FCM v1 API."""
514
522
if not isinstance (messages , list ):
515
523
raise ValueError ('messages must be a list of messaging.Message instances.' )
516
524
if len (messages ) > 1000 :
517
525
raise ValueError ('messages must not contain more than 500 elements.' )
518
-
526
+
519
527
async def send_data (data ):
520
528
try :
521
529
resp = await self ._async_client .request (
@@ -661,7 +669,8 @@ def _handle_batch_error(self, error):
661
669
"""Handles errors received from the googleapiclient while making batch requests."""
662
670
return _gapic_utils .handle_platform_error_from_googleapiclient (
663
671
error , _MessagingService ._build_fcm_error_googleapiclient )
664
-
672
+
673
+ # TODO: Remove comments
665
674
# We should be careful to clean up the httpx clients.
666
675
# Since we are using an async client we must also close in async. However we can sync wrap this.
667
676
# The close method is called by the app on shutdown/clean-up of each service. We don't seem to
@@ -677,14 +686,16 @@ def _build_fcm_error_requests(cls, error, message, error_dict):
677
686
return exc_type (message , cause = error , http_response = error .response ) if exc_type else None
678
687
679
688
@classmethod
680
- def _build_fcm_error_httpx (cls , error : httpx .HTTPError , message , error_dict ) -> Optional [exceptions .FirebaseError ]:
689
+ def _build_fcm_error_httpx (
690
+ cls , error : httpx .HTTPError , message , error_dict
691
+ ) -> Optional [exceptions .FirebaseError ]:
681
692
"""Parses a httpx error response from the FCM API and creates a FCM-specific exception if
682
693
appropriate."""
683
694
exc_type = cls ._build_fcm_error (error_dict )
684
695
if isinstance (error , httpx .HTTPStatusError ):
685
- return exc_type (message , cause = error , http_response = error . response ) if exc_type else None
686
- else :
687
- return exc_type (message , cause = error ) if exc_type else None
696
+ return exc_type (
697
+ message , cause = error , http_response = error . response ) if exc_type else None
698
+ return exc_type (message , cause = error ) if exc_type else None
688
699
689
700
690
701
@classmethod
@@ -706,42 +717,43 @@ def _build_fcm_error(cls, error_dict) -> Optional[Callable[..., exceptions.Fireb
706
717
return _MessagingService .FCM_ERROR_TYPES .get (fcm_code ) if fcm_code else None
707
718
708
719
720
+ # TODO: Remove comments
721
+ # Notes:
722
+ # This implementation currently only covers basic retires for pre-defined status errors
709
723
class HttpxRetryTransport (httpx .AsyncBaseTransport ):
724
+ """HTTPX transport with retry logic."""
710
725
# We could also support passing kwargs here
711
- def __init__ (self ) -> None :
726
+ def __init__ (self , ** kwargs ) -> None :
727
+ # Hardcoded settings for now
712
728
self ._retryable_status_codes = (500 , 503 ,)
713
729
self ._max_retry_count = 4
714
730
715
- # We should use a full AsyncHTTPTransport under the hood since that is
716
- # fully implemented. We could consider making this class extend a
717
- # AsyncHTTPTransport instead and use the parent class's methods to handle
718
- # requests. We sould also ensure that that transport's internal retry is
719
- # not enabled.
720
- self ._wrapped_transport = httpx .AsyncHTTPTransport (retries = 0 , http2 = True )
721
-
722
- # Checklist:
723
- # - Do we want to disable built in retries
724
- # - Can we dispatch the same request multiple times? Is there any side effects?
725
-
726
- # Two types of retries
727
- # - Status code (500s, redirect)
728
- # - Error code (read, connect, other)
729
- # - more ???
730
-
731
+ # - We use a full AsyncHTTPTransport under the hood to make use of it's
732
+ # fully implemented `handle_async_request()`.
733
+ # - We could consider making the `HttpxRetryTransport`` class extend a
734
+ # `AsyncHTTPTransport` instead and use the parent class's methods to handle
735
+ # requests.
736
+ # - We should also ensure that that transport's internal retry is
737
+ # not enabled.
738
+ transport_kwargs = kwargs .copy ()
739
+ transport_kwargs .update ({'retries' : 0 , 'http2' : True })
740
+ self ._wrapped_transport = httpx .AsyncHTTPTransport (** transport_kwargs )
741
+
742
+
731
743
async def handle_async_request (self , request : httpx .Request ) -> httpx .Response :
732
744
_retry_count = 0
733
-
745
+
734
746
while True :
735
747
# Dispatch request
748
+ # Let exceptions pass through for now
736
749
response = await self ._wrapped_transport .handle_async_request (request )
737
-
750
+
738
751
# Check if request is retryable
739
752
if response .status_code in self ._retryable_status_codes :
740
753
_retry_count += 1
741
-
742
- # Figure out how we want to handle 0 here
754
+
755
+ # Return if retries exhausted
743
756
if _retry_count > self ._max_retry_count :
744
757
return response
745
758
else :
746
759
return response
747
- # break;
0 commit comments