Skip to content

Commit f1c7017

Browse files
committed
Implement streaming callables
1 parent f516594 commit f1c7017

File tree

3 files changed

+192
-10
lines changed

3 files changed

+192
-10
lines changed

samples/https_flask/functions/main.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,13 @@ def world():
2323
@https_fn.on_request()
2424
def httpsflaskexample(request):
2525
return entrypoint(app, request)
26+
27+
@https_fn.on_call()
28+
def callableexample(request: https_fn.Request):
29+
return request.data
30+
31+
@https_fn.on_call()
32+
def streamingcallable(request: https_fn.Request):
33+
yield "Hello,"
34+
yield "world!"
35+
return request.data

src/firebase_functions/https_fn.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
import typing_extensions as _typing_extensions
2222
import enum as _enum
2323
import json as _json
24+
import inspect as _inspect
2425
import firebase_functions.private.util as _util
2526
import firebase_functions.core as _core
27+
import contextlib as _contextlib
2628
from functions_framework import logging as _logging
2729

2830
from firebase_functions.options import HttpsOptions, _GLOBAL_OPTIONS
@@ -351,6 +353,12 @@ class CallableRequest(_typing.Generic[_core.T]):
351353
_C1 = _typing.Callable[[Request], Response]
352354
_C2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any]
353355

356+
class _IterWithReturn:
357+
def __init__(self, iterable):
358+
self.iterable = iterable
359+
360+
def __iter__(self):
361+
self.value = yield from self.iterable
354362

355363
def _on_call_handler(func: _C2, request: Request,
356364
enforce_app_check: bool) -> Response:
@@ -401,7 +409,19 @@ def _on_call_handler(func: _C2, request: Request,
401409
"Firebase-Instance-ID-Token"),
402410
)
403411
result = _core._with_init(func)(context)
404-
return _jsonify(result=result)
412+
if not _inspect.isgenerator(result):
413+
return _jsonify(result=result)
414+
415+
if request.headers.get("Accept") != "text/event-stream":
416+
vals = _IterWithReturn(result)
417+
for _ in vals:
418+
next
419+
return _jsonify(result=vals.value)
420+
421+
else:
422+
return Response(_sse_encode_generator(result), content_type="text/plain")
423+
424+
405425
# Disable broad exceptions lint since we want to handle all exceptions here
406426
# and wrap as an HttpsError.
407427
# pylint: disable=broad-except
@@ -413,6 +433,19 @@ def _on_call_handler(func: _C2, request: Request,
413433
return _make_response(_jsonify(error=err._as_dict()), status)
414434

415435

436+
def _sse_encode_generator(gen: _typing.Generator):
437+
iter = _IterWithReturn(gen)
438+
try:
439+
for chunk in iter:
440+
yield f"data: %s\n\n" % _json.dumps(obj={"message": chunk})
441+
yield f"data: %s\n\n" % _json.dumps(obj={"result": iter.value})
442+
except Exception as err:
443+
if not isinstance(err, HttpsError):
444+
_logging.error("Unhandled error: %s", err)
445+
err = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL")
446+
yield f"error: %s\n\n" % _json.dumps(obj={"error": err._as_dict()})
447+
yield "END"
448+
416449
@_util.copy_func_kwargs(HttpsOptions)
417450
def on_request(**kwargs) -> _typing.Callable[[_C1], _C1]:
418451
"""

tests/test_https_fn.py

Lines changed: 148 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
"""
44

55
import unittest
6-
from unittest.mock import Mock
7-
from flask import Flask, Request
6+
from flask import Flask, Request, jsonify as _jsonify
87
from werkzeug.test import EnvironBuilder
98

109
from firebase_functions import core, https_fn
@@ -25,7 +24,9 @@ def init():
2524
nonlocal hello
2625
hello = "world"
2726

28-
func = Mock(__name__="example_func")
27+
@https_fn.on_request()
28+
def func(_):
29+
pass
2930

3031
with app.test_request_context("/"):
3132
environ = EnvironBuilder(
@@ -37,9 +38,8 @@ def init():
3738
},
3839
).get_environ()
3940
request = Request(environ)
40-
decorated_func = https_fn.on_request()(func)
4141

42-
decorated_func(request)
42+
func(request)
4343

4444
self.assertEqual(hello, "world")
4545

@@ -53,7 +53,9 @@ def init():
5353
nonlocal hello
5454
hello = "world"
5555

56-
func = Mock(__name__="example_func")
56+
@https_fn.on_call()
57+
def func(_):
58+
pass
5759

5860
with app.test_request_context("/"):
5961
environ = EnvironBuilder(
@@ -65,8 +67,145 @@ def init():
6567
},
6668
).get_environ()
6769
request = Request(environ)
68-
decorated_func = https_fn.on_call()(func)
69-
70-
decorated_func(request)
70+
func(request)
7171

7272
self.assertEqual("world", hello)
73+
74+
def test_callable_encoding(self):
75+
app = Flask(__name__)
76+
77+
@https_fn.on_call()
78+
def add(req: https_fn.CallableRequest[int]):
79+
return req.data + 1
80+
81+
with app.test_request_context("/"):
82+
environ = EnvironBuilder(
83+
method="POST",
84+
json={
85+
"data": 1
86+
}
87+
).get_environ()
88+
request = Request(environ)
89+
90+
response = add(request)
91+
self.assertEqual(response.status_code, 200)
92+
self.assertEqual(response.get_json(), { "result": 2 })
93+
94+
def test_callable_errors(self):
95+
app = Flask(__name__)
96+
97+
@https_fn.on_call()
98+
def throw_generic_error(req):
99+
raise Exception("Invalid type")
100+
101+
@https_fn.on_call()
102+
def throw_access_denied(req):
103+
raise https_fn.HttpsError(https_fn.FunctionsErrorCode.PERMISSION_DENIED, "Permission is denied")
104+
105+
with app.test_request_context("/"):
106+
environ = EnvironBuilder(
107+
method="POST",
108+
json={
109+
"data": None
110+
}
111+
).get_environ()
112+
request = Request(environ)
113+
114+
response = throw_generic_error(request)
115+
self.assertEqual(response.status_code, 500)
116+
self.assertEqual(response.get_json(), { "error": { "message": "INTERNAL", "status": "INTERNAL" } })
117+
118+
response = throw_access_denied(request)
119+
self.assertEqual(response.status_code, 403)
120+
self.assertEqual(response.get_json(), { "error": { "message": "Permission is denied", "status": "PERMISSION_DENIED" }})
121+
122+
def test_yielding_without_streaming(self):
123+
app = Flask(__name__)
124+
125+
@https_fn.on_call()
126+
def yielder(req: https_fn.CallableRequest[int]):
127+
yield from range(req.data)
128+
return "OK"
129+
130+
@https_fn.on_call()
131+
def yield_thrower(req: https_fn.CallableRequest[int]):
132+
yield from range(req.data)
133+
raise https_fn.HttpsError(https_fn.FunctionsErrorCode.PERMISSION_DENIED, "Can't read anymore")
134+
135+
with app.test_request_context("/"):
136+
environ = EnvironBuilder(
137+
method="POST",
138+
json={
139+
"data": 5
140+
}
141+
).get_environ()
142+
143+
request = Request(environ)
144+
response = yielder(request)
145+
146+
self.assertEqual(response.status_code, 200)
147+
self.assertEqual(response.get_json(), { "result": "OK" })
148+
149+
with app.test_request_context("/"):
150+
environ = EnvironBuilder(
151+
method="POST",
152+
json={
153+
"data": 3
154+
}
155+
).get_environ()
156+
157+
request = Request(environ)
158+
response = yield_thrower(request)
159+
160+
self.assertEqual(response.status_code, 403)
161+
self.assertEqual(response.get_json(), { "error": { "message": "Can't read anymore", "status": "PERMISSION_DENIED" }})
162+
163+
164+
def test_yielding_with_streaming(self):
165+
app = Flask(__name__)
166+
167+
@https_fn.on_call()
168+
def yielder(req: https_fn.CallableRequest[int]):
169+
yield from range(req.data)
170+
return "OK"
171+
172+
@https_fn.on_call()
173+
def yield_thrower(req: https_fn.CallableRequest[int]):
174+
yield from range(req.data)
175+
raise https_fn.HttpsError(https_fn.FunctionsErrorCode.INTERNAL, "Throwing")
176+
177+
with app.test_request_context("/"):
178+
environ = EnvironBuilder(
179+
method="POST",
180+
json={
181+
"data": 2
182+
},
183+
headers={
184+
"accept": "text/event-stream"
185+
}
186+
).get_environ()
187+
188+
request = Request(environ)
189+
response = yielder(request)
190+
191+
self.assertEqual(response.status_code, 200)
192+
chunks = list(response.response)
193+
self.assertEqual(chunks, ['data: {"message": 0}\n\n', 'data: {"message": 1}\n\n', 'data: {"result": "OK"}\n\n', "END"])
194+
195+
with app.test_request_context("/"):
196+
environ = EnvironBuilder(
197+
method="POST",
198+
json={
199+
"data": 2
200+
},
201+
headers={
202+
"accept": "text/event-stream"
203+
}
204+
).get_environ()
205+
206+
request = Request(environ)
207+
response = yield_thrower(request)
208+
209+
self.assertEqual(response.status_code, 200)
210+
chunks = list(response.response)
211+
self.assertEqual(chunks, ['data: {"message": 0}\n\n', 'data: {"message": 1}\n\n', 'error: {"error": {"status": "INTERNAL", "message": "Throwing"}}\n\n', "END"])

0 commit comments

Comments
 (0)