11/*
2- * Copyright (c) 2017, 2024 , Oracle and/or its affiliates. All rights reserved.
2+ * Copyright (c) 2017, 2025 , Oracle and/or its affiliates. All rights reserved.
33 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44 *
55 * This code is free software; you can redistribute it and/or modify it
2323
2424package jdk .internal .net .http ;
2525
26+ import java .io .Closeable ;
2627import java .io .IOException ;
2728import java .io .InputStream ;
2829import java .io .OutputStream ;
4344import java .net .http .HttpClient ;
4445import java .net .http .HttpRequest ;
4546import java .net .http .HttpResponse ;
47+ import java .util .concurrent .atomic .AtomicReference ;
48+
4649import jdk .internal .net .http .websocket .RawChannel ;
47- import jdk .internal .net .http .websocket .WebSocketRequest ;
4850import org .testng .annotations .Test ;
4951import static java .net .http .HttpResponse .BodyHandlers .discarding ;
5052import static java .util .concurrent .TimeUnit .SECONDS ;
5759 */
5860public class RawChannelTest {
5961
62+ // can't use jdk.test.lib when injected in java.net.httpclient
63+ // Seed can be specified on the @run line with -Dseed=<seed>
64+ private static class RandomFactory {
65+ private static long getSeed () {
66+ long seed = Long .getLong ("seed" , new Random ().nextLong ());
67+ System .out .println ("Seed from RandomFactory = " +seed +"L" );
68+ return seed ;
69+ }
70+ public static Random getRandom () {
71+ return new Random (getSeed ());
72+ }
73+ }
74+
75+ private static final Random RANDOM = RandomFactory .getRandom ();
6076 private final AtomicLong clientWritten = new AtomicLong ();
6177 private final AtomicLong serverWritten = new AtomicLong ();
6278 private final AtomicLong clientRead = new AtomicLong ();
@@ -90,7 +106,8 @@ public void test() throws Exception {
90106 server .setReuseAddress (false );
91107 server .bind (new InetSocketAddress (InetAddress .getLoopbackAddress (), 0 ));
92108 int port = server .getLocalPort ();
93- new TestServer (server ).start ();
109+ TestServer testServer = new TestServer (server );
110+ testServer .start ();
94111
95112 final RawChannel chan = channelOf (port );
96113 print ("RawChannel is %s" , String .valueOf (chan ));
@@ -129,6 +146,7 @@ public void handle() {
129146 } catch (IOException e ) {
130147 outputCompleted .completeExceptionally (e );
131148 e .printStackTrace ();
149+ closeChannel (chan );
132150 }
133151 return ;
134152 }
@@ -145,6 +163,9 @@ public void handle() {
145163 chan .registerEvent (this );
146164 writeStall .countDown (); // signal send buffer is full
147165 } catch (IOException e ) {
166+ print ("OP_WRITE failed: " + e );
167+ outputCompleted .completeExceptionally (e );
168+ closeChannel (chan );
148169 throw new UncheckedIOException (e );
149170 }
150171 }
@@ -168,6 +189,7 @@ public void handle() {
168189 read = chan .read ();
169190 } catch (IOException e ) {
170191 inputCompleted .completeExceptionally (e );
192+ closeChannel (chan );
171193 e .printStackTrace ();
172194 }
173195 if (read == null ) {
@@ -179,7 +201,10 @@ public void handle() {
179201 try {
180202 chan .registerEvent (this );
181203 } catch (IOException e ) {
182- e .printStackTrace ();
204+ print ("OP_READ failed to register event: " + e );
205+ inputCompleted .completeExceptionally (e );
206+ closeChannel (chan );
207+ throw new UncheckedIOException (e );
183208 }
184209 readStall .countDown ();
185210 break ;
@@ -191,21 +216,33 @@ public void handle() {
191216 print ("OP_READ read %s bytes (%s total)" , total , clientRead .get ());
192217 }
193218 });
219+
194220 CompletableFuture .allOf (outputCompleted ,inputCompleted )
195221 .whenComplete ((r ,t ) -> {
196- try {
197- print ("closing channel" );
198- chan .close ();
199- } catch (IOException x ) {
200- x .printStackTrace ();
201- }
222+ closeChannel (chan );
202223 });
203224 exit .await (); // All done, we need to compare results:
204225 assertEquals (clientRead .get (), serverWritten .get ());
205226 assertEquals (serverRead .get (), clientWritten .get ());
227+ Throwable serverError = testServer .failed .get ();
228+ if (serverError != null ) {
229+ throw new AssertionError ("TestServer failed: "
230+ + serverError , serverError );
231+ }
206232 }
207233 }
208234
235+ private static void closeChannel (RawChannel chan ) {
236+ print ("closing channel" );
237+ try {
238+ chan .close ();
239+ } catch (IOException x ) {
240+ print ("Failed to close channel: " + x );
241+ x .printStackTrace ();
242+ }
243+ }
244+
245+
209246 private static RawChannel channelOf (int port ) throws Exception {
210247 URI uri = URI .create ("http://localhost:" + port + "/" );
211248 print ("raw channel to %s" , uri .toString ());
@@ -237,11 +274,24 @@ private static RawChannel channelOf(int port) throws Exception {
237274 private class TestServer extends Thread { // Powered by Slowpokes
238275
239276 private final ServerSocket server ;
277+ private final AtomicReference <Throwable > failed = new AtomicReference <>();
240278
241279 TestServer (ServerSocket server ) throws IOException {
242280 this .server = server ;
243281 }
244282
283+ private void fail (Closeable s , String actor , Throwable t ) {
284+ failed .compareAndSet (null , t );
285+ print ("Server %s got exception: %s" , actor , t );
286+ t .printStackTrace ();
287+ try {
288+ s .close ();
289+ } catch (Exception x ) {
290+ print ("Server %s failed to close socket: %s" , actor , t );
291+ }
292+
293+ }
294+
245295 @ Override
246296 public void run () {
247297 try (Socket s = server .accept ()) {
@@ -252,21 +302,23 @@ public void run() {
252302
253303 Thread reader = new Thread (() -> {
254304 try {
305+ print ("Server reader started" );
255306 long n = readSlowly (is );
256307 print ("Server read %s bytes" , n );
257308 s .shutdownInput ();
258309 } catch (Exception e ) {
259- e . printStackTrace ( );
310+ fail ( s , "reader" , e );
260311 }
261312 });
262313
263314 Thread writer = new Thread (() -> {
264315 try {
316+ print ("Server writer started" );
265317 long n = writeSlowly (os );
266318 print ("Server written %s bytes" , n );
267319 s .shutdownOutput ();
268320 } catch (Exception e ) {
269- e . printStackTrace ( );
321+ fail ( s , "writer" , e );
270322 }
271323 });
272324
@@ -276,7 +328,7 @@ public void run() {
276328 reader .join ();
277329 writer .join ();
278330 } catch (Exception e ) {
279- e . printStackTrace ( );
331+ fail ( server , "acceptor" , e );
280332 } finally {
281333 exit .countDown ();
282334 }
@@ -365,6 +417,8 @@ private static void print(String format, Object... args) {
365417 }
366418
367419 private static byte [] byteArrayOfSize (int bound ) {
368- return new byte [new Random ().nextInt (1 + bound )];
420+ // bound must be > 1; No need to check it,
421+ // nextInt will throw IllegalArgumentException if needed
422+ return new byte [RANDOM .nextInt (1 , bound + 1 )];
369423 }
370424}
0 commit comments