Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure Redis.send(..) methods are tail recursive #302

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions src/main/scala/com/redis/RedisClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package com.redis

import java.net.SocketException
import javax.net.ssl.SSLContext

import com.redis.serialization.Format

import scala.annotation.tailrec

object RedisClient {
sealed trait SortOrder
case object ASC extends SortOrder
Expand Down Expand Up @@ -35,7 +36,8 @@ abstract class Redis(batch: Mode) extends IO with Protocol {
var handlers: Vector[(String, () => Any)] = Vector.empty
val commandBuffer = collection.mutable.ListBuffer.empty[CommandToSend]

def send[A](command: String, args: Seq[Any])(result: => A)(implicit format: Format): A = try {
@tailrec
private def doSend[A](command: String, args: Seq[Any])(result: => A)(implicit format: Format): A = try {
if (batch == BATCH) {
handlers :+= ((command, () => result))
commandBuffer += CommandToSend(command, args.map(format.apply))
Expand All @@ -46,14 +48,18 @@ abstract class Redis(batch: Mode) extends IO with Protocol {
}
} catch {
case e: RedisConnectionException =>
if (disconnect) send(command, args)(result)
if (disconnect) doSend(command, args)(result)
else throw e
case e: SocketException =>
if (disconnect) send(command, args)(result)
if (disconnect) doSend(command, args)(result)
else throw e
}

def send[A](command: String)(result: => A): A = try {
def send[A](command: String, args: Seq[Any])(result: => A)(implicit format: Format): A =
doSend(command, args)(result)

@tailrec
private def doSend[A](command: String)(result: => A): A = try {
if (batch == BATCH) {
handlers :+= ((command, () => result))
commandBuffer += CommandToSend(command, Seq.empty[Array[Byte]])
Expand All @@ -64,28 +70,33 @@ abstract class Redis(batch: Mode) extends IO with Protocol {
}
} catch {
case e: RedisConnectionException =>
if (disconnect) send(command)(result)
if (disconnect) doSend(command)(result)
else throw e
case e: SocketException =>
if (disconnect) send(command)(result)
if (disconnect) doSend(command)(result)
else throw e
}

def send[A](commands: List[CommandToSend])(result: => A): A = try {
def send[A](command: String)(result: => A): A = doSend(command)(result)

@tailrec
private def doSend[A](commands: List[CommandToSend])(result: => A): A = try {
val cs = commands.map { command =>
command.command.getBytes("UTF-8") +: command.args
}
write(Commands.multiMultiBulk(cs))
result
} catch {
case e: RedisConnectionException =>
if (disconnect) send(commands)(result)
if (disconnect) doSend(commands)(result)
else throw e
case e: SocketException =>
if (disconnect) send(commands)(result)
if (disconnect) doSend(commands)(result)
else throw e
}

def send[A](commands: List[CommandToSend])(result: => A): A = doSend(commands)(result)

def cmd(args: Seq[Array[Byte]]): Array[Byte] = Commands.multiBulk(args)

protected def flattenPairs(in: Iterable[Product2[Any, Any]]): List[Any] =
Expand Down
49 changes: 46 additions & 3 deletions src/test/scala/com/redis/RedisClientSpec.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package com.redis

import java.net.{ServerSocket, URI}
import com.redis.RedisClientSpec.DummyClientWithFaultyConnection

import java.net.{ServerSocket, URI}
import com.redis.api.ApiSpec
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers

import java.io.OutputStream
import scala.concurrent.Await
import scala.concurrent.duration._

Expand Down Expand Up @@ -71,7 +73,7 @@ class RedisClientSpec extends AnyFunSpec
r.close()
}}

// describe("test reconnect") {
describe("test reconnect") {
// it("should re-init after server restart") {
// val docker = new Docker(DefaultDockerClientConfig.createDefaultConfigBuilder().build()).client
//
Expand Down Expand Up @@ -104,5 +106,46 @@ class RedisClientSpec extends AnyFunSpec
//
// got shouldBe Some(value)
// }
// }

it("should not trigger a StackOverflowError in send(..) if Redis is down") {
val maxFailures = 10000 // Should be enough to trigger StackOverflowError
val r = new DummyClientWithFaultyConnection(maxFailures)
r.send("PING") {
/* PONG */
}
r.connected shouldBe true
}

}
}

object RedisClientSpec {

private class DummyClientWithFaultyConnection(maxFailures: Int) extends Redis(RedisClient.SINGLE) {

private var _connected = false
private var _failures = 0

override val host: String = null
override val port: Int = 0
override val timeout: Int = 0

override def onConnect(): Unit = ()

override def connected: Boolean = _connected

override def disconnect: Boolean = true

override def write_to_socket(data: Array[Byte])(op: OutputStream => Unit): Unit = ()

override def connect: Boolean =
if (_failures <= maxFailures) {
_failures += 1
throw RedisConnectionException("fail in order to trigger the reconnect")
} else {
_connected = true
true
}
}

}