diff --git a/lualib-src/lua-socket.c b/lualib-src/lua-socket.c index 5ac0b359c..9d12ab523 100644 --- a/lualib-src/lua-socket.c +++ b/lualib-src/lua-socket.c @@ -461,6 +461,14 @@ lclose(lua_State *L) { return 0; } +static int +lshutdown(lua_State *L) { + int id = luaL_checkinteger(L,1); + struct skynet_context * ctx = lua_touserdata(L, lua_upvalueindex(1)); + skynet_socket_shutdown(ctx, id); + return 0; +} + static int llisten(lua_State *L) { const char * host = luaL_checkstring(L,1); @@ -639,6 +647,7 @@ luaopen_socketdriver(lua_State *L) { luaL_Reg l2[] = { { "connect", lconnect }, { "close", lclose }, + { "shutdown", lshutdown }, { "listen", llisten }, { "send", lsend }, { "lsend", lsendlow }, diff --git a/lualib/socket.lua b/lualib/socket.lua index 25a6099fb..1fd9d60b3 100644 --- a/lualib/socket.lua +++ b/lualib/socket.lua @@ -118,7 +118,7 @@ socket_message[5] = function(id, _, err) s.connecting = err end s.connected = false - driver.close(id) + driver.shutdown(id) wakeup(s) end @@ -210,18 +210,22 @@ function socket.start(id, func) return connect(id, func) end -function socket.shutdown(id) +local function close_fd(id, func) local s = socket_pool[id] if s then if s.buffer then driver.clear(s.buffer,buffer_pool) end if s.connected then - driver.close(id) + func(id) end end end +function socket.shutdown(id) + close_fd(id, driver.shutdown) +end + function socket.close(id) local s = socket_pool[id] if s == nil then @@ -232,7 +236,7 @@ function socket.close(id) -- notice: call socket.close in __gc should be carefully, -- because skynet.wait never return in __gc, so driver.clear may not be called if s.co then - -- reading this socket on another coroutine, so don't shutdown (clear the buffer) immediatel + -- reading this socket on another coroutine, so don't shutdown (clear the buffer) immediately -- wait reading coroutine read the buffer. assert(not s.closing) s.closing = coroutine.running() @@ -242,7 +246,7 @@ function socket.close(id) end s.connected = false end - socket.shutdown(id) + close_fd(id) -- clear the buffer (already close fd) assert(s.lock_set == nil or next(s.lock_set) == nil) socket_pool[id] = nil end diff --git a/skynet-src/skynet_socket.c b/skynet-src/skynet_socket.c index f6f8ec569..2f3aec849 100644 --- a/skynet-src/skynet_socket.c +++ b/skynet-src/skynet_socket.c @@ -159,6 +159,12 @@ skynet_socket_close(struct skynet_context *ctx, int id) { socket_server_close(SOCKET_SERVER, source, id); } +void +skynet_socket_shutdown(struct skynet_context *ctx, int id) { + uint32_t source = skynet_context_handle(ctx); + socket_server_shutdown(SOCKET_SERVER, source, id); +} + void skynet_socket_start(struct skynet_context *ctx, int id) { uint32_t source = skynet_context_handle(ctx); diff --git a/skynet-src/skynet_socket.h b/skynet-src/skynet_socket.h index bcdc137cd..55b98595d 100644 --- a/skynet-src/skynet_socket.h +++ b/skynet-src/skynet_socket.h @@ -29,6 +29,7 @@ int skynet_socket_listen(struct skynet_context *ctx, const char *host, int port, int skynet_socket_connect(struct skynet_context *ctx, const char *host, int port); int skynet_socket_bind(struct skynet_context *ctx, int fd); void skynet_socket_close(struct skynet_context *ctx, int id); +void skynet_socket_shutdown(struct skynet_context *ctx, int id); void skynet_socket_start(struct skynet_context *ctx, int id); void skynet_socket_nodelay(struct skynet_context *ctx, int id); diff --git a/skynet-src/socket_server.c b/skynet-src/socket_server.c index 6ee4d03ae..2e26e24f1 100644 --- a/skynet-src/socket_server.c +++ b/skynet-src/socket_server.c @@ -119,6 +119,7 @@ struct request_setudp { struct request_close { int id; + int shutdown; uintptr_t opaque; }; @@ -787,7 +788,7 @@ close_socket(struct socket_server *ss, struct request_close *request, struct soc if (type != -1) return type; } - if (send_buffer_empty(s)) { + if (request->shutdown || send_buffer_empty(s)) { force_close(ss,s,result); result->id = id; result->opaque = request->opaque; @@ -1366,6 +1367,17 @@ void socket_server_close(struct socket_server *ss, uintptr_t opaque, int id) { struct request_package request; request.u.close.id = id; + request.u.close.shutdown = 0; + request.u.close.opaque = opaque; + send_request(ss, &request, 'K', sizeof(request.u.close)); +} + + +void +socket_server_shutdown(struct socket_server *ss, uintptr_t opaque, int id) { + struct request_package request; + request.u.close.id = id; + request.u.close.shutdown = 1; request.u.close.opaque = opaque; send_request(ss, &request, 'K', sizeof(request.u.close)); } diff --git a/skynet-src/socket_server.h b/skynet-src/socket_server.h index b6f0f5fbd..41ef9ccae 100644 --- a/skynet-src/socket_server.h +++ b/skynet-src/socket_server.h @@ -26,6 +26,7 @@ int socket_server_poll(struct socket_server *, struct socket_message *result, in void socket_server_exit(struct socket_server *); void socket_server_close(struct socket_server *, uintptr_t opaque, int id); +void socket_server_shutdown(struct socket_server *, uintptr_t opaque, int id); void socket_server_start(struct socket_server *, uintptr_t opaque, int id); // return -1 when error