1
1
package com .avsystem .commons
2
2
package jetty .rpc
3
3
4
- import java .nio .charset .StandardCharsets
5
4
import com .avsystem .commons .rpc .StandardRPCFramework
6
5
import com .avsystem .commons .serialization .json .{JsonStringInput , JsonStringOutput , RawJson }
7
6
import com .avsystem .commons .serialization .{GenCodec , HasGenCodec }
8
7
import com .typesafe .scalalogging .LazyLogging
9
-
10
- import javax .servlet .http .{HttpServletRequest , HttpServletResponse }
11
- import org .eclipse .jetty .client .HttpClient
12
- import org .eclipse .jetty .client .api .Result
13
- import org .eclipse .jetty .client .util .{BufferingResponseListener , StringContentProvider , StringRequestContent }
8
+ import jakarta .servlet .http .{HttpServlet , HttpServletRequest , HttpServletResponse }
9
+ import org .eclipse .jetty .client .{BufferingResponseListener , HttpClient , Result , StringRequestContent }
10
+ import org .eclipse .jetty .ee10 .servlet .ServletContextHandler
14
11
import org .eclipse .jetty .http .{HttpMethod , HttpStatus , MimeTypes }
15
- import org .eclipse .jetty .server .handler .AbstractHandler
16
- import org .eclipse .jetty .server .{Handler , Request }
12
+ import org .eclipse .jetty .server .Handler
17
13
18
- import scala .concurrent .duration ._
14
+ import java .nio .charset .StandardCharsets
15
+ import java .util .concurrent .atomic .AtomicBoolean
16
+ import scala .concurrent .duration .*
17
+ import scala .util .Using
19
18
20
19
object JettyRPCFramework extends StandardRPCFramework with LazyLogging {
21
20
class RawValue (val s : String ) extends AnyVal
@@ -89,30 +88,40 @@ object JettyRPCFramework extends StandardRPCFramework with LazyLogging {
89
88
request(HttpMethod .PUT , call)
90
89
}
91
90
92
- class RPCHandler (rootRpc : RawRPC , contextTimeout : FiniteDuration ) extends AbstractHandler {
93
- override def handle (target : String , baseRequest : Request , request : HttpServletRequest , response : HttpServletResponse ): Unit = {
94
- baseRequest.setHandled(true )
95
-
96
- val content = Iterator .continually(request.getReader.readLine())
97
- .takeWhile(_ != null )
98
- .mkString(" \n " )
99
-
100
- val call = read[Call ](new RawValue (content))
91
+ class RPCHandler (rootRpc : RawRPC , contextTimeout : FiniteDuration ) extends HttpServlet {
92
+ override def service (request : HttpServletRequest , response : HttpServletResponse ): Unit = {
93
+ // readRequest must execute in request thread but we want exceptions to be handled uniformly, hence the Try
94
+ val content =
95
+ Using (request.getReader)(reader =>
96
+ Iterator .continually(reader.readLine()).takeWhile(_ != null ).mkString(" \n " )
97
+ )
98
+ val call = content.map(content => read[Call ](new RawValue (content)))
101
99
102
100
HttpMethod .fromString(request.getMethod) match {
103
101
case HttpMethod .POST =>
104
- val async = request.startAsync().setup(_.setTimeout(contextTimeout.toMillis))
105
- handlePost(call).andThenNow {
102
+ val asyncContext = request.startAsync().setup(_.setTimeout(contextTimeout.toMillis))
103
+ val completed = new AtomicBoolean (false )
104
+ // Need to protect asyncContext from being completed twice because after a timeout the
105
+ // servlet may recycle the same context instance between subsequent requests (not cool)
106
+ // https://stackoverflow.com/a/27744537
107
+ def completeWith (code : => Unit ): Unit =
108
+ if (! completed.getAndSet(true )) {
109
+ code
110
+ asyncContext.complete()
111
+ }
112
+ Future .fromTry(call).flatMapNow(handlePost).onCompleteNow {
106
113
case Success (responseContent) =>
107
- response.setContentType(MimeTypes .Type .APPLICATION_JSON .asString())
108
- response.setCharacterEncoding(StandardCharsets .UTF_8 .name())
109
- response.getWriter.write(responseContent.s)
114
+ completeWith {
115
+ response.setContentType(MimeTypes .Type .APPLICATION_JSON .asString())
116
+ response.setCharacterEncoding(StandardCharsets .UTF_8 .name())
117
+ response.getWriter.write(responseContent.s)
118
+ }
110
119
case Failure (t) =>
111
- response.sendError(HttpStatus .INTERNAL_SERVER_ERROR_500 , t.getMessage)
120
+ completeWith( response.sendError(HttpStatus .INTERNAL_SERVER_ERROR_500 , t.getMessage) )
112
121
logger.error(" Failed to handle RPC call" , t)
113
- }.andThenNow { case _ => async.complete() }
122
+ }
114
123
case HttpMethod .PUT =>
115
- handlePut( call)
124
+ call.map(handlePut).get
116
125
case _ =>
117
126
throw new IllegalArgumentException (s " Request HTTP method is ${request.getMethod}, only POST or PUT are supported " )
118
127
}
@@ -132,11 +141,12 @@ object JettyRPCFramework extends StandardRPCFramework with LazyLogging {
132
141
invoke(call)(_.fire)
133
142
}
134
143
135
- def newHandler [T ](impl : T , contextTimeout : FiniteDuration = 30 .seconds)(
136
- implicit asRawRPC : AsRawRPC [T ]): Handler =
137
- new RPCHandler (asRawRPC.asRaw(impl), contextTimeout)
144
+ def newServlet [T : AsRawRPC ](impl : T , contextTimeout : FiniteDuration = 30 .seconds): HttpServlet =
145
+ new RPCHandler (AsRawRPC [T ].asRaw(impl), contextTimeout)
146
+
147
+ def newHandler [T : AsRawRPC ](impl : T , contextTimeout : FiniteDuration = 30 .seconds): Handler =
148
+ new ServletContextHandler ().setup(_.addServlet(newServlet(impl, contextTimeout), " /*" ))
138
149
139
- def newClient [T ](httpClient : HttpClient , uri : String , maxResponseLength : Int = 2 * 1024 * 1024 )(
140
- implicit asRealRPC : AsRealRPC [T ]): T =
141
- asRealRPC.asReal(new RPCClient (httpClient, uri, maxResponseLength).rawRPC)
150
+ def newClient [T : AsRealRPC ](httpClient : HttpClient , uri : String , maxResponseLength : Int = 2 * 1024 * 1024 ): T =
151
+ AsRealRPC [T ].asReal(new RPCClient (httpClient, uri, maxResponseLength).rawRPC)
142
152
}
0 commit comments