Skip to content

Commit 8d44c0f

Browse files
committed
Adds a function to manually open the channel
1 parent ecb66b1 commit 8d44c0f

File tree

3 files changed

+35
-8
lines changed

3 files changed

+35
-8
lines changed

fs2/src/jsonrpclib/fs2/FS2Channel.scala

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@ import jsonrpclib.internals.MessageDispatcher
1414
import jsonrpclib.internals._
1515

1616
import scala.util.Try
17+
import _root_.fs2.concurrent.SignallingRef
1718

1819
trait FS2Channel[F[_]] extends Channel[F] {
1920
def withEndpoint(endpoint: Endpoint[F])(implicit F: Functor[F]): Resource[F, Unit] =
2021
Resource.make(mountEndpoint(endpoint))(_ => unmountEndpoint(endpoint.method))
2122

2223
def withEndpoints(endpoint: Endpoint[F], rest: Endpoint[F]*)(implicit F: Monad[F]): Resource[F, Unit] =
2324
(endpoint :: rest.toList).traverse_(withEndpoint)
25+
26+
def open: Resource[F, Unit]
2427
}
2528

2629
object FS2Channel {
@@ -42,16 +45,22 @@ object FS2Channel {
4245
val endpointsMap = startingEndpoints.map(ep => ep.method -> ep).toMap
4346
for {
4447
supervisor <- Stream.resource(Supervisor[F])
45-
ref <- Ref[F].of(State[F](Map.empty, endpointsMap, 0)).toStream
46-
impl = new Impl(payloadSink, ref, supervisor)
47-
_ <- Stream(()).concurrently(payloadStream.evalMap(impl.handleReceivedPayload))
48+
ref <- Ref[F].of(State[F](Map.empty, endpointsMap, 0, false)).toStream
49+
isOpen <- SignallingRef[F].of(false).toStream
50+
impl = new Impl(payloadSink, ref, isOpen, supervisor)
51+
_ <- Stream(()).concurrently {
52+
// Gatekeeping the pull until the channel is actually marked as open
53+
val wait = isOpen.waitUntil(identity)
54+
payloadStream.evalTap(_ => wait).evalMap(impl.handleReceivedPayload)
55+
}
4856
} yield impl
4957
}
5058

5159
private case class State[F[_]](
5260
pendingCalls: Map[CallId, OutputMessage => F[Unit]],
5361
endpoints: Map[String, Endpoint[F]],
54-
counter: Long
62+
counter: Long,
63+
isOpen: Boolean
5564
) {
5665
def nextCallId: (State[F], CallId) = (this.copy(counter = counter + 1), CallId.NumberId(counter))
5766
def storePendingCall(callId: CallId, handle: OutputMessage => F[Unit]): State[F] =
@@ -67,11 +76,15 @@ object FS2Channel {
6776
}
6877
def removeEndpoint(method: String): State[F] =
6978
copy(endpoints = endpoints.removed(method))
79+
80+
def open: State[F] = copy(isOpen = true)
81+
def close: State[F] = copy(isOpen = false)
7082
}
7183

7284
private class Impl[F[_]](
7385
private val sink: Payload => F[Unit],
7486
private val state: Ref[F, FS2Channel.State[F]],
87+
private val isOpen: SignallingRef[F, Boolean],
7588
supervisor: Supervisor[F]
7689
)(implicit F: Concurrent[F])
7790
extends MessageDispatcher[F]
@@ -88,6 +101,8 @@ object FS2Channel {
88101

89102
def unmountEndpoint(method: String): F[Unit] = state.update(_.removeEndpoint(method))
90103

104+
def open: Resource[F, Unit] = Resource.make[F, Unit](isOpen.set(true))(_ => isOpen.set(false))
105+
91106
protected def background[A](fa: F[A]): F[Unit] = supervisor.supervise(fa).void
92107
protected def reportError(params: Option[Payload], error: ProtocolError, method: String): F[Unit] = ???
93108
protected def getEndpoint(method: String): F[Option[Endpoint[F]]] = state.get.map(_.endpoints.get(method))

fs2/src/jsonrpclib/fs2/package.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@ package jsonrpclib
33
import _root_.fs2.Stream
44
import cats.MonadThrow
55
import cats.Monad
6+
import cats.effect.kernel.Resource
7+
import cats.effect.kernel.MonadCancel
68

79
package object fs2 {
810

911
private[jsonrpclib] implicit class EffectOps[F[_], A](private val fa: F[A]) extends AnyVal {
1012
def toStream: Stream[F, A] = Stream.eval(fa)
1113
}
1214

15+
private[jsonrpclib] implicit class ResourceOps[F[_], A](private val fa: Resource[F, A]) extends AnyVal {
16+
def asStream(implicit F: MonadCancel[F, Throwable]): Stream[F, A] = Stream.resource(fa)
17+
}
18+
1319
implicit def catsMonadic[F[_]: MonadThrow]: Monadic[F] = new Monadic[F] {
1420
def doFlatMap[A, B](fa: F[A])(f: A => F[B]): F[B] = Monad[F].flatMap(fa)(f)
1521

fs2/test/src/jsonrpclib/fs2/FS2ChannelSpec.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ object FS2ChannelSpec extends SimpleIOSuite {
2121
}
2222

2323
def testRes(name: TestName)(run: Stream[IO, Expectations]): Unit =
24-
test(name)(run.compile.lastOrError)
24+
test(name)(run.compile.lastOrError.timeout(10.second))
2525

2626
testRes("Round trip") {
2727
val endpoint: Endpoint[IO] = Endpoint[IO]("inc").simple((int: IntWrapper) => IO(IntWrapper(int.int + 1)))
@@ -31,8 +31,10 @@ object FS2ChannelSpec extends SimpleIOSuite {
3131
stdin <- Queue.bounded[IO, Payload](10).toStream
3232
serverSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), stdout.offer)
3333
clientSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdout), stdin.offer)
34-
_ <- Stream.resource(serverSideChannel.withEndpoint(endpoint))
34+
_ <- serverSideChannel.withEndpoint(endpoint).asStream
3535
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc")
36+
_ <- serverSideChannel.open.asStream
37+
_ <- clientSideChannel.open.asStream
3638
result <- remoteFunction(IntWrapper(1)).toStream
3739
} yield {
3840
expect.same(result, IntWrapper(2))
@@ -44,9 +46,11 @@ object FS2ChannelSpec extends SimpleIOSuite {
4446
for {
4547
stdout <- Queue.bounded[IO, Payload](10).toStream
4648
stdin <- Queue.bounded[IO, Payload](10).toStream
47-
_ <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), stdout.offer)
49+
serverSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), stdout.offer)
4850
clientSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdout), stdin.offer)
4951
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc")
52+
_ <- serverSideChannel.open.asStream
53+
_ <- clientSideChannel.open.asStream
5054
result <- remoteFunction(IntWrapper(1)).attempt.toStream
5155
} yield {
5256
expect.same(result, Left(ErrorPayload(-32601, "Method inc not found", None)))
@@ -65,8 +69,10 @@ object FS2ChannelSpec extends SimpleIOSuite {
6569
stdin <- Queue.bounded[IO, Payload](10).toStream
6670
serverSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdin), payload => stdout.offer(payload))
6771
clientSideChannel <- FS2Channel[IO](Stream.fromQueueUnterminated(stdout), payload => stdin.offer(payload))
68-
_ <- Stream.resource(serverSideChannel.withEndpoint(endpoint))
72+
_ <- serverSideChannel.withEndpoint(endpoint).asStream
6973
remoteFunction = clientSideChannel.simpleStub[IntWrapper, IntWrapper]("inc")
74+
_ <- serverSideChannel.open.asStream
75+
_ <- clientSideChannel.open.asStream
7076
timedResults <- (1 to 10).toList.map(IntWrapper(_)).parTraverse(remoteFunction).timed.toStream
7177
} yield {
7278
val (time, results) = timedResults

0 commit comments

Comments
 (0)