@@ -124,6 +124,9 @@ abstract static class AcceptNode extends PythonUnaryBuiltinNode {
124
124
@ Specialization
125
125
@ TruffleBoundary
126
126
Object accept (PSocket socket ) {
127
+ if (socket .getServerSocket () == null ) {
128
+ throw raiseOSError (null , OSErrorEnum .EINVAL );
129
+ }
127
130
try {
128
131
SocketChannel acceptSocket = SocketUtils .accept (this , socket );
129
132
if (acceptSocket == null ) {
@@ -214,6 +217,7 @@ private static void doConnect(PSocket socket, Object[] hostAndPort) throws IOExc
214
217
InetSocketAddress socketAddress = new InetSocketAddress ((String ) hostAndPort [0 ], (Integer ) hostAndPort [1 ]);
215
218
SocketChannel channel = SocketChannel .open ();
216
219
channel .connect (socketAddress );
220
+ channel .configureBlocking (socket .isBlocking ());
217
221
socket .setSocket (channel );
218
222
}
219
223
}
@@ -338,6 +342,9 @@ Object listen(PSocket socket, PNone backlog) {
338
342
abstract static class RecvNode extends PythonTernaryClinicBuiltinNode {
339
343
@ Specialization
340
344
Object recv (VirtualFrame frame , PSocket socket , int bufsize , int flags ) {
345
+ if (socket .getSocket () == null ) {
346
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
347
+ }
341
348
ByteBuffer readBytes = PythonUtils .allocateByteBuffer (bufsize );
342
349
try {
343
350
int length = SocketUtils .recv (this , socket , readBytes );
@@ -384,6 +391,9 @@ Object recvInto(VirtualFrame frame, PSocket socket, PMemoryView buffer, Object f
384
391
@ CachedLibrary (limit = "getCallSiteInlineCacheMaxDepth()" ) PythonObjectLibrary lib ,
385
392
@ Cached ("create(__LEN__)" ) LookupAndCallUnaryNode callLen ,
386
393
@ Cached ("create(__SETITEM__)" ) LookupAndCallTernaryNode setItem ) {
394
+ if (socket .getSocket () == null ) {
395
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
396
+ }
387
397
int bufferLen = lib .asSizeWithState (callLen .executeObject (frame , buffer ), PArguments .getThreadState (frame ));
388
398
byte [] targetBuffer = new byte [bufferLen ];
389
399
ByteBuffer byteBuffer = PythonUtils .wrapByteBuffer (targetBuffer );
@@ -410,6 +420,9 @@ Object recvInto(VirtualFrame frame, PSocket socket, PByteArray buffer, Object fl
410
420
@ Cached ("createBinaryProfile()" ) ConditionProfile byteStorage ,
411
421
@ Cached SequenceStorageNodes .LenNode lenNode ,
412
422
@ Cached ("createSetItem()" ) SequenceStorageNodes .SetItemNode setItem ) {
423
+ if (socket .getSocket () == null ) {
424
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
425
+ }
413
426
SequenceStorage storage = buffer .getSequenceStorage ();
414
427
int bufferLen = lenNode .execute (storage );
415
428
if (byteStorage .profile (storage instanceof ByteSequenceStorage )) {
@@ -470,17 +483,14 @@ Object send(VirtualFrame frame, PSocket socket, PBytes bytes, Object flags,
470
483
@ Cached SequenceStorageNodes .ToByteArrayNode toBytes ) {
471
484
// TODO: do not ignore flags
472
485
if (socket .getSocket () == null ) {
473
- throw raise (OSError );
474
- }
475
-
476
- if (!socket .isOpen ()) {
477
- throw raise (OSError );
486
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
478
487
}
479
-
480
488
int written ;
481
489
ByteBuffer buffer = PythonUtils .wrapByteBuffer (toBytes .execute (bytes .getSequenceStorage ()));
482
490
try {
483
491
written = SocketUtils .send (this , socket , buffer );
492
+ } catch (NotYetConnectedException e ) {
493
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
484
494
} catch (IOException e ) {
485
495
throw raise (OSError );
486
496
}
@@ -500,6 +510,9 @@ Object sendAll(VirtualFrame frame, PSocket socket, PBytesLike bytes, Object flag
500
510
@ Cached SequenceStorageNodes .ToByteArrayNode toBytes ,
501
511
@ Cached ConditionProfile hasTimeoutProfile ) {
502
512
// TODO: do not ignore flags
513
+ if (socket .getSocket () == null ) {
514
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
515
+ }
503
516
ByteBuffer buffer = PythonUtils .wrapByteBuffer (toBytes .execute (bytes .getSequenceStorage ()));
504
517
long timeoutMillis = socket .getTimeoutInMilliseconds ();
505
518
TimeoutHelper timeoutHelper = null ;
@@ -513,6 +526,8 @@ Object sendAll(VirtualFrame frame, PSocket socket, PBytesLike bytes, Object flag
513
526
int written ;
514
527
try {
515
528
written = SocketUtils .send (this , socket , buffer , timeoutMillis );
529
+ } catch (NotYetConnectedException e ) {
530
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
516
531
} catch (IOException e ) {
517
532
throw raise (OSError );
518
533
}
@@ -606,24 +621,29 @@ Object setTimeout(PSocket socket, Object secondsObj,
606
621
@ GenerateNodeFactory
607
622
abstract static class shutdownNode extends PythonBinaryBuiltinNode {
608
623
@ Specialization
609
- @ TruffleBoundary
610
- Object family (PSocket socket , int how ) {
611
- if (socket .getSocket () != null ) {
612
- try {
613
- if (how == 0 || how == 2 ) {
614
- socket .getSocket ().shutdownInput ();
615
- }
616
- if (how == 1 || how == 2 ) {
617
- socket .getSocket ().shutdownOutput ();
618
- }
619
- } catch (IOException e ) {
620
- throw raise (OSError );
621
- }
622
- } else {
624
+ Object family (VirtualFrame frame , PSocket socket , int how ) {
625
+ if (socket .getSocket () == null ) {
626
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
627
+ }
628
+ try {
629
+ shutdown (socket , how );
630
+ } catch (NotYetConnectedException e ) {
631
+ throw raiseOSError (frame , OSErrorEnum .ENOTCONN );
632
+ } catch (IOException e ) {
623
633
throw raise (OSError );
624
634
}
625
635
return PNone .NO_VALUE ;
626
636
}
637
+
638
+ @ TruffleBoundary
639
+ private static void shutdown (PSocket socket , int how ) throws IOException {
640
+ if (how == 0 || how == 2 ) {
641
+ socket .getSocket ().shutdownInput ();
642
+ }
643
+ if (how == 1 || how == 2 ) {
644
+ socket .getSocket ().shutdownOutput ();
645
+ }
646
+ }
627
647
}
628
648
629
649
// family
0 commit comments