Skip to content

Commit 0fece01

Browse files
author
Michal Charvát
committed
fix: Fix bug in the publisher confirmations
1 parent 0ef2a8a commit 0fece01

File tree

5 files changed

+138
-77
lines changed

5 files changed

+138
-77
lines changed

core/src/main/scala/com/avast/clients/rabbitmq/DefaultRabbitMQClientFactory.scala

+27-27
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package com.avast.clients.rabbitmq
22

33
import cats.effect._
4-
import cats.effect.concurrent.{Deferred, Ref}
54
import cats.implicits.{catsSyntaxFlatMapOps, toFunctorOps, toTraverseOps}
65
import com.avast.bytes.Bytes
76
import com.avast.clients.rabbitmq.DefaultRabbitMQClientFactory.startConsumingQueue
@@ -270,41 +269,42 @@ private[rabbitmq] class DefaultRabbitMQClientFactory[F[_]: ConcurrentEffect: Tim
270269
producerConfig.properties.confirms match {
271270
case Some(PublisherConfirmsConfig(true, sendAttempts)) =>
272271
prepareProducer(producerConfig, connection) { (defaultProperties, channel, logger) =>
273-
Ref.of(Map.empty[Long, Deferred[F, Either[NotAcknowledgedPublish, Unit]]])
274-
.map {
275-
new PublishConfirmsRabbitMQProducer[F, A](
276-
producerConfig.name,
277-
producerConfig.exchange,
278-
channel,
279-
defaultProperties,
280-
_,
281-
sendAttempts,
282-
producerConfig.reportUnroutable,
283-
producerConfig.sizeLimitBytes,
284-
blocker,
285-
logger,
286-
monitor)
287-
}
272+
F.pure {
273+
new PublishConfirmsRabbitMQProducer[F, A](
274+
producerConfig.name,
275+
producerConfig.exchange,
276+
channel,
277+
defaultProperties,
278+
sendAttempts,
279+
producerConfig.reportUnroutable,
280+
producerConfig.sizeLimitBytes,
281+
blocker,
282+
logger,
283+
monitor
284+
)
285+
}
288286
}
289287
case _ =>
290288
prepareProducer(producerConfig, connection) { (defaultProperties, channel, logger) =>
291289
F.pure {
292-
new DefaultRabbitMQProducer[F, A](producerConfig.name,
293-
producerConfig.exchange,
294-
channel,
295-
defaultProperties,
296-
producerConfig.reportUnroutable,
297-
producerConfig.sizeLimitBytes,
298-
blocker,
299-
logger,
300-
monitor)
290+
new DefaultRabbitMQProducer[F, A](
291+
producerConfig.name,
292+
producerConfig.exchange,
293+
channel,
294+
defaultProperties,
295+
producerConfig.reportUnroutable,
296+
producerConfig.sizeLimitBytes,
297+
blocker,
298+
logger,
299+
monitor
300+
)
301301
}
302302
}
303303
}
304304
}
305305

306306
private def prepareProducer[T: ClassTag, A: ProductConverter](producerConfig: ProducerConfig, connection: RabbitMQConnection[F])(
307-
createProducer: (MessageProperties, ServerChannel, ImplicitContextLogger[F]) => F[T]) = {
307+
createProducer: (MessageProperties, ServerChannel, ImplicitContextLogger[F]) => F[T]) = {
308308
val logger: ImplicitContextLogger[F] = ImplicitContextLogger.createLogger[F, T]
309309

310310
connection
@@ -319,7 +319,7 @@ private[rabbitmq] class DefaultRabbitMQClientFactory[F[_]: ConcurrentEffect: Tim
319319
contentType = producerConfig.properties.contentType,
320320
contentEncoding = producerConfig.properties.contentEncoding,
321321
priority = producerConfig.properties.priority.map(Integer.valueOf)
322-
)
322+
)
323323
createProducer(defaultProperties, channel, logger)
324324
}
325325
}

core/src/main/scala/com/avast/clients/rabbitmq/publisher/BaseRabbitMQProducer.scala

+8-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import com.avast.clients.rabbitmq.JavaConverters._
99
import com.avast.clients.rabbitmq.api.CorrelationIdStrategy.FromPropertiesOrRandomNew
1010
import com.avast.clients.rabbitmq.api._
1111
import com.avast.clients.rabbitmq.logging.ImplicitContextLogger
12-
import com.avast.clients.rabbitmq.{CorrelationId, ProductConverter, ServerChannel, startAndForget}
12+
import com.avast.clients.rabbitmq.{startAndForget, CorrelationId, ProductConverter, ServerChannel}
1313
import com.avast.metrics.scalaeffectapi.Monitor
1414
import com.rabbitmq.client.AMQP.BasicProperties
1515
import com.rabbitmq.client.{AlreadyClosedException, ReturnListener}
@@ -63,17 +63,21 @@ abstract class BaseRabbitMQProducer[F[_], A: ProductConverter](name: String,
6363
}
6464
}
6565

66-
protected def basicSend(routingKey: String, body: Bytes, properties: MessageProperties)(implicit correlationId: CorrelationId): F[Unit] = {
66+
protected def basicSend(routingKey: String, body: Bytes, properties: MessageProperties, preSendAction: Long => Unit = (_: Long) => ())(
67+
implicit correlationId: CorrelationId): F[Long] = {
6768
for {
6869
_ <- logger.debug(s"Sending message with ${body.size()} B to exchange $exchangeName with routing key '$routingKey' and $properties")
69-
_ <- blocker.delay {
70+
sequenceNumber <- blocker.delay {
7071
sendLock.synchronized {
7172
// see https://www.rabbitmq.com/api-guide.html#channel-threads
73+
val sequenceNumber = channel.getNextPublishSeqNo
74+
preSendAction(sequenceNumber)
7275
channel.basicPublish(exchangeName, routingKey, properties.asAMQP, body.toByteArray)
76+
sequenceNumber
7377
}
7478
}
7579
_ <- sentMeter.mark
76-
} yield ()
80+
} yield sequenceNumber
7781
}
7882

7983
private def processErrors(from: F[Unit], routingKey: String)(implicit correlationId: CorrelationId): F[Unit] = {

core/src/main/scala/com/avast/clients/rabbitmq/publisher/DefaultRabbitMQProducer.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.avast.clients.rabbitmq.publisher
22

33
import cats.effect.{Blocker, ConcurrentEffect, ContextShift}
4+
import cats.implicits.toFunctorOps
45
import com.avast.bytes.Bytes
56
import com.avast.clients.rabbitmq.api.MessageProperties
67
import com.avast.clients.rabbitmq.logging.ImplicitContextLogger
@@ -26,5 +27,5 @@ class DefaultRabbitMQProducer[F[_], A: ProductConverter](name: String,
2627
logger,
2728
monitor) {
2829
override def sendMessage(routingKey: String, body: Bytes, properties: MessageProperties)(implicit correlationId: CorrelationId): F[Unit] =
29-
basicSend(routingKey, body, properties)
30+
basicSend(routingKey, body, properties).void
3031
}

core/src/main/scala/com/avast/clients/rabbitmq/publisher/PublishConfirmsRabbitMQProducer.scala

+20-23
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
package com.avast.clients.rabbitmq.publisher
22

3-
import cats.effect.concurrent.{Deferred, Ref}
3+
import cats.effect.concurrent.Deferred
44
import cats.effect.{Blocker, ConcurrentEffect, ContextShift}
55
import cats.syntax.flatMap._
66
import cats.syntax.functor._
77
import com.avast.bytes.Bytes
88
import com.avast.clients.rabbitmq.api.{MaxAttemptsReached, MessageProperties, NotAcknowledgedPublish}
99
import com.avast.clients.rabbitmq.logging.ImplicitContextLogger
10-
import com.avast.clients.rabbitmq.publisher.PublishConfirmsRabbitMQProducer.SentMessages
11-
import com.avast.clients.rabbitmq.{CorrelationId, ProductConverter, ServerChannel, startAndForget}
10+
import com.avast.clients.rabbitmq.{startAndForget, CorrelationId, ProductConverter, ServerChannel}
1211
import com.avast.metrics.scalaeffectapi.Monitor
1312
import com.rabbitmq.client.ConfirmListener
1413

14+
import java.util.concurrent.ConcurrentHashMap
15+
import scala.collection.JavaConverters._
1516
class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String,
1617
exchangeName: String,
1718
channel: ServerChannel,
1819
defaultProperties: MessageProperties,
19-
sentMessages: SentMessages[F],
2020
sendAttempts: Int,
2121
reportUnroutable: Boolean,
2222
sizeLimitBytes: Option[Int],
@@ -39,6 +39,10 @@ class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String,
3939
private val acked = monitor.meter("acked")
4040
private val nacked = monitor.meter("nacked")
4141

42+
private[rabbitmq] val confirmationCallbacks = {
43+
new ConcurrentHashMap[Long, Deferred[F, Either[NotAcknowledgedPublish, Unit]]]().asScala
44+
}
45+
4246
override def sendMessage(routingKey: String, body: Bytes, properties: MessageProperties)(implicit correlationId: CorrelationId): F[Unit] =
4347
sendWithAck(routingKey, body, properties, 1)
4448

@@ -48,34 +52,29 @@ class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String,
4852
if (attemptCount > sendAttempts) {
4953
F.raiseError(MaxAttemptsReached("Exhausted max number of attempts"))
5054
} else {
51-
val messageId = channel.getNextPublishSeqNo
5255
for {
53-
defer <- Deferred.apply[F, Either[NotAcknowledgedPublish, Unit]]
54-
_ <- sentMessages.update(_ + (messageId -> defer))
55-
_ <- basicSend(routingKey, body, properties)
56-
result <- defer.get
56+
confirmationCallback <- Deferred.apply[F, Either[NotAcknowledgedPublish, Unit]]
57+
sequenceNumber <- basicSend(routingKey, body, properties, (sequenceNumber: Long) => {
58+
confirmationCallbacks += sequenceNumber -> confirmationCallback
59+
})
60+
result <- confirmationCallback.get
61+
_ <- F.delay(confirmationCallbacks -= sequenceNumber)
5762
_ <- result match {
5863
case Left(err) =>
5964
val sendResult = if (sendAttempts > 1) {
60-
clearProcessedMessage(messageId) >> sendWithAck(routingKey, body, properties, attemptCount + 1)
65+
sendWithAck(routingKey, body, properties, attemptCount + 1)
6166
} else {
6267
F.raiseError(err)
6368
}
64-
6569
nacked.mark >> sendResult
6670
case Right(_) =>
67-
acked.mark >> clearProcessedMessage(messageId)
71+
acked.mark
6872
}
6973
} yield ()
7074
}
7175
}
7276

73-
private def clearProcessedMessage(messageId: Long): F[Unit] = {
74-
sentMessages.update(_ - messageId)
75-
}
76-
7777
private object DefaultConfirmListener extends ConfirmListener {
78-
import cats.syntax.foldable._
7978

8079
override def handleAck(deliveryTag: Long, multiple: Boolean): Unit = {
8180
startAndForget {
@@ -92,12 +91,10 @@ class PublishConfirmsRabbitMQProducer[F[_], A: ProductConverter](name: String,
9291
}
9392

9493
private def completeDefer(deliveryTag: Long, result: Either[NotAcknowledgedPublish, Unit]): F[Unit] = {
95-
sentMessages.get.flatMap(_.get(deliveryTag).traverse_(_.complete(result)))
94+
confirmationCallbacks.get(deliveryTag) match {
95+
case Some(callback) => callback.complete(result)
96+
case None => logger.plainWarn("Received confirmation for unknown delivery tag. That is unexpected state.")
97+
}
9698
}
9799
}
98-
99-
}
100-
101-
object PublishConfirmsRabbitMQProducer {
102-
type SentMessages[F[_]] = Ref[F, Map[Long, Deferred[F, Either[NotAcknowledgedPublish, Unit]]]]
103100
}
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,33 @@
11
package com.avast.clients.rabbitmq
22

3-
import cats.effect.concurrent.{Deferred, Ref}
4-
import cats.syntax.parallel._
3+
import cats.implicits.catsSyntaxParallelAp
54
import com.avast.bytes.Bytes
65
import com.avast.clients.rabbitmq.api.{MessageProperties, NotAcknowledgedPublish}
76
import com.avast.clients.rabbitmq.logging.ImplicitContextLogger
87
import com.avast.clients.rabbitmq.publisher.PublishConfirmsRabbitMQProducer
9-
import com.avast.clients.rabbitmq.publisher.PublishConfirmsRabbitMQProducer.SentMessages
108
import com.avast.metrics.scalaeffectapi.Monitor
119
import com.rabbitmq.client.impl.recovery.AutorecoveringChannel
1210
import monix.eval.Task
1311
import monix.execution.Scheduler.Implicits.global
12+
import org.junit.runner.manipulation.InvalidOrderingException
1413
import org.mockito.Matchers
1514
import org.mockito.Matchers.any
1615
import org.mockito.Mockito.{times, verify, when}
1716

17+
import scala.concurrent.Await
18+
import scala.concurrent.duration.DurationInt
1819
import scala.util.Random
1920

2021
class PublisherConfirmsRabbitMQProducerTest extends TestBase {
21-
test("message is acked after one retry") {
22+
23+
test("Message is acked after one retry") {
2224
val exchangeName = Random.nextString(10)
2325
val routingKey = Random.nextString(10)
2426
val seqNumber = 1L
2527
val seqNumber2 = 2L
2628

2729
val channel = mock[AutorecoveringChannel]
28-
val ref = Ref.of[Task, Map[Long, Deferred[Task, Either[NotAcknowledgedPublish, Unit]]]](Map.empty).await
29-
val updatedState1 = updateMessageState(ref, seqNumber)(Left(NotAcknowledgedPublish("abcd", messageId = seqNumber)))
30-
val updatedState2 = updateMessageState(ref, seqNumber2)(Right())
30+
when(channel.getNextPublishSeqNo).thenReturn(seqNumber, seqNumber2)
3131

3232
val producer = new PublishConfirmsRabbitMQProducer[Task, Bytes](
3333
name = "test",
@@ -39,15 +39,21 @@ class PublisherConfirmsRabbitMQProducerTest extends TestBase {
3939
sizeLimitBytes = None,
4040
blocker = TestBase.testBlocker,
4141
logger = ImplicitContextLogger.createLogger,
42-
sentMessages = ref,
4342
sendAttempts = 2
4443
)
45-
when(channel.getNextPublishSeqNo).thenReturn(seqNumber, seqNumber2)
4644

47-
producer.send(routingKey, Bytes.copyFrom(Array.fill(499)(32.toByte))).parProduct(updatedState1.parProduct(updatedState2)).await
45+
val body = Bytes.copyFrom(Array.fill(499)(32.toByte))
46+
47+
val publishTask = producer.send(routingKey, body).runToFuture
48+
49+
updateMessageState(producer, seqNumber)(Left(NotAcknowledgedPublish("abcd", messageId = seqNumber))).parProduct {
50+
updateMessageState(producer, seqNumber2)(Right())
51+
}.await
52+
53+
Await.result(publishTask, 10.seconds)
4854

4955
verify(channel, times(2))
50-
.basicPublish(Matchers.eq(exchangeName), Matchers.eq(routingKey), any(), Matchers.eq(Bytes.copyFrom(Array.fill(499)(32.toByte)).toByteArray))
56+
.basicPublish(Matchers.eq(exchangeName), Matchers.eq(routingKey), any(), Matchers.eq(body.toByteArray))
5157
}
5258

5359
test("Message not acked returned if number of attempts exhausted") {
@@ -56,8 +62,7 @@ class PublisherConfirmsRabbitMQProducerTest extends TestBase {
5662
val seqNumber = 1L
5763

5864
val channel = mock[AutorecoveringChannel]
59-
val ref = Ref.of[Task, Map[Long, Deferred[Task, Either[NotAcknowledgedPublish, Unit]]]](Map.empty).await
60-
val updatedState = updateMessageState(ref, seqNumber)(Left(NotAcknowledgedPublish("abcd", messageId = seqNumber)))
65+
when(channel.getNextPublishSeqNo).thenReturn(seqNumber)
6166

6267
val producer = new PublishConfirmsRabbitMQProducer[Task, Bytes](
6368
name = "test",
@@ -69,22 +74,76 @@ class PublisherConfirmsRabbitMQProducerTest extends TestBase {
6974
sizeLimitBytes = None,
7075
blocker = TestBase.testBlocker,
7176
logger = ImplicitContextLogger.createLogger,
72-
sentMessages = ref,
7377
sendAttempts = 1
7478
)
75-
when(channel.getNextPublishSeqNo).thenReturn(seqNumber)
79+
80+
val body = Bytes.copyFrom(Array.fill(499)(32.toByte))
81+
82+
val publishTask = producer.send(routingKey, body).runToFuture
7683

7784
assertThrows[NotAcknowledgedPublish] {
78-
producer.send(routingKey, Bytes.copyFrom(Array.fill(499)(32.toByte))).parProduct(updatedState).await
85+
updateMessageState(producer, seqNumber)(Left(NotAcknowledgedPublish("abcd", messageId = seqNumber))).await
86+
Await.result(publishTask, 10.seconds)
7987
}
8088

81-
verify(channel).basicPublish(Matchers.eq(exchangeName), Matchers.eq(routingKey), any(), Matchers.eq(Bytes.copyFrom(Array.fill(499)(32.toByte)).toByteArray))
89+
verify(channel).basicPublish(Matchers.eq(exchangeName), Matchers.eq(routingKey), any(), Matchers.eq(body.toByteArray))
90+
}
91+
92+
test("Multiple messages are fully acked") {
93+
val exchangeName = Random.nextString(10)
94+
val routingKey = Random.nextString(10)
95+
96+
val channel = mock[AutorecoveringChannel]
97+
98+
val seqNumbers = 1 to 500
99+
val iterator = seqNumbers.iterator
100+
when(channel.getNextPublishSeqNo).thenAnswer(_ => { iterator.next() })
101+
102+
val producer = new PublishConfirmsRabbitMQProducer[Task, Bytes](
103+
name = "test",
104+
exchangeName = exchangeName,
105+
channel = channel,
106+
monitor = Monitor.noOp(),
107+
defaultProperties = MessageProperties.empty,
108+
reportUnroutable = false,
109+
sizeLimitBytes = None,
110+
blocker = TestBase.testBlocker,
111+
logger = ImplicitContextLogger.createLogger,
112+
sendAttempts = 2
113+
)
114+
115+
val body = Bytes.copyFrom(Array.fill(499)(32.toByte))
116+
117+
val publishTasks = Task.parSequenceUnordered {
118+
seqNumbers.map { _ =>
119+
producer.send(routingKey, body)
120+
}
121+
}.runToFuture
122+
123+
Task
124+
.parSequenceUnordered(seqNumbers.map { seqNumber =>
125+
updateMessageState(producer, seqNumber)(Right())
126+
})
127+
.await(15.seconds)
128+
129+
Await.result(publishTasks, 15.seconds)
130+
131+
verify(channel, times(seqNumbers.length))
132+
.basicPublish(Matchers.eq(exchangeName), Matchers.eq(routingKey), any(), Matchers.eq(body.toByteArray))
82133
}
83134

84-
private def updateMessageState(ref: SentMessages[Task], messageId: Long)(result: Either[NotAcknowledgedPublish, Unit]): Task[Unit] = {
85-
ref.get.flatMap(map => map.get(messageId) match {
86-
case Some(value) => value.complete(result)
87-
case None => updateMessageState(ref, messageId)(result)
88-
})
135+
private def updateMessageState(producer: PublishConfirmsRabbitMQProducer[Task, Bytes], messageId: Long, attempt: Int = 1)(
136+
result: Either[NotAcknowledgedPublish, Unit]): Task[Unit] = {
137+
Task
138+
.delay(producer.confirmationCallbacks.get(messageId))
139+
.flatMap {
140+
case Some(value) => value.complete(result)
141+
case None =>
142+
if (attempt < 90) {
143+
Task.sleep(100.millis) >> updateMessageState(producer, messageId, attempt + 1)(result)
144+
} else {
145+
throw new InvalidOrderingException(s"The message ID $messageId is not present in the list of callbacks")
146+
}
147+
}
89148
}
90149
}

0 commit comments

Comments
 (0)