Skip to content

Commit ebfc4f1

Browse files
authored
GH-1704: Interceptor Improvements
Resolves #1704 - Extract `ThreadStateProcessor` - Avoid all the if/else tests around calling the interceptor common methods - Add `afterRecord` to the record interceptor so sleuth can defer cleaning up until after the error handler. * Remove ThreadStateProcessor (TSP), rename CATSP to TSP. - confusing for interceptor implementors - no components ever call `setupThreadState()`; ARP and CEH build state during normal processing.
1 parent 0e0f615 commit ebfc4f1

File tree

7 files changed

+80
-73
lines changed

7 files changed

+80
-73
lines changed

spring-kafka/src/main/java/org/springframework/kafka/listener/BatchInterceptor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
*
3333
*/
3434
@FunctionalInterface
35-
public interface BatchInterceptor<K, V> extends BeforeAfterPollProcessor<K, V> {
35+
public interface BatchInterceptor<K, V> extends ThreadStateProcessor {
3636

3737
/**
3838
* Perform some action on the records or return a different one. If null is returned

spring-kafka/src/main/java/org/springframework/kafka/listener/CompositeBatchInterceptor.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ public void failure(ConsumerRecords<K, V> records, Exception exception, Consumer
7575
}
7676

7777
@Override
78-
public void beforePoll(Consumer<K, V> consumer) {
79-
this.delegates.forEach(del -> del.beforePoll(consumer));
78+
public void setupThreadState(Consumer<?, ?> consumer) {
79+
this.delegates.forEach(del -> del.setupThreadState(consumer));
8080
}
8181

8282
@Override
83-
public void clearThreadState(Consumer<K, V> consumer) {
83+
public void clearThreadState(Consumer<?, ?> consumer) {
8484
this.delegates.forEach(del -> del.clearThreadState(consumer));
8585
}
8686

spring-kafka/src/main/java/org/springframework/kafka/listener/CompositeRecordInterceptor.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,13 @@ public void failure(ConsumerRecord<K, V> record, Exception exception, Consumer<K
7878
}
7979

8080
@Override
81-
public void beforePoll(Consumer<K, V> consumer) {
82-
this.delegates.forEach(del -> del.beforePoll(consumer));
81+
public void setupThreadState(Consumer<?, ?> consumer) {
82+
this.delegates.forEach(del -> del.setupThreadState(consumer));
8383
}
8484

8585
@Override
86-
public void clearThreadState(Consumer<K, V> consumer) {
86+
public void clearThreadState(Consumer<?, ?> consumer) {
8787
this.delegates.forEach(del -> del.clearThreadState(consumer));
8888
}
89+
8990
}

spring-kafka/src/main/java/org/springframework/kafka/listener/KafkaMessageListenerContainer.java

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,8 @@ private final class ListenerConsumer implements SchedulingAwareRunnable, Consume
632632

633633
private final BatchInterceptor<K, V> commonBatchInterceptor = getBatchInterceptor();
634634

635+
private final ThreadStateProcessor pollThreadStateProcessor;
636+
635637
private final ConsumerSeekCallback seekCallback = new InitialOrIdleSeekCallback();
636638

637639
private final long maxPollInterval;
@@ -746,12 +748,14 @@ private final class ListenerConsumer implements SchedulingAwareRunnable, Consume
746748
this.batchListener = (BatchMessageListener<K, V>) listener;
747749
this.isBatchListener = true;
748750
this.wantsFullRecords = this.batchListener.wantsPollResult();
751+
this.pollThreadStateProcessor = setUpPollProcessor(true);
749752
}
750753
else if (listener instanceof MessageListener) {
751754
this.listener = (MessageListener<K, V>) listener;
752755
this.batchListener = null;
753756
this.isBatchListener = false;
754757
this.wantsFullRecords = false;
758+
this.pollThreadStateProcessor = setUpPollProcessor(false);
755759
}
756760
else {
757761
throw new IllegalArgumentException("Listener must be one of 'MessageListener', "
@@ -802,6 +806,19 @@ else if (listener instanceof MessageListener) {
802806
this.pausedPartitions = new HashSet<>();
803807
}
804808

809+
@Nullable
810+
private ThreadStateProcessor setUpPollProcessor(boolean batch) {
811+
if (batch) {
812+
if (this.commonBatchInterceptor != null) {
813+
return this.commonBatchInterceptor;
814+
}
815+
}
816+
else if (this.commonRecordInterceptor != null) {
817+
return this.commonRecordInterceptor;
818+
}
819+
return null;
820+
}
821+
805822
@Nullable
806823
private CommonErrorHandler determineCommonErrorHandler(@Nullable GenericErrorHandler<?> errHandler) {
807824
CommonErrorHandler common = getCommonErrorHandler();
@@ -1314,22 +1331,8 @@ private void invokeIfHaveRecords(@Nullable ConsumerRecords<K, V> records) {
13141331
}
13151332

13161333
private void clearThreadState() {
1317-
if (this.isBatchListener) {
1318-
interceptClearThreadState(this.commonBatchInterceptor);
1319-
}
1320-
else {
1321-
interceptClearThreadState(this.commonRecordInterceptor);
1322-
}
1323-
}
1324-
1325-
private void interceptClearThreadState(BeforeAfterPollProcessor<K, V> processor) {
1326-
if (processor != null) {
1327-
try {
1328-
processor.clearThreadState(this.consumer);
1329-
}
1330-
catch (Exception e) {
1331-
this.logger.error(e, "BeforeAfterPollProcessor.clearThreadState threw an exception");
1332-
}
1334+
if (this.pollThreadStateProcessor != null) {
1335+
this.pollThreadStateProcessor.clearThreadState(this.consumer);
13331336
}
13341337
}
13351338

@@ -1480,22 +1483,8 @@ private ConsumerRecords<K, V> pollConsumer() {
14801483
}
14811484

14821485
private void beforePoll() {
1483-
if (this.isBatchListener) {
1484-
interceptBeforePoll(this.commonBatchInterceptor);
1485-
}
1486-
else {
1487-
interceptBeforePoll(this.commonRecordInterceptor);
1488-
}
1489-
}
1490-
1491-
private void interceptBeforePoll(BeforeAfterPollProcessor<K, V> processor) {
1492-
if (processor != null) {
1493-
try {
1494-
processor.beforePoll(this.consumer);
1495-
}
1496-
catch (Exception e) {
1497-
this.logger.error(e, "BeforeAfterPollProcessor.beforePoll threw an exception");
1498-
}
1486+
if (this.pollThreadStateProcessor != null) {
1487+
this.pollThreadStateProcessor.setupThreadState(this.consumer);
14991488
}
15001489
}
15011490

@@ -2294,6 +2283,9 @@ private void invokeRecordListenerInTx(final ConsumerRecords<K, V> records) {
22942283
TransactionSupport.clearTransactionIdSuffix();
22952284
}
22962285
}
2286+
if (this.commonRecordInterceptor != null) {
2287+
this.commonRecordInterceptor.afterRecord(record, this.consumer);
2288+
}
22972289
if (this.nackSleep >= 0) {
22982290
handleNack(records, record);
22992291
break;
@@ -2374,6 +2366,9 @@ private void doInvokeWithRecords(final ConsumerRecords<K, V> records) {
23742366
}
23752367
this.logger.trace(() -> "Processing " + ListenerUtils.recordToString(record));
23762368
doInvokeRecordListener(record, iterator);
2369+
if (this.commonRecordInterceptor != null) {
2370+
this.commonRecordInterceptor.afterRecord(record, this.consumer);
2371+
}
23772372
if (this.nackSleep >= 0) {
23782373
handleNack(records, record);
23792374
break;

spring-kafka/src/main/java/org/springframework/kafka/listener/RecordInterceptor.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
*
3434
*/
3535
@FunctionalInterface
36-
public interface RecordInterceptor<K, V> extends BeforeAfterPollProcessor<K, V> {
36+
public interface RecordInterceptor<K, V> extends ThreadStateProcessor {
3737

3838
/**
3939
* Perform some action on the record or return a different one. If null is returned
@@ -81,4 +81,15 @@ default void success(ConsumerRecord<K, V> record, Consumer<K, V> consumer) {
8181
default void failure(ConsumerRecord<K, V> record, Exception exception, Consumer<K, V> consumer) {
8282
}
8383

84+
/**
85+
* Called when processing the record is complete either
86+
* {@link #success(ConsumerRecord, Consumer)} or
87+
* {@link #failure(ConsumerRecord, Exception, Consumer)}.
88+
* @param record the record.
89+
* @param consumer the consumer.
90+
* @since 2.8
91+
*/
92+
default void afterRecord(ConsumerRecord<K, V> record, Consumer<K, V> consumer) {
93+
}
94+
8495
}

spring-kafka/src/main/java/org/springframework/kafka/listener/BeforeAfterPollProcessor.java renamed to spring-kafka/src/main/java/org/springframework/kafka/listener/ThreadStateProcessor.java

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,42 +19,32 @@
1919
import org.apache.kafka.clients.consumer.Consumer;
2020

2121
/**
22-
* An interceptor for consumer poll operation.
22+
* A general interface for managing thread-bound resources when a {@link Consumer} is
23+
* available.
2324
*
24-
* @param <K> the key type.
25-
* @param <V> the value type.
26-
*
27-
* @author Gary Russell
2825
* @author Karol Dowbecki
29-
* @author Artem Bilan
26+
* @author Gary Russell
3027
* @since 2.8
3128
*
3229
*/
33-
public interface BeforeAfterPollProcessor<K, V> {
30+
public interface ThreadStateProcessor {
3431

3532
/**
36-
* Called before consumer is polled.
37-
* <p>
38-
* It can be used to set up thread-bound resources which will be available for the
39-
* entire duration of the consumer poll operation e.g. logging with MDC.
40-
* </p>
33+
* Call to set up thread-bound resources which will be available for the
34+
* entire duration of enclosed operation involving a {@link Consumer}.
4135
*
4236
* @param consumer the consumer.
4337
*/
44-
default void beforePoll(Consumer<K, V> consumer) {
38+
default void setupThreadState(Consumer<?, ?> consumer) {
4539
}
4640

4741
/**
48-
* Called after records were processed by listener and error handler.
49-
* <p>
50-
* It can be used to clear thread-bound resources which were set up in {@link #beforePoll(Consumer)}.
51-
* This is the last method called by the {@link MessageListenerContainer} before the next record
52-
* processing cycle starts.
53-
* </p>
42+
* Call to clear thread-bound resources which were set up in
43+
* {@link #setupThreadState(Consumer)}.
5444
*
5545
* @param consumer the consumer.
5646
*/
57-
default void clearThreadState(Consumer<K, V> consumer) {
47+
default void clearThreadState(Consumer<?, ?> consumer) {
5848
}
5949

6050
}

spring-kafka/src/test/java/org/springframework/kafka/listener/KafkaMessageListenerContainerTests.java

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3498,7 +3498,7 @@ void stopImmediately() throws InterruptedException {
34983498
}
34993499

35003500
@Test
3501-
@SuppressWarnings({"unchecked", "deprecated"})
3501+
@SuppressWarnings({"unchecked", "deprecation"})
35023502
public void testInvokeRecordInterceptorSuccess() throws Exception {
35033503
ConsumerFactory<Integer, String> cf = mock(ConsumerFactory.class);
35043504
Consumer<Integer, String> consumer = mock(Consumer.class);
@@ -3538,15 +3538,17 @@ public void onMessage(ConsumerRecord<Integer, String> data) {
35383538

35393539
CountDownLatch afterLatch = new CountDownLatch(1);
35403540
RecordInterceptor<Integer, String> recordInterceptor = spy(new RecordInterceptor<Integer, String>() {
3541+
35413542
@Override
35423543
public ConsumerRecord<Integer, String> intercept(ConsumerRecord<Integer, String> record) {
35433544
return record;
35443545
}
35453546

35463547
@Override
3547-
public void clearThreadState(Consumer<Integer, String> consumer) {
3548+
public void clearThreadState(Consumer<?, ?> consumer) {
35483549
afterLatch.countDown();
35493550
}
3551+
35503552
});
35513553

35523554
KafkaMessageListenerContainer<Integer, String> container =
@@ -3557,20 +3559,22 @@ public void clearThreadState(Consumer<Integer, String> consumer) {
35573559
assertThat(afterLatch.await(10, TimeUnit.SECONDS)).isTrue();
35583560

35593561
InOrder inOrder = inOrder(recordInterceptor, messageListener, consumer);
3560-
inOrder.verify(recordInterceptor).beforePoll(eq(consumer));
3562+
inOrder.verify(recordInterceptor).setupThreadState(eq(consumer));
35613563
inOrder.verify(consumer).poll(Duration.ofMillis(ContainerProperties.DEFAULT_POLL_TIMEOUT));
35623564
inOrder.verify(recordInterceptor).intercept(eq(firstRecord), eq(consumer));
35633565
inOrder.verify(messageListener).onMessage(eq(firstRecord));
35643566
inOrder.verify(recordInterceptor).success(eq(firstRecord), eq(consumer));
3567+
inOrder.verify(recordInterceptor).afterRecord(eq(firstRecord), eq(consumer));
35653568
inOrder.verify(recordInterceptor).intercept(eq(secondRecord), eq(consumer));
35663569
inOrder.verify(messageListener).onMessage(eq(secondRecord));
35673570
inOrder.verify(recordInterceptor).success(eq(secondRecord), eq(consumer));
3571+
inOrder.verify(recordInterceptor).afterRecord(eq(secondRecord), eq(consumer));
35683572
inOrder.verify(recordInterceptor).clearThreadState(eq(consumer));
35693573
container.stop();
35703574
}
35713575

35723576
@Test
3573-
@SuppressWarnings({"unchecked", "deprecated"})
3577+
@SuppressWarnings({"unchecked", "deprecation"})
35743578
public void testInvokeRecordInterceptorFailure() throws Exception {
35753579
ConsumerFactory<Integer, String> cf = mock(ConsumerFactory.class);
35763580
Consumer<Integer, String> consumer = mock(Consumer.class);
@@ -3608,15 +3612,17 @@ public void onMessage(ConsumerRecord<Integer, String> data) {
36083612

36093613
CountDownLatch afterLatch = new CountDownLatch(1);
36103614
RecordInterceptor<Integer, String> recordInterceptor = spy(new RecordInterceptor<Integer, String>() {
3615+
36113616
@Override
36123617
public ConsumerRecord<Integer, String> intercept(ConsumerRecord<Integer, String> record) {
36133618
return record;
36143619
}
36153620

36163621
@Override
3617-
public void clearThreadState(Consumer<Integer, String> consumer) {
3622+
public void clearThreadState(Consumer<?, ?> consumer) {
36183623
afterLatch.countDown();
36193624
}
3625+
36203626
});
36213627

36223628
KafkaMessageListenerContainer<Integer, String> container =
@@ -3627,11 +3633,12 @@ public void clearThreadState(Consumer<Integer, String> consumer) {
36273633
assertThat(afterLatch.await(10, TimeUnit.SECONDS)).isTrue();
36283634

36293635
InOrder inOrder = inOrder(recordInterceptor, messageListener, consumer);
3630-
inOrder.verify(recordInterceptor).beforePoll(eq(consumer));
3636+
inOrder.verify(recordInterceptor).setupThreadState(eq(consumer));
36313637
inOrder.verify(consumer).poll(Duration.ofMillis(ContainerProperties.DEFAULT_POLL_TIMEOUT));
36323638
inOrder.verify(recordInterceptor).intercept(eq(record), eq(consumer));
36333639
inOrder.verify(messageListener).onMessage(eq(record));
36343640
inOrder.verify(recordInterceptor).failure(eq(record), any(), eq(consumer));
3641+
inOrder.verify(recordInterceptor).afterRecord(eq(record), eq(consumer));
36353642
inOrder.verify(recordInterceptor).clearThreadState(eq(consumer));
36363643
container.stop();
36373644
}
@@ -3677,14 +3684,17 @@ public void onMessage(List<ConsumerRecord<Integer, String>> data) {
36773684
BatchInterceptor<Integer, String> batchInterceptor = spy(new BatchInterceptor<Integer, String>() {
36783685

36793686
@Override
3680-
public ConsumerRecords<Integer, String> intercept(ConsumerRecords<Integer, String> records, Consumer<Integer, String> consumer) {
3687+
public ConsumerRecords<Integer, String> intercept(ConsumerRecords<Integer, String> records,
3688+
Consumer<Integer, String> consumer) {
3689+
36813690
return records;
36823691
}
36833692

36843693
@Override
3685-
public void clearThreadState(Consumer<Integer, String> consumer) {
3694+
public void clearThreadState(Consumer<?, ?> consumer) {
36863695
afterLatch.countDown();
36873696
}
3697+
36883698
});
36893699

36903700
KafkaMessageListenerContainer<Integer, String> container =
@@ -3695,7 +3705,7 @@ public void clearThreadState(Consumer<Integer, String> consumer) {
36953705
assertThat(afterLatch.await(10, TimeUnit.SECONDS)).isTrue();
36963706

36973707
InOrder inOrder = inOrder(batchInterceptor, batchMessageListener, consumer);
3698-
inOrder.verify(batchInterceptor).beforePoll(eq(consumer));
3708+
inOrder.verify(batchInterceptor).setupThreadState(eq(consumer));
36993709
inOrder.verify(consumer).poll(Duration.ofMillis(ContainerProperties.DEFAULT_POLL_TIMEOUT));
37003710
inOrder.verify(batchInterceptor).intercept(eq(consumerRecords), eq(consumer));
37013711
inOrder.verify(batchMessageListener).onMessage(eq(List.of(firstRecord, secondRecord)));
@@ -3743,8 +3753,7 @@ public void onMessage(List<ConsumerRecord<Integer, String>> data) {
37433753
containerProps.setClientId("clientId");
37443754

37453755
CountDownLatch afterLatch = new CountDownLatch(1);
3746-
BatchInterceptor<Integer, String> batchInterceptor = spy(
3747-
new BatchInterceptor<Integer, String>() {
3756+
BatchInterceptor<Integer, String> batchInterceptor = spy(new BatchInterceptor<Integer, String>() {
37483757

37493758
@Override
37503759
public ConsumerRecords<Integer, String> intercept(ConsumerRecords<Integer, String> records,
@@ -3753,9 +3762,10 @@ public ConsumerRecords<Integer, String> intercept(ConsumerRecords<Integer, Strin
37533762
}
37543763

37553764
@Override
3756-
public void clearThreadState(Consumer<Integer, String> consumer) {
3765+
public void clearThreadState(Consumer<?, ?> consumer) {
37573766
afterLatch.countDown();
37583767
}
3768+
37593769
});
37603770

37613771
KafkaMessageListenerContainer<Integer, String> container =
@@ -3766,7 +3776,7 @@ public void clearThreadState(Consumer<Integer, String> consumer) {
37663776
assertThat(afterLatch.await(10, TimeUnit.SECONDS)).isTrue();
37673777

37683778
InOrder inOrder = inOrder(batchInterceptor, batchMessageListener, consumer);
3769-
inOrder.verify(batchInterceptor).beforePoll(eq(consumer));
3779+
inOrder.verify(batchInterceptor).setupThreadState(eq(consumer));
37703780
inOrder.verify(consumer).poll(Duration.ofMillis(ContainerProperties.DEFAULT_POLL_TIMEOUT));
37713781
inOrder.verify(batchInterceptor).intercept(eq(consumerRecords), eq(consumer));
37723782
inOrder.verify(batchMessageListener).onMessage(eq(List.of(firstRecord, secondRecord)));

0 commit comments

Comments
 (0)