diff --git a/src/example/java/io/github/sashirestela/cleverclient/example/WebSocketExample.java b/src/example/java/io/github/sashirestela/cleverclient/example/WebSocketExample.java new file mode 100644 index 0000000..344c3f5 --- /dev/null +++ b/src/example/java/io/github/sashirestela/cleverclient/example/WebSocketExample.java @@ -0,0 +1,35 @@ +package io.github.sashirestela.cleverclient.example; + +import io.github.sashirestela.cleverclient.websocket.JavaHttpWebSocketAdapter; +import io.github.sashirestela.cleverclient.websocket.WebSocketAdapter; + +import java.util.Map; + +public class WebSocketExample { + + protected WebSocketAdapter webSocketAdapter; + + public WebSocketExample() { + this.webSocketAdapter = new JavaHttpWebSocketAdapter(); + } + + public void run() { + final String BASE_URL = "wss://s13970.blr1.piesocket.com/v3/1?api_key=" + System.getenv("PIESOCKET_API_KEY") + + "¬ify_self=1"; + + webSocketAdapter.onOpen(() -> System.out.println("Connected")); + webSocketAdapter.onMessage(message -> System.out.println("Received: " + message)); + webSocketAdapter.onClose((code, message) -> System.out.println("Closed")); + + webSocketAdapter.connect(BASE_URL, Map.of()).join(); + webSocketAdapter.send("Hello World!").join(); + webSocketAdapter.send("Welcome to the Jungle!").join(); + webSocketAdapter.close(); + } + + public static void main(String[] args) { + var example = new WebSocketExample(); + example.run(); + } + +} diff --git a/src/example/java/io/github/sashirestela/cleverclient/example/WebSocketExampleOkHttp.java b/src/example/java/io/github/sashirestela/cleverclient/example/WebSocketExampleOkHttp.java new file mode 100644 index 0000000..8592033 --- /dev/null +++ b/src/example/java/io/github/sashirestela/cleverclient/example/WebSocketExampleOkHttp.java @@ -0,0 +1,16 @@ +package io.github.sashirestela.cleverclient.example; + +import io.github.sashirestela.cleverclient.websocket.OkHttpWebSocketAdapter; + +public class WebSocketExampleOkHttp extends WebSocketExample { + + public WebSocketExampleOkHttp() { + this.webSocketAdapter = new OkHttpWebSocketAdapter(); + } + + public static void main(String[] args) { + var example = new WebSocketExampleOkHttp(); + example.run(); + } + +} diff --git a/src/main/java/io/github/sashirestela/cleverclient/CleverClient.java b/src/main/java/io/github/sashirestela/cleverclient/CleverClient.java index ba759e7..1f6fc70 100644 --- a/src/main/java/io/github/sashirestela/cleverclient/CleverClient.java +++ b/src/main/java/io/github/sashirestela/cleverclient/CleverClient.java @@ -8,6 +8,7 @@ import io.github.sashirestela.cleverclient.http.HttpResponseData; import io.github.sashirestela.cleverclient.support.Configurator; import io.github.sashirestela.cleverclient.util.CommonUtil; +import io.github.sashirestela.cleverclient.websocket.WebSocketAdapter; import lombok.Builder; import lombok.Getter; import lombok.NonNull; @@ -37,6 +38,7 @@ public class CleverClient { private final UnaryOperator requestInterceptor; private final UnaryOperator responseInterceptor; private final HttpClientAdapter clientAdapter; + private final WebSocketAdapter webSockewAdapter; private final HttpProcessor httpProcessor; /** @@ -49,6 +51,8 @@ public class CleverClient { * @param responseInterceptor Function to modify the response once it has been received. * @param clientAdapter Component to call http services. If none is passed the * JavaHttpClientAdapter will be used. Optional. + * @param webSocketAdapter Component to do web socket interactions. If none is passed the + * JavaHttpWebSocketAdapter will be used. Optional. * @param endsOfStream Texts used to mark the final of streams when handling server sent * events (SSE). Optional. * @param objectMapper Provides Json conversions either to and from objects. Optional. @@ -57,8 +61,8 @@ public class CleverClient { @SuppressWarnings("java:S107") public CleverClient(@NonNull String baseUrl, @Singular Map headers, Consumer bodyInspector, UnaryOperator requestInterceptor, UnaryOperator responseInterceptor, - HttpClientAdapter clientAdapter, @Singular("endOfStream") List endsOfStream, - ObjectMapper objectMapper) { + HttpClientAdapter clientAdapter, WebSocketAdapter webSocketAdapter, + @Singular("endOfStream") List endsOfStream, ObjectMapper objectMapper) { this.baseUrl = baseUrl; this.headers = Optional.ofNullable(headers).orElse(Map.of()); this.bodyInspector = bodyInspector; @@ -67,6 +71,7 @@ public CleverClient(@NonNull String baseUrl, @Singular Map heade this.clientAdapter = Optional.ofNullable(clientAdapter).orElse(new JavaHttpClientAdapter()); this.clientAdapter.setRequestInterceptor(this.requestInterceptor); this.clientAdapter.setResponseInterceptor(this.responseInterceptor); + this.webSockewAdapter = webSocketAdapter; this.httpProcessor = HttpProcessor.builder() .baseUrl(this.baseUrl) diff --git a/src/main/java/io/github/sashirestela/cleverclient/websocket/Action.java b/src/main/java/io/github/sashirestela/cleverclient/websocket/Action.java new file mode 100644 index 0000000..8a75959 --- /dev/null +++ b/src/main/java/io/github/sashirestela/cleverclient/websocket/Action.java @@ -0,0 +1,8 @@ +package io.github.sashirestela.cleverclient.websocket; + +@FunctionalInterface +public interface Action { + + void execute(); + +} diff --git a/src/main/java/io/github/sashirestela/cleverclient/websocket/JavaHttpWebSocketAdapter.java b/src/main/java/io/github/sashirestela/cleverclient/websocket/JavaHttpWebSocketAdapter.java new file mode 100644 index 0000000..761c0fb --- /dev/null +++ b/src/main/java/io/github/sashirestela/cleverclient/websocket/JavaHttpWebSocketAdapter.java @@ -0,0 +1,172 @@ +package io.github.sashirestela.cleverclient.websocket; + +import io.github.sashirestela.cleverclient.support.CleverClientException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.WebSocket; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +public class JavaHttpWebSocketAdapter implements WebSocketAdapter { + + private static final Logger logger = LoggerFactory.getLogger(JavaHttpWebSocketAdapter.class); + private HttpClient httpClient; + private WebSocket webSocket; + private Consumer messageCallback; + private Action openCallback; + private BiConsumer closeCallback; + private Consumer errorCallback; + private final StringBuilder dataBuffer = new StringBuilder(); + private CompletableFuture sendFuture; + private CompletableFuture closeFuture; + + public JavaHttpWebSocketAdapter(HttpClient httpClient) { + this.httpClient = httpClient; + logger.debug("Created WebSocketAdapter with custom HttpClient"); + } + + public JavaHttpWebSocketAdapter() { + this.httpClient = HttpClient.newHttpClient(); + logger.debug("Created WebSocketAdapter with default HttpClient"); + } + + @Override + @SuppressWarnings("java:S3776") + public CompletableFuture connect(String url, Map headers) { + logger.info("Connecting to WebSocket URL: {}", url); + logger.debug("Connection headers: {}", headers); + + WebSocket.Builder builder = this.httpClient.newWebSocketBuilder(); + headers.forEach(builder::header); + + CompletableFuture connectFuture = new CompletableFuture<>(); + + builder.buildAsync(URI.create(url), new WebSocket.Listener() { + + @Override + public void onOpen(WebSocket webSocket) { + JavaHttpWebSocketAdapter.this.webSocket = webSocket; + logger.info("WebSocket connection established"); + if (openCallback != null) { + openCallback.execute(); + } + connectFuture.complete(null); + webSocket.request(1); + } + + @Override + public CompletionStage onText(WebSocket webSocket, CharSequence data, boolean last) { + logger.trace("Received text data chunk, last={}", last); + dataBuffer.append(data); + if (last) { + if (messageCallback != null) { + String completeMessage = dataBuffer.toString(); + logger.debug("Received message: {}", completeMessage); + messageCallback.accept(completeMessage); + } + dataBuffer.setLength(0); + if (sendFuture != null) { + sendFuture.complete(null); + sendFuture = null; + } + } + webSocket.request(1); + return CompletableFuture.completedFuture(null); + } + + @Override + public CompletionStage onClose(WebSocket webSocket, int statusCode, String reason) { + logger.info("WebSocket closing with code: {}, reason: {}", statusCode, reason); + if (closeCallback != null) { + closeCallback.accept(statusCode, reason); + } + if (closeFuture != null) { + closeFuture.complete(null); + } + return CompletableFuture.completedFuture(null); + } + + @Override + public void onError(WebSocket webSocket, Throwable error) { + logger.error("WebSocket error occurred", error); + if (errorCallback != null) { + errorCallback.accept(error); + } + if (sendFuture != null) { + sendFuture.completeExceptionally(error); + } + if (closeFuture != null) { + closeFuture.completeExceptionally(error); + } + connectFuture.completeExceptionally(error); + } + + }); + + return connectFuture; + } + + @Override + public CompletableFuture send(String message) { + if (webSocket == null) { + logger.error("Attempt to send message before WebSocket connection is established"); + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new CleverClientException("WebSocket is not connected")); + return future; + } + + logger.debug("Sending message: {}", message); + sendFuture = new CompletableFuture<>(); + webSocket.sendText(message, true); + return sendFuture; + } + + @Override + public void close() { + if (webSocket != null) { + logger.info("Initiating WebSocket close"); + closeFuture = new CompletableFuture<>(); + webSocket.sendClose(WebSocket.NORMAL_CLOSURE, "Closing connection"); + try { + closeFuture.join(); + logger.debug("WebSocket close completed normally"); + } catch (Exception e) { + logger.error("Error during WebSocket close", e); + if (errorCallback != null) { + errorCallback.accept(e); + } + } + } + } + + @Override + public void onMessage(Consumer callback) { + logger.trace("Registering message callback"); + this.messageCallback = callback; + } + + @Override + public void onOpen(Action callback) { + logger.trace("Registering open callback"); + this.openCallback = callback; + } + + @Override + public void onClose(BiConsumer callback) { + logger.trace("Registering close callback"); + this.closeCallback = callback; + } + + @Override + public void onError(Consumer callback) { + logger.trace("Registering error callback"); + this.errorCallback = callback; + } + +} diff --git a/src/main/java/io/github/sashirestela/cleverclient/websocket/OkHttpWebSocketAdapter.java b/src/main/java/io/github/sashirestela/cleverclient/websocket/OkHttpWebSocketAdapter.java new file mode 100644 index 0000000..3f98fdb --- /dev/null +++ b/src/main/java/io/github/sashirestela/cleverclient/websocket/OkHttpWebSocketAdapter.java @@ -0,0 +1,138 @@ +package io.github.sashirestela.cleverclient.websocket; + +import io.github.sashirestela.cleverclient.support.CleverClientException; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.WebSocket; +import okhttp3.WebSocketListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +public class OkHttpWebSocketAdapter implements WebSocketAdapter { + + private static final Logger logger = LoggerFactory.getLogger(OkHttpWebSocketAdapter.class); + private OkHttpClient okHttpClient; + private WebSocket webSocket; + private Consumer messageCallback; + private Action openCallback; + private BiConsumer closeCallback; + private Consumer errorCallback; + + public OkHttpWebSocketAdapter(OkHttpClient okHttpClient) { + this.okHttpClient = okHttpClient; + logger.debug("Created WebSocketAdapter with custom OkHttpClient"); + } + + public OkHttpWebSocketAdapter() { + this.okHttpClient = new OkHttpClient(); + logger.debug("Created WebSocketAdapter with default OkHttpClient"); + } + + @Override + public CompletableFuture connect(String url, Map headers) { + logger.info("Connecting to WebSocket URL: {}", url); + logger.debug("Connection headers: {}", headers); + + Request.Builder requestBuilder = new Request.Builder().url(url); + headers.forEach(requestBuilder::addHeader); + + CompletableFuture connectFuture = new CompletableFuture<>(); + this.webSocket = okHttpClient.newWebSocket(requestBuilder.build(), new WebSocketListener() { + + @Override + public void onOpen(WebSocket webSocket, Response response) { + logger.info("WebSocket connection established with response code: {}", response.code()); + if (openCallback != null) { + openCallback.execute(); + } + connectFuture.complete(null); + } + + @Override + public void onMessage(WebSocket webSocket, String text) { + logger.debug("Received message: {}", text); + if (messageCallback != null) { + messageCallback.accept(text); + } + } + + @Override + public void onClosing(WebSocket webSocket, int code, String reason) { + logger.info("WebSocket closing with code: {}, reason: {}", code, reason); + if (closeCallback != null) { + closeCallback.accept(code, reason); + } + } + + @Override + public void onFailure(WebSocket webSocket, Throwable t, Response response) { + String responseCode = response != null ? String.valueOf(response.code()) : "unknown"; + logger.error("WebSocket error occurred. Response code: {}", responseCode, t); + if (errorCallback != null) { + errorCallback.accept(t); + } + connectFuture.completeExceptionally(t); + } + + }); + return connectFuture; + } + + @Override + public CompletableFuture send(String message) { + logger.debug("Sending message: {}", message); + boolean success = webSocket.send(message); + CompletableFuture future = new CompletableFuture<>(); + if (success) { + logger.trace("Message sent successfully"); + future.complete(null); + } else { + String errorMsg = "Failed to send message"; + logger.error(errorMsg); + future.completeExceptionally(new CleverClientException(errorMsg)); + } + return future; + } + + @Override + public void close() { + if (webSocket != null) { + logger.info("Initiating WebSocket close"); + webSocket.close(1000, "Closing connection"); + okHttpClient.dispatcher().executorService().shutdown(); + okHttpClient.connectionPool().evictAll(); + logger.debug("WebSocket resources cleaned up"); + } + } + + @Override + public void onMessage(Consumer callback) { + logger.trace("Registering message callback"); + this.messageCallback = callback; + } + + @Override + public void onOpen(Action callback) { + logger.trace("Registering open callback"); + this.openCallback = callback; + } + + @Override + public void onClose(BiConsumer callback) { + logger.trace("Registering close callback"); + this.closeCallback = callback; + } + + @Override + public void onError(Consumer errorCallback) { + logger.trace("Registering error callback"); + this.errorCallback = errorCallback; + } + +} diff --git a/src/main/java/io/github/sashirestela/cleverclient/websocket/WebSocketAdapter.java b/src/main/java/io/github/sashirestela/cleverclient/websocket/WebSocketAdapter.java new file mode 100644 index 0000000..071c5f3 --- /dev/null +++ b/src/main/java/io/github/sashirestela/cleverclient/websocket/WebSocketAdapter.java @@ -0,0 +1,24 @@ +package io.github.sashirestela.cleverclient.websocket; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +public interface WebSocketAdapter { + + CompletableFuture connect(String url, Map headers); + + CompletableFuture send(String message); + + void close(); + + void onMessage(Consumer callback); + + void onOpen(Action callback); + + void onClose(BiConsumer callback); + + void onError(Consumer callback); + +} diff --git a/src/test/java/io/github/sashirestela/cleverclient/websocket/JavaHttpWebSocketAdapterTest.java b/src/test/java/io/github/sashirestela/cleverclient/websocket/JavaHttpWebSocketAdapterTest.java new file mode 100644 index 0000000..ec1b196 --- /dev/null +++ b/src/test/java/io/github/sashirestela/cleverclient/websocket/JavaHttpWebSocketAdapterTest.java @@ -0,0 +1,169 @@ +package io.github.sashirestela.cleverclient.websocket; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.WebSocket; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +@ExtendWith(MockitoExtension.class) +@SuppressWarnings("unchecked") +class JavaHttpWebSocketAdapterTest { + + @Mock + private HttpClient mockHttpClient; + + @Mock + private WebSocket.Builder mockWebSocketBuilder; + + @Mock + private WebSocket mockWebSocket; + + @Captor + private ArgumentCaptor listenerCaptor; + + private JavaHttpWebSocketAdapter adapter; + + @BeforeEach + void setUp() { + lenient().when(mockHttpClient.newWebSocketBuilder()).thenReturn(mockWebSocketBuilder); + lenient().when(mockWebSocketBuilder.buildAsync(any(URI.class), any(WebSocket.Listener.class))) + .thenReturn(CompletableFuture.completedFuture(mockWebSocket)); + + adapter = new JavaHttpWebSocketAdapter(mockHttpClient); + } + + @Test + void testConnectSuccess() { + Action mockOpenAction = mock(Action.class); + adapter.onOpen(mockOpenAction); + + CompletableFuture connectFuture = adapter.connect("ws://test", Map.of("key", "value")); + + verify(mockWebSocketBuilder).buildAsync(any(URI.class), listenerCaptor.capture()); + WebSocket.Listener listener = listenerCaptor.getValue(); + listener.onOpen(mockWebSocket); + + assertTrue(connectFuture.isDone()); + assertDoesNotThrow(() -> connectFuture.get()); + verify(mockOpenAction).execute(); + } + + @Test + void testOnMessage() { + Consumer mockMessageCallback = mock(Consumer.class); + adapter.onMessage(mockMessageCallback); + + adapter.connect("ws://test", Map.of()); + + verify(mockWebSocketBuilder).buildAsync(any(URI.class), listenerCaptor.capture()); + WebSocket.Listener listener = listenerCaptor.getValue(); + + listener.onText(mockWebSocket, "Hello ", false); + listener.onText(mockWebSocket, "World", true); + + verify(mockMessageCallback).accept("Hello World"); + } + + @Test + void testSendSuccess() { + adapter.connect("ws://test", Map.of()); + + verify(mockWebSocketBuilder).buildAsync(any(URI.class), listenerCaptor.capture()); + WebSocket.Listener listener = listenerCaptor.getValue(); + + listener.onOpen(mockWebSocket); + + CompletableFuture sendFuture = adapter.send("test"); + + verify(mockWebSocket).sendText("test", true); + + listener.onText(mockWebSocket, "response", true); + + assertTrue(sendFuture.isDone()); + assertDoesNotThrow(() -> sendFuture.get()); + } + + @Test + void testSendBeforeConnection() { + CompletableFuture sendFuture = adapter.send("test"); + + assertTrue(sendFuture.isCompletedExceptionally()); + assertThrows(Exception.class, sendFuture::get, "Expected exception not thrown"); + verify(mockWebSocket, never()).sendText(anyString(), anyBoolean()); + } + + @Test + void testCloseCallback() { + BiConsumer mockCloseCallback = mock(BiConsumer.class); + adapter.onClose(mockCloseCallback); + + adapter.connect("ws://test", Map.of()); + + verify(mockWebSocketBuilder).buildAsync(any(URI.class), listenerCaptor.capture()); + WebSocket.Listener listener = listenerCaptor.getValue(); + + listener.onClose(mockWebSocket, 1000, "Normal closure"); + + verify(mockCloseCallback).accept(1000, "Normal closure"); + } + + @Test + void testErrorCallback() { + Consumer mockErrorCallback = mock(Consumer.class); + adapter.onError(mockErrorCallback); + + adapter.connect("ws://test", Map.of()); + + verify(mockWebSocketBuilder).buildAsync(any(URI.class), listenerCaptor.capture()); + WebSocket.Listener listener = listenerCaptor.getValue(); + + Throwable testError = new RuntimeException("Test error"); + listener.onError(mockWebSocket, testError); + + verify(mockErrorCallback).accept(testError); + } + + @Test + void testClose() { + adapter.connect("ws://test", Map.of()); + verify(mockWebSocketBuilder).buildAsync(any(URI.class), listenerCaptor.capture()); + WebSocket.Listener listener = listenerCaptor.getValue(); + + listener.onOpen(mockWebSocket); + + doAnswer(invocation -> { + int statusCode = invocation.getArgument(0); + String reason = invocation.getArgument(1); + listener.onClose(mockWebSocket, statusCode, reason); + return null; + }).when(mockWebSocket).sendClose(anyInt(), anyString()); + + assertDoesNotThrow(() -> adapter.close()); + verify(mockWebSocket).sendClose(WebSocket.NORMAL_CLOSURE, "Closing connection"); + } + +} diff --git a/src/test/java/io/github/sashirestela/cleverclient/websocket/OkHttpWebSocketAdapterTest.java b/src/test/java/io/github/sashirestela/cleverclient/websocket/OkHttpWebSocketAdapterTest.java new file mode 100644 index 0000000..f7bfb93 --- /dev/null +++ b/src/test/java/io/github/sashirestela/cleverclient/websocket/OkHttpWebSocketAdapterTest.java @@ -0,0 +1,169 @@ +package io.github.sashirestela.cleverclient.websocket; + +import okhttp3.ConnectionPool; +import okhttp3.Dispatcher; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.WebSocket; +import okhttp3.WebSocketListener; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +@SuppressWarnings("unchecked") +class OkHttpWebSocketAdapterTest { + + @Mock + private OkHttpClient mockOkHttpClient; + + @Mock + private WebSocket mockWebSocket; + + @Mock + private Response mockResponse; + + @Captor + private ArgumentCaptor listenerCaptor; + + private OkHttpWebSocketAdapter adapter; + + @BeforeEach + void setUp() { + adapter = new OkHttpWebSocketAdapter(mockOkHttpClient); + } + + @Test + void testConnectSuccess() { + Action mockOpenAction = mock(Action.class); + adapter.onOpen(mockOpenAction); + + CompletableFuture connectFuture = adapter.connect("ws://test", Map.of()); + + verify(mockOkHttpClient).newWebSocket(any(Request.class), listenerCaptor.capture()); + WebSocketListener listener = listenerCaptor.getValue(); + when(mockResponse.code()).thenReturn(101); + listener.onOpen(mockWebSocket, mockResponse); + + assertTrue(connectFuture.isDone()); + assertDoesNotThrow(() -> connectFuture.get()); + verify(mockOpenAction).execute(); + } + + @Test + void testOnMessage() { + Consumer mockMessageCallback = mock(Consumer.class); + adapter.onMessage(mockMessageCallback); + adapter.connect("ws://test", Map.of()); + + verify(mockOkHttpClient).newWebSocket(any(Request.class), listenerCaptor.capture()); + WebSocketListener listener = listenerCaptor.getValue(); + listener.onMessage(mockWebSocket, "test message"); + + verify(mockMessageCallback).accept("test message"); + } + + @Test + void testSendSuccess() { + when(mockOkHttpClient.newWebSocket(any(Request.class), any(WebSocketListener.class))) + .thenReturn(mockWebSocket); + adapter.connect("ws://test", Map.of()); + verify(mockOkHttpClient).newWebSocket(any(Request.class), listenerCaptor.capture()); + WebSocketListener listener = listenerCaptor.getValue(); + listener.onOpen(mockWebSocket, mockResponse); + + when(mockWebSocket.send("test")).thenReturn(true); + CompletableFuture sendFuture = adapter.send("test"); + + assertTrue(sendFuture.isDone()); + assertDoesNotThrow(() -> sendFuture.get()); + verify(mockWebSocket).send("test"); + } + + @Test + void testSendFailure() { + when(mockOkHttpClient.newWebSocket(any(Request.class), any(WebSocketListener.class))) + .thenReturn(mockWebSocket); + adapter.connect("ws://test", Map.of()); + verify(mockOkHttpClient).newWebSocket(any(Request.class), listenerCaptor.capture()); + WebSocketListener listener = listenerCaptor.getValue(); + listener.onOpen(mockWebSocket, mockResponse); + + when(mockWebSocket.send("test")).thenReturn(false); + CompletableFuture sendFuture = adapter.send("test"); + + assertTrue(sendFuture.isCompletedExceptionally()); + } + + @Test + void testCloseCallback() { + BiConsumer mockCloseCallback = mock(BiConsumer.class); + adapter.onClose(mockCloseCallback); + adapter.connect("ws://test", Map.of()); + + verify(mockOkHttpClient).newWebSocket(any(Request.class), listenerCaptor.capture()); + WebSocketListener listener = listenerCaptor.getValue(); + listener.onClosing(mockWebSocket, 1000, "Normal closure"); + + verify(mockCloseCallback).accept(1000, "Normal closure"); + } + + @Test + void testErrorCallback() { + Consumer mockErrorCallback = mock(Consumer.class); + adapter.onError(mockErrorCallback); + adapter.connect("ws://test", Map.of()); + + verify(mockOkHttpClient).newWebSocket(any(Request.class), listenerCaptor.capture()); + WebSocketListener listener = listenerCaptor.getValue(); + Throwable testError = new RuntimeException("Test error"); + listener.onFailure(mockWebSocket, testError, mockResponse); + + verify(mockErrorCallback).accept(testError); + } + + @Test + void testClose() { + Dispatcher mockDispatcher = mock(Dispatcher.class); + ExecutorService mockExecutorService = mock(ExecutorService.class); + ConnectionPool mockConnectionPool = mock(ConnectionPool.class); + + when(mockOkHttpClient.newWebSocket(any(Request.class), any(WebSocketListener.class))) + .thenReturn(mockWebSocket); + + when(mockOkHttpClient.dispatcher()).thenReturn(mockDispatcher); + when(mockDispatcher.executorService()).thenReturn(mockExecutorService); + when(mockOkHttpClient.connectionPool()).thenReturn(mockConnectionPool); + + adapter.connect("ws://test", Map.of()); + verify(mockOkHttpClient).newWebSocket(any(Request.class), listenerCaptor.capture()); + WebSocketListener listener = listenerCaptor.getValue(); + + listener.onOpen(mockWebSocket, mockResponse); + + adapter.close(); + + verify(mockWebSocket).close(1000, "Closing connection"); + verify(mockExecutorService).shutdown(); + verify(mockConnectionPool).evictAll(); + } + +}