4
4
import logging
5
5
import uuid
6
6
import sys
7
+ import asyncio
7
8
8
9
from concurrent import futures
9
10
from .exceptions import (JsonRpcException , JsonRpcRequestCancelled ,
12
13
log = logging .getLogger (__name__ )
13
14
JSONRPC_VERSION = '2.0'
14
15
CANCEL_METHOD = '$/cancelRequest'
16
+ EXIT_METHOD = 'exit'
15
17
16
18
17
19
class Endpoint :
@@ -35,9 +37,25 @@ def __init__(self, dispatcher, consumer, id_generator=lambda: str(uuid.uuid4()),
35
37
self ._client_request_futures = {}
36
38
self ._server_request_futures = {}
37
39
self ._executor_service = futures .ThreadPoolExecutor (max_workers = max_workers )
40
+ self ._cancelledRequests = set ()
41
+ self ._messageQueue = None
42
+ self ._consume_task = None
43
+
44
+ def init_async (self ):
45
+ self ._messageQueue = asyncio .Queue ()
46
+ self ._consume_task = asyncio .create_task (self .consume_task ())
47
+
48
+ async def consume_task (self ):
49
+ loop = asyncio .get_running_loop ()
50
+ while loop .is_running ():
51
+ message = await self ._messageQueue .get ()
52
+ await asyncio .to_thread (self .consume , message )
53
+ self ._messageQueue .task_done ()
38
54
39
55
def shutdown (self ):
40
56
self ._executor_service .shutdown ()
57
+ if self ._consume_task is not None :
58
+ self ._consume_task .cancel ()
41
59
42
60
def notify (self , method , params = None ):
43
61
"""Send a JSON RPC notification to the client.
@@ -94,6 +112,21 @@ def callback(future):
94
112
future .set_exception (JsonRpcRequestCancelled ())
95
113
return callback
96
114
115
+ async def consume_async (self , message ):
116
+ """Consume a JSON RPC message from the client and put it into a queue.
117
+
118
+ Args:
119
+ message (dict): The JSON RPC message sent by the client
120
+ """
121
+ if message ['method' ] == CANCEL_METHOD :
122
+ self ._cancelledRequests .add (message .get ('params' )['id' ])
123
+
124
+ # The exit message needs to be handled directly since the stream cannot be closed asynchronously
125
+ if message ['method' ] == EXIT_METHOD :
126
+ self .consume (message )
127
+ else :
128
+ await self ._messageQueue .put (message )
129
+
97
130
def consume (self , message ):
98
131
"""Consume a JSON RPC message from the client.
99
132
@@ -182,6 +215,9 @@ def _handle_request(self, msg_id, method, params):
182
215
except KeyError as e :
183
216
raise JsonRpcMethodNotFound .of (method ) from e
184
217
218
+ if msg_id in self ._cancelledRequests :
219
+ raise JsonRpcRequestCancelled ()
220
+
185
221
handler_result = handler (params )
186
222
187
223
if callable (handler_result ):
0 commit comments