Skip to content

Create caches for document encoders #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modules/core/src/main/scala/jsonrpclib/Monadic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ object Monadic {
implicit class MonadicOps[F[_], A](private val fa: F[A]) extends AnyVal {
def flatMap[B](f: A => F[B])(implicit m: Monadic[F]): F[B] = m.doFlatMap(fa)(f)
def map[B](f: A => B)(implicit m: Monadic[F]): F[B] = m.doMap(fa)(f)
def attempt[B](implicit m: Monadic[F]): F[Either[Throwable, A]] = m.doAttempt(fa)
def attempt(implicit m: Monadic[F]): F[Either[Throwable, A]] = m.doAttempt(fa)
def void(implicit m: Monadic[F]): F[Unit] = m.doVoid(fa)
}
implicit class MonadicOpsPure[A](private val a: A) extends AnyVal {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import cats.effect._
import fs2.io._
import io.circe.generic.semiauto._
import io.circe.Codec
import io.circe.Decoder
import io.circe.Encoder
import jsonrpclib.fs2._
import jsonrpclib.CallId
import jsonrpclib.Endpoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import io.circe.Encoder
import jsonrpclib._
import jsonrpclib.fs2._
import test._
import test.TestServerOperation.GreetError
import weaver._

import scala.concurrent.duration._
Expand Down Expand Up @@ -78,4 +79,47 @@ object TestClientSpec extends SimpleIOSuite {
expect.same(result.payload.message, "Hello Bob")
}
}

testRes("server returns known error") {
implicit val greetInputDecoder: Decoder[GreetInput] = CirceJsonCodec.fromSchema
implicit val greetOutputEncoder: Encoder[GreetOutput] = CirceJsonCodec.fromSchema
implicit val greetErrorEncoder: Encoder[GreetError] = CirceJsonCodec.fromSchema
implicit val errEncoder: ErrorEncoder[GreetError] =
err => ErrorPayload(-1, "error", Some(Payload(greetErrorEncoder(err))))

val endpoint: Endpoint[IO] =
Endpoint[IO]("greet").apply[GreetInput, GreetError, GreetOutput](in =>
IO.pure(Left(GreetError.notWelcomeError(NotWelcomeError(s"${in.name} is not welcome"))))
)

for {
clientSideChannel <- setup(endpoint)
clientStub = ClientStub(TestServer, clientSideChannel)
result <- clientStub.greet("Bob").attempt.toStream
} yield {
matches(result) { case Left(t: NotWelcomeError) =>
expect.same(t.msg, s"Bob is not welcome")
}
}
}

testRes("server returns unknown error") {
implicit val greetInputDecoder: Decoder[GreetInput] = CirceJsonCodec.fromSchema
implicit val greetOutputEncoder: Encoder[GreetOutput] = CirceJsonCodec.fromSchema

val endpoint: Endpoint[IO] =
Endpoint[IO]("greet").simple[GreetInput, GreetOutput](_ => IO.raiseError(new RuntimeException("boom!")))

for {
clientSideChannel <- setup(endpoint)
clientStub = ClientStub(TestServer, clientSideChannel)
result <- clientStub.greet("Bob").attempt.toStream
} yield {
matches(result) { case Left(t: ErrorPayload) =>
expect.same(t.code, 0) &&
expect.same(t.message, "boom!") &&
expect.same(t.data, None)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package jsonrpclib.smithy4sinterop

import io.circe.{Decoder => CirceDecoder, _}
import smithy4s.codecs.PayloadPath
import smithy4s.schema.CachedSchemaCompiler
import smithy4s.Document
import smithy4s.Document.{Encoder => _, _}
import smithy4s.Schema

private[smithy4sinterop] class CirceDecoderImpl extends CachedSchemaCompiler[CirceDecoder] {
val decoder: CachedSchemaCompiler.DerivingImpl[Decoder] = Document.Decoder

type Cache = decoder.Cache
def createCache(): Cache = decoder.createCache()

def fromSchema[A](schema: Schema[A], cache: Cache): CirceDecoder[A] =
c => {
c.as[Json]
.map(fromJson(_))
.flatMap { d =>
decoder
.fromSchema(schema, cache)
.decode(d)
.left
.map(e =>
DecodingFailure(DecodingFailure.Reason.CustomReason(e.getMessage), c.history ++ toCursorOps(e.path))
)
}
}

def fromSchema[A](schema: Schema[A]): CirceDecoder[A] = fromSchema(schema, createCache())

private def toCursorOps(path: PayloadPath): List[CursorOp] =
path.segments.map {
case PayloadPath.Segment.Label(name) => CursorOp.DownField(name)
case PayloadPath.Segment.Index(i) => CursorOp.DownN(i)
}

private def fromJson(json: Json): Document = json.fold(
jsonNull = DNull,
jsonBoolean = DBoolean(_),
jsonNumber = n => DNumber(n.toBigDecimal.get),
jsonString = DString(_),
jsonArray = arr => DArray(arr.map(fromJson)),
jsonObject = obj => DObject(obj.toMap.view.mapValues(fromJson).toMap)
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package jsonrpclib.smithy4sinterop

import io.circe.{Encoder => CirceEncoder, _}
import smithy4s.schema.CachedSchemaCompiler
import smithy4s.Document
import smithy4s.Document._
import smithy4s.Schema

private[smithy4sinterop] class CirceEncoderImpl extends CachedSchemaCompiler[CirceEncoder] {
val encoder: CachedSchemaCompiler.DerivingImpl[Encoder] = Document.Encoder

type Cache = encoder.Cache
def createCache(): Cache = encoder.createCache()

def fromSchema[A](schema: Schema[A], cache: Cache): CirceEncoder[A] =
a => documentToJson(encoder.fromSchema(schema, cache).encode(a))

def fromSchema[A](schema: Schema[A]): CirceEncoder[A] = fromSchema(schema, createCache())

private val documentToJson: Document => Json = {
case DNull => Json.Null
case DString(value) => Json.fromString(value)
case DBoolean(value) => Json.fromBoolean(value)
case DNumber(value) => Json.fromBigDecimal(value)
case DArray(values) => Json.fromValues(values.map(documentToJson))
case DObject(entries) => Json.fromFields(entries.view.mapValues(documentToJson))
}
}
Original file line number Diff line number Diff line change
@@ -1,55 +1,33 @@
package jsonrpclib.smithy4sinterop

import io.circe._
import smithy4s.codecs.PayloadPath
import smithy4s.Document
import smithy4s.Document.{Decoder => _, _}
import smithy4s.schema.CachedSchemaCompiler
import smithy4s.Schema

object CirceJsonCodec {

object Encoder extends CirceEncoderImpl
object Decoder extends CirceDecoderImpl

object Codec extends CachedSchemaCompiler[Codec] {
type Cache = (Encoder.Cache, Decoder.Cache)
def createCache(): Cache = (Encoder.createCache(), Decoder.createCache())

def fromSchema[A](schema: Schema[A]): Codec[A] =
io.circe.Codec.from(Decoder.fromSchema(schema), Encoder.fromSchema(schema))

def fromSchema[A](schema: Schema[A], cache: Cache): Codec[A] =
io.circe.Codec.from(
Decoder.fromSchema(schema, cache._2),
Encoder.fromSchema(schema, cache._1)
)
}

/** Creates a Circe `Codec[A]` from a Smithy4s `Schema[A]`.
*
* This enables encoding values of type `A` to JSON and decoding JSON back into `A`, using the structure defined by
* the Smithy schema.
*/
def fromSchema[A](implicit schema: Schema[A]): Codec[A] = Codec.from(
c => {
c.as[Json]
.map(fromJson)
.flatMap { d =>
Document
.decode[A](d)
.left
.map(e =>
DecodingFailure(DecodingFailure.Reason.CustomReason(e.getMessage), c.history ++ toCursorOps(e.path))
)
}
},
a => documentToJson(Document.encode(a))
)

private def toCursorOps(path: PayloadPath): List[CursorOp] =
path.segments.map {
case PayloadPath.Segment.Label(name) => CursorOp.DownField(name)
case PayloadPath.Segment.Index(i) => CursorOp.DownN(i)
}

private val documentToJson: Document => Json = {
case DNull => Json.Null
case DString(value) => Json.fromString(value)
case DBoolean(value) => Json.fromBoolean(value)
case DNumber(value) => Json.fromBigDecimal(value)
case DArray(values) => Json.fromValues(values.map(documentToJson))
case DObject(entries) => Json.fromFields(entries.view.mapValues(documentToJson))
}

private def fromJson(json: Json): Document = json.fold(
jsonNull = DNull,
jsonBoolean = DBoolean(_),
jsonNumber = n => DNumber(n.toBigDecimal.get),
jsonString = DString(_),
jsonArray = arr => DArray(arr.map(fromJson)),
jsonObject = obj => DObject(obj.toMap.view.mapValues(fromJson).toMap)
)
def fromSchema[A](implicit schema: Schema[A]): Codec[A] =
Codec.fromSchema(schema)
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package jsonrpclib.smithy4sinterop

import io.circe.Codec
import io.circe.HCursor
import jsonrpclib.Channel
import jsonrpclib.ErrorPayload
import jsonrpclib.Monadic
import jsonrpclib.Monadic.syntax._
import jsonrpclib.ProtocolError
import smithy4s.~>
import smithy4s.schema._
import smithy4s.Service
Expand Down Expand Up @@ -30,12 +34,13 @@ object ClientStub {
private class ClientStub[Alg[_[_, _, _, _, _]], F[_]: Monadic](val service: Service[Alg], channel: Channel[F]) {

def compile: service.Impl[F] = {
val codecCache = CirceJsonCodec.Codec.createCache()
val interpreter = new service.FunctorEndpointCompiler[F] {
def apply[I, E, O, SI, SO](e: service.Endpoint[I, E, O, SI, SO]): I => F[O] = {
val shapeId = e.id
val spec = EndpointSpec.fromHints(e.hints).toRight(NotJsonRPCEndpoint(shapeId)).toTry.get

jsonRPCStub(e, spec)
jsonRPCStub(e, spec, codecCache)
}
}

Expand All @@ -44,18 +49,42 @@ private class ClientStub[Alg[_[_, _, _, _, _]], F[_]: Monadic](val service: Serv

def jsonRPCStub[I, E, O, SI, SO](
smithy4sEndpoint: service.Endpoint[I, E, O, SI, SO],
endpointSpec: EndpointSpec
endpointSpec: EndpointSpec,
codecCache: CirceJsonCodec.Codec.Cache
): I => F[O] = {

implicit val inputCodec: Codec[I] = CirceJsonCodec.fromSchema(smithy4sEndpoint.input)
implicit val outputCodec: Codec[O] = CirceJsonCodec.fromSchema(smithy4sEndpoint.output)
implicit val inputCodec: Codec[I] = CirceJsonCodec.Codec.fromSchema(smithy4sEndpoint.input, codecCache)
implicit val outputCodec: Codec[O] = CirceJsonCodec.Codec.fromSchema(smithy4sEndpoint.output, codecCache)

def errorResponse(throwable: Throwable, errorCodec: Codec[E]): F[E] = {
throwable match {
case ErrorPayload(_, _, Some(payload)) =>
errorCodec.decodeJson(payload.data) match {
case Left(err) => ProtocolError.ParseError(err.getMessage).raiseError
case Right(error) => error.pure
}
case e: Throwable => e.raiseError
}
}

endpointSpec match {
case EndpointSpec.Notification(methodName) =>
val coerce = coerceUnit[O](smithy4sEndpoint.output)
channel.notificationStub[I](methodName).andThen(f => Monadic[F].doFlatMap(f)(_ => coerce))
case EndpointSpec.Request(methodName) =>
channel.simpleStub[I, O](methodName)
smithy4sEndpoint.error match {
case None => channel.simpleStub[I, O](methodName)
case Some(errorSchema) =>
val errorCodec = CirceJsonCodec.Codec.fromSchema(errorSchema.schema, codecCache)
val stub = channel.simpleStub[I, O](methodName)
(in: I) =>
stub.apply(in).attempt.flatMap {
case Right(success) => success.pure
case Left(error) =>
errorResponse(error, errorCodec)
.flatMap(e => errorSchema.unliftError(e).raiseError)
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ object ServerEndpoints {
)(implicit service: Service[Alg], F: Monadic[F]): List[Endpoint[F]] = {
val transformedService = JsonRpcTransformations.apply(service)
val interpreter: transformedService.FunctorInterpreter[F] = transformedService.toPolyFunction(impl)
val codecCache = CirceJsonCodec.Codec.createCache()
transformedService.endpoints.toList.flatMap { smithy4sEndpoint =>
EndpointSpec
.fromHints(smithy4sEndpoint.hints)
.map { endpointSpec =>
jsonRPCEndpoint(smithy4sEndpoint, endpointSpec, interpreter)
jsonRPCEndpoint(smithy4sEndpoint, endpointSpec, interpreter, codecCache)
}
.toList
}
Expand All @@ -55,17 +56,19 @@ object ServerEndpoints {
* JSON-RPC method name and interaction hints
* @param impl
* Interpreter that executes the Smithy operation in `F`
* @param codecCache
* Coche for the schema to codec compilation results
* @return
* A JSON-RPC-compatible `Endpoint[F]`
*/
private def jsonRPCEndpoint[F[_]: Monadic, Op[_, _, _, _, _], I, E, O, SI, SO](
smithy4sEndpoint: Smithy4sEndpoint[Op, I, E, O, SI, SO],
endpointSpec: EndpointSpec,
impl: FunctorInterpreter[Op, F]
impl: FunctorInterpreter[Op, F],
codecCache: CirceJsonCodec.Codec.Cache
): Endpoint[F] = {

implicit val inputCodec: Codec[I] = CirceJsonCodec.fromSchema(smithy4sEndpoint.input)
implicit val outputCodec: Codec[O] = CirceJsonCodec.fromSchema(smithy4sEndpoint.output)
implicit val inputCodec: Codec[I] = CirceJsonCodec.Codec.fromSchema(smithy4sEndpoint.input, codecCache)
implicit val outputCodec: Codec[O] = CirceJsonCodec.Codec.fromSchema(smithy4sEndpoint.output, codecCache)

def errorResponse(throwable: Throwable): F[E] = throwable match {
case smithy4sEndpoint.Error((_, e)) => e.pure
Expand All @@ -86,7 +89,7 @@ object ServerEndpoints {
impl(op)
}
case Some(errorSchema) =>
implicit val errorCodec: ErrorEncoder[E] = errorCodecFromSchema(errorSchema)
implicit val errorCodec: ErrorEncoder[E] = errorCodecFromSchema(errorSchema, codecCache)
Endpoint[F](methodName).apply[I, E, O] { (input: I) =>
val op = smithy4sEndpoint.wrap(input)
impl(op).attempt.flatMap {
Expand All @@ -98,8 +101,8 @@ object ServerEndpoints {
}
}

private def errorCodecFromSchema[A](s: ErrorSchema[A]): ErrorEncoder[A] = {
val circeCodec = CirceJsonCodec.fromSchema(s.schema)
private def errorCodecFromSchema[A](s: ErrorSchema[A], cache: CirceJsonCodec.Codec.Cache): ErrorEncoder[A] = {
val circeCodec = CirceJsonCodec.Codec.fromSchema(s.schema, cache)
(a: A) =>
ErrorPayload(
0,
Expand Down