Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions slack_sdk/socket_mode/builtin/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _parse_handshake_response(sock: ssl.SSLSocket) -> Tuple[Optional[int], dict,
if len(elements) > 2:
status = int(elements[1])
else:
elements = line.split(":")
elements = line.split(":", 1)
if len(elements) == 2:
headers[elements[0].strip().lower()] = elements[1].strip()
if line is None or len(line.strip()) == 0:
Expand Down Expand Up @@ -337,7 +337,7 @@ def _fetch_messages(
)
else:
# This pattern is unexpected but set data with the expected length anyway
_append_message(current_header, current_data[:current_data_length]) # type: ignore[call-arg, arg-type]
_append_message(messages, current_header, current_data[:current_data_length])
return messages

# work in progress with the current_header/current_data
Expand Down
8 changes: 5 additions & 3 deletions slack_sdk/socket_mode/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,17 @@ def disconnect(self) -> None:
raise NotImplementedError()

def connect_to_new_endpoint(self, force: bool = False):
acquired = False
try:
self.connect_operation_lock.acquire(blocking=True, timeout=5)
if force or not self.is_connected():
acquired = self.connect_operation_lock.acquire(blocking=True, timeout=5)
if force or (acquired and not self.is_connected()):
self.logger.info("Connecting to a new endpoint...")
self.wss_uri = self.issue_new_wss_url()
self.connect()
self.logger.info("Connected to a new endpoint...")
finally:
self.connect_operation_lock.release()
if acquired:
self.connect_operation_lock.release()

def close(self) -> None:
self.closed = True
Expand Down
2 changes: 1 addition & 1 deletion slack_sdk/socket_mode/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def from_dict(cls, message: dict) -> Optional["SocketModeRequest"]:
return None

def to_dict(self) -> dict:
d = {"envelope_id": self.envelope_id}
d = {"type": self.type, "envelope_id": self.envelope_id}
if self.payload is not None:
d["payload"] = self.payload # type: ignore[assignment]
return d
3 changes: 2 additions & 1 deletion slack_sdk/socket_mode/websocket_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""websocket-client bassd Socket Mode client
"""websocket-client based Socket Mode client

* https://docs.slack.dev/apis/events-api/using-socket-mode/
* https://docs.slack.dev/tools/python-slack-sdk/socket-mode/
Expand Down Expand Up @@ -229,6 +229,7 @@ def close(self) -> None:
self.closed = True
self.auto_reconnect_enabled = False
self.disconnect()
self.current_session_runner.shutdown()
self.current_app_monitor.shutdown()
self.message_processor.shutdown()
self.message_workers.shutdown()
Expand Down
2 changes: 1 addition & 1 deletion slack_sdk/socket_mode/websockets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""websockets bassd Socket Mode client
"""websockets based Socket Mode client

* https://docs.slack.dev/apis/events-api/using-socket-mode/
* https://docs.slack.dev/tools/python-slack-sdk/socket-mode/
Expand Down
25 changes: 25 additions & 0 deletions tests/slack_sdk/socket_mode/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,28 @@ def test(self):
req = SocketModeRequest.from_dict(body)
self.assertIsNotNone(req)
self.assertEqual(req.envelope_id, "1d3c79ab-0ffb-41f3-a080-d19e85f53649")

def test_to_dict(self):
req = SocketModeRequest(
type="slash_commands",
envelope_id="abc-123",
payload={"text": "hello"},
)
self.assertDictEqual(
req.to_dict(), {"type": "slash_commands", "envelope_id": "abc-123", "payload": {"text": "hello"}}
)

def test_to_dict_from_dict_round_trip(self):
expected = SocketModeRequest(
type="slash_commands",
envelope_id="1d3c79ab-0ffb-41f3-a080-d19e85f53649",
payload={"token": "xxx", "team_id": "T111", "command": "/hello"},
accepts_response_payload=True,
retry_attempt=2,
retry_reason="timeout",
)
actual = SocketModeRequest.from_dict(expected.to_dict())
self.assertIsNotNone(actual)
self.assertEqual(actual.type, expected.type)
self.assertEqual(actual.envelope_id, expected.envelope_id)
self.assertEqual(actual.payload, expected.payload)
67 changes: 67 additions & 0 deletions tests/slack_sdk/socket_mode/test_socket_mode_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
import ssl
import unittest
from threading import Lock
from unittest.mock import MagicMock, patch

from slack_sdk.socket_mode.builtin.internals import _parse_handshake_response
from slack_sdk.socket_mode.client import BaseSocketModeClient


class TestSocketModeClient(unittest.TestCase):
logger = logging.getLogger(__name__)

def test_connect_to_new_endpoint_does_not_release_lock_on_acquisition_timeout(self):
client = BaseSocketModeClient.__new__(BaseSocketModeClient)
client.logger = self.logger
lock_mock = MagicMock(spec=Lock())
lock_mock.acquire.return_value = False
client.connect_operation_lock = lock_mock

client.connect_to_new_endpoint()

client.connect_operation_lock.release.assert_not_called()

def test_connect_to_new_endpoint_releases_lock_on_successful_acquisition(self):
client = BaseSocketModeClient.__new__(BaseSocketModeClient)
client.logger = self.logger
client.connect_operation_lock = Lock()

with patch.object(client, client.is_connected.__name__, return_value=True):
client.connect_to_new_endpoint()

acquired = client.connect_operation_lock.acquire(blocking=False)
self.assertTrue(acquired)
client.connect_operation_lock.release()

def test_parse_handshake_response_preserves_colons_in_header_values(self):
lines = [
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Location: https://example.com:8080/path",
"",
]
with patch(
"slack_sdk.socket_mode.builtin.internals._read_http_response_line",
side_effect=lines,
):
status, headers, _ = _parse_handshake_response(MagicMock(spec=ssl.SSLSocket))

self.assertEqual(status, 101)
self.assertEqual(headers["upgrade"], "websocket")
self.assertEqual(headers["location"], "https://example.com:8080/path")

def test_parse_handshake_response_parses_standard_headers(self):
lines = [
"HTTP/1.1 200 OK",
"Content-Type: text/html",
"",
]
with patch(
"slack_sdk.socket_mode.builtin.internals._read_http_response_line",
side_effect=lines,
):
status, headers, _ = _parse_handshake_response(MagicMock(spec=ssl.SSLSocket))

self.assertEqual(status, 200)
self.assertEqual(headers["content-type"], "text/html")