Skip to content

Commit 7df8090

Browse files
committed
add async interface
1 parent 1c97c6d commit 7df8090

File tree

4 files changed

+167
-2
lines changed

4 files changed

+167
-2
lines changed

pylsp_jsonrpc/endpoint.py

+35
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import uuid
66
import sys
7+
import asyncio
78

89
from concurrent import futures
910
from .exceptions import (JsonRpcException, JsonRpcRequestCancelled,
@@ -12,6 +13,7 @@
1213
log = logging.getLogger(__name__)
1314
JSONRPC_VERSION = '2.0'
1415
CANCEL_METHOD = '$/cancelRequest'
16+
EXIT_METHOD = 'exit'
1517

1618

1719
class Endpoint:
@@ -35,9 +37,24 @@ def __init__(self, dispatcher, consumer, id_generator=lambda: str(uuid.uuid4()),
3537
self._client_request_futures = {}
3638
self._server_request_futures = {}
3739
self._executor_service = futures.ThreadPoolExecutor(max_workers=max_workers)
40+
self._cancelledRequests = set()
41+
self._messageQueue = None
42+
self._consume_task = None
43+
44+
def init_async(self):
45+
self._messageQueue = asyncio.Queue()
46+
self._consume_task = asyncio.create_task(self.consume_task())
47+
48+
async def consume_task(self):
49+
while True:
50+
message = await self._messageQueue.get()
51+
await asyncio.to_thread(self.consume, message)
52+
self._messageQueue.task_done()
3853

3954
def shutdown(self):
4055
self._executor_service.shutdown()
56+
if self._consume_task is not None:
57+
self._consume_task.cancel()
4158

4259
def notify(self, method, params=None):
4360
"""Send a JSON RPC notification to the client.
@@ -94,6 +111,21 @@ def callback(future):
94111
future.set_exception(JsonRpcRequestCancelled())
95112
return callback
96113

114+
async def consume_async(self, message):
115+
"""Consume a JSON RPC message from the client and put it into a queue.
116+
117+
Args:
118+
message (dict): The JSON RPC message sent by the client
119+
"""
120+
if message['method'] == CANCEL_METHOD:
121+
self._cancelledRequests.add(message.get('params')['id'])
122+
123+
# The exit message needs to be handled directly since the stream cannot be closed asynchronously
124+
if message['method'] == EXIT_METHOD:
125+
self.consume(message)
126+
else:
127+
await self._messageQueue.put(message)
128+
97129
def consume(self, message):
98130
"""Consume a JSON RPC message from the client.
99131
@@ -182,6 +214,9 @@ def _handle_request(self, msg_id, method, params):
182214
except KeyError as e:
183215
raise JsonRpcMethodNotFound.of(method) from e
184216

217+
if msg_id in self._cancelledRequests:
218+
raise JsonRpcRequestCancelled()
219+
185220
handler_result = handler(params)
186221

187222
if callable(handler_result):

pylsp_jsonrpc/streams.py

+25
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import logging
55
import threading
6+
import asyncio
67

78
try:
89
import ujson as json
@@ -65,6 +66,30 @@ def _read_message(self):
6566
# Grab the body
6667
return self._rfile.read(content_length)
6768

69+
async def listen_async(self, message_consumer):
70+
"""Blocking call to listen for messages on the rfile.
71+
72+
Args:
73+
message_consumer (fn): function that is passed each message as it is read off the socket.
74+
"""
75+
76+
while not self._rfile.closed:
77+
try:
78+
request_str = await asyncio.to_thread(self._read_message)
79+
except ValueError:
80+
if self._rfile.closed:
81+
return
82+
log.exception("Failed to read from rfile")
83+
84+
if request_str is None:
85+
break
86+
87+
try:
88+
await message_consumer(json.loads(request_str.decode('utf-8')))
89+
except ValueError:
90+
log.exception("Failed to parse JSON message %s", request_str)
91+
continue
92+
6893
@staticmethod
6994
def _content_length(line):
7095
"""Extract the content length from an input line."""

test/test_endpoint.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pylsp_jsonrpc.endpoint import Endpoint
1111

1212
MSG_ID = 'id'
13-
13+
EXIT_METHOD = 'exit'
1414

1515
@pytest.fixture()
1616
def dispatcher():
@@ -319,6 +319,60 @@ def test_consume_request_cancel_unknown(endpoint):
319319
})
320320

321321

322+
@pytest.mark.asyncio
323+
async def test_consume_async_request_cancel(endpoint, dispatcher, consumer):
324+
def async_handler():
325+
time.sleep(1)
326+
handler = mock.Mock(return_value=async_handler)
327+
dispatcher['methodName'] = handler
328+
329+
endpoint.init_async()
330+
331+
await endpoint.consume_async({
332+
'jsonrpc': '2.0',
333+
'method': 'methodName',
334+
'params': {'key': 'value'}
335+
})
336+
await endpoint.consume_async({
337+
'jsonrpc': '2.0',
338+
'id': MSG_ID,
339+
'method': 'methodName',
340+
'params': {'key': 'value'}
341+
})
342+
await endpoint.consume_async({
343+
'jsonrpc': '2.0',
344+
'method': '$/cancelRequest',
345+
'params': {'id': MSG_ID}
346+
})
347+
348+
await endpoint._messageQueue.join()
349+
350+
consumer.assert_called_once_with({
351+
'jsonrpc': '2.0',
352+
'id': MSG_ID,
353+
'error': exceptions.JsonRpcRequestCancelled().to_dict()
354+
})
355+
356+
endpoint.shutdown()
357+
358+
359+
@pytest.mark.asyncio
360+
async def test_consume_async_exit(endpoint, dispatcher, consumer):
361+
# verify that exit is still called synchronously
362+
handler = mock.Mock()
363+
dispatcher[EXIT_METHOD] = handler
364+
365+
endpoint.init_async()
366+
367+
await endpoint.consume_async({
368+
'jsonrpc': '2.0',
369+
'method': EXIT_METHOD
370+
})
371+
372+
handler.assert_called_once_with(None)
373+
374+
endpoint.shutdown()
375+
322376
def assert_consumer_error(consumer_mock, exception):
323377
"""Assert that the consumer mock has had once call with the given error message and code.
324378

test/test_streams.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,53 @@ def test_reader_bad_json(rfile, reader):
7676
consumer.assert_not_called()
7777

7878

79+
@pytest.mark.asyncio
80+
async def test_reader_async(rfile, reader):
81+
rfile.write(
82+
b'Content-Length: 49\r\n'
83+
b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n'
84+
b'\r\n'
85+
b'{"id": "hello", "method": "method", "params": {}}'
86+
)
87+
rfile.seek(0)
88+
89+
consumer = mock.AsyncMock()
90+
await reader.listen_async(consumer)
91+
92+
consumer.assert_called_once_with({
93+
'id': 'hello',
94+
'method': 'method',
95+
'params': {}
96+
})
97+
98+
99+
@pytest.mark.asyncio
100+
async def test_reader_bad_message_async(rfile, reader):
101+
rfile.write(b'Hello world')
102+
rfile.seek(0)
103+
104+
# Ensure the listener doesn't throw
105+
consumer = mock.AsyncMock()
106+
await reader.listen_async(consumer)
107+
consumer.assert_not_called()
108+
109+
110+
@pytest.mark.asyncio
111+
async def test_reader_bad_json_async(rfile, reader):
112+
rfile.write(
113+
b'Content-Length: 8\r\n'
114+
b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n'
115+
b'\r\n'
116+
b'{hello}}'
117+
)
118+
rfile.seek(0)
119+
120+
# Ensure the listener doesn't throw
121+
consumer = mock.AsyncMock()
122+
await reader.listen_async(consumer)
123+
consumer.assert_not_called()
124+
125+
79126
def test_writer(wfile, writer):
80127
writer.write({
81128
'id': 'hello',
@@ -124,5 +171,9 @@ def test_writer_bad_message(wfile, writer):
124171
b'Content-Length: 10\r\n'
125172
b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n'
126173
b'\r\n'
127-
b'1546322461'
174+
b'1546322461',
175+
b'Content-Length: 10\r\n'
176+
b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n'
177+
b'\r\n'
178+
b'1546300861'
128179
]

0 commit comments

Comments
 (0)