Skip to content

Commit d857a13

Browse files
wenxuwan文徐
andauthored
feat: Fix concurrent subscriptions (#1222)
* seperate interface and implement * fix panic when close tracedispatcher * Restore rlog/log.go * Delete default.go * fix consumer panic * change any to interface * Optimize UnCompress --------- Co-authored-by: 文徐 <[email protected]>
1 parent ab6584b commit d857a13

File tree

4 files changed

+150
-22
lines changed

4 files changed

+150
-22
lines changed

consumer/push_consumer.go

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ type pushConsumer struct {
6767
queueMaxSpanFlowControlTimes int
6868
consumeFunc utils.Set
6969
submitToConsume func(*processQueue, *primitive.MessageQueue)
70-
subscribedTopic map[string]string
70+
subscribedTopic sync.Map
7171
interceptor primitive.Interceptor
7272
queueLock *QueueLock
7373
done chan struct{}
7474
closeOnce sync.Once
75-
crCh map[string]chan struct{}
75+
crCh sync.Map
7676
}
7777

7878
func NewPushConsumer(opts ...Option) (*pushConsumer, error) {
@@ -113,11 +113,9 @@ func NewPushConsumer(opts ...Option) (*pushConsumer, error) {
113113

114114
p := &pushConsumer{
115115
defaultConsumer: dc,
116-
subscribedTopic: make(map[string]string, 0),
117116
queueLock: newQueueLock(),
118117
done: make(chan struct{}, 1),
119118
consumeFunc: utils.NewSet(),
120-
crCh: make(map[string]chan struct{}),
121119
}
122120
dc.mqChanged = p.messageQueueChanged
123121
if p.consumeOrderly {
@@ -165,7 +163,7 @@ func (pc *pushConsumer) Start() error {
165163
}
166164

167165
retryTopic := internal.GetRetryTopic(pc.consumerGroup)
168-
pc.crCh[retryTopic] = make(chan struct{}, pc.defaultConsumer.option.ConsumeGoroutineNums)
166+
pc.crCh.Store(retryTopic, make(chan struct{}, pc.defaultConsumer.option.ConsumeGoroutineNums))
169167

170168
go func() {
171169
// todo start clean msg expired
@@ -236,13 +234,20 @@ func (pc *pushConsumer) Start() error {
236234
}
237235

238236
pc.client.UpdateTopicRouteInfo()
239-
for k := range pc.subscribedTopic {
237+
pc.subscribedTopic.Range(func(k, v interface{}) bool {
240238
_, exist := pc.topicSubscribeInfoTable.Load(k)
241239
if !exist {
242240
pc.Shutdown()
243-
return fmt.Errorf("the topic=%s route info not found, it may not exist", k)
241+
err = fmt.Errorf("the topic=%s route info not found, it may not exist", k)
242+
return false
244243
}
244+
return true
245+
})
246+
247+
if err != nil {
248+
return err
245249
}
250+
246251
pc.client.CheckClientInBroker()
247252
pc.client.SendHeartbeatToAllBrokerWithLock()
248253
go pc.client.RebalanceImmediately()
@@ -298,12 +303,10 @@ func (pc *pushConsumer) Subscribe(topic string, selector MessageSelector,
298303
if pc.option.Namespace != "" {
299304
topic = pc.option.Namespace + "%" + topic
300305
}
301-
if _, ok := pc.crCh[topic]; !ok {
302-
pc.crCh[topic] = make(chan struct{}, pc.defaultConsumer.option.ConsumeGoroutineNums)
303-
}
306+
pc.crCh.LoadOrStore(topic, make(chan struct{}, pc.defaultConsumer.option.ConsumeGoroutineNums))
304307
data := buildSubscriptionData(topic, selector)
305308
pc.subscriptionDataTable.Store(topic, data)
306-
pc.subscribedTopic[topic] = ""
309+
pc.subscribedTopic.LoadOrStore(topic, "")
307310

308311
pc.consumeFunc.Add(&PushConsumerCallback{
309312
f: f,
@@ -550,8 +553,12 @@ func (pc *pushConsumer) validate() error {
550553
// TODO FQA
551554
return fmt.Errorf("consumerGroup can't equal [%s], please specify another one", internal.DefaultConsumerGroup)
552555
}
553-
554-
if len(pc.subscribedTopic) == 0 {
556+
noSubscribedTopic := true
557+
pc.subscribedTopic.Range(func(key, value interface{}) bool {
558+
noSubscribedTopic = false
559+
return false
560+
})
561+
if noSubscribedTopic {
555562
rlog.Warning("not subscribe any topic yet", map[string]interface{}{
556563
rlog.LogKeyConsumerGroup: pc.consumerGroup,
557564
})
@@ -1089,9 +1096,7 @@ func (pc *pushConsumer) consumeMessageConcurrently(pq *processQueue, mq *primiti
10891096

10901097
limiter := pc.option.Limiter
10911098
limiterOn := limiter != nil
1092-
if _, ok := pc.crCh[mq.Topic]; !ok {
1093-
pc.crCh[mq.Topic] = make(chan struct{}, pc.defaultConsumer.option.ConsumeGoroutineNums)
1094-
}
1099+
pc.crCh.LoadOrStore(mq.Topic, make(chan struct{}, pc.defaultConsumer.option.ConsumeGoroutineNums))
10951100

10961101
for count := 0; count < len(msgs); count++ {
10971102
var subMsgs []*primitive.MessageExt
@@ -1107,8 +1112,10 @@ func (pc *pushConsumer) consumeMessageConcurrently(pq *processQueue, mq *primiti
11071112
if limiterOn {
11081113
limiter(utils.WithoutNamespace(mq.Topic))
11091114
}
1110-
pc.crCh[mq.Topic] <- struct{}{}
1111-
1115+
ch, _ := pc.crCh.Load(mq.Topic)
1116+
if channel, ok := ch.(chan struct{}); ok {
1117+
channel <- struct{}{}
1118+
}
11121119
go primitive.WithRecover(func() {
11131120
defer func() {
11141121
if err := recover(); err != nil {
@@ -1121,7 +1128,10 @@ func (pc *pushConsumer) consumeMessageConcurrently(pq *processQueue, mq *primiti
11211128
rlog.LogKeyConsumerGroup: pc.consumerGroup,
11221129
})
11231130
}
1124-
<-pc.crCh[mq.Topic]
1131+
ch, _ := pc.crCh.Load(mq.Topic)
1132+
if channel, ok := ch.(chan struct{}); ok {
1133+
<-channel
1134+
}
11251135
}()
11261136
RETRY:
11271137
if pq.IsDroppd() {

internal/utils/compression.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import (
2121
"bytes"
2222
"compress/zlib"
2323
"github.com/apache/rocketmq-client-go/v2/errors"
24-
"io/ioutil"
24+
"io"
2525
"sync"
2626
)
2727

@@ -79,9 +79,12 @@ func UnCompress(data []byte) []byte {
7979
return data
8080
}
8181
defer r.Close()
82-
retData, err := ioutil.ReadAll(r)
82+
83+
// Use a buffer with reasonable initial size to avoid frequent reallocations
84+
buf := bytes.NewBuffer(make([]byte, 0, len(data)*2))
85+
_, err = io.Copy(buf, r)
8386
if err != nil {
8487
return data
8588
}
86-
return retData
89+
return buf.Bytes()
8790
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package utils
2+
3+
import (
4+
"bytes"
5+
"compress/zlib"
6+
"io/ioutil"
7+
"math/rand"
8+
"strconv"
9+
"testing"
10+
)
11+
12+
func generateTestData(size int) []byte {
13+
data := make([]byte, size)
14+
rand.Read(data)
15+
return data
16+
}
17+
18+
func compressTestData(data []byte) []byte {
19+
var buf bytes.Buffer
20+
writer, _ := zlib.NewWriterLevel(&buf, zlib.BestCompression)
21+
writer.Write(data)
22+
writer.Close()
23+
return buf.Bytes()
24+
}
25+
26+
func UnCompressOriginal(data []byte) []byte {
27+
rdata := bytes.NewReader(data)
28+
r, err := zlib.NewReader(rdata)
29+
if err != nil {
30+
return data
31+
}
32+
defer r.Close()
33+
retData, err := ioutil.ReadAll(r)
34+
if err != nil {
35+
return data
36+
}
37+
return retData
38+
}
39+
40+
var testDataSizes = []int{1024, 64 * 1024, 512 * 1024, 1024 * 1024, 2 * 1024 * 1024, 4 * 1024 * 1024}
41+
42+
func BenchmarkUnCompress(b *testing.B) {
43+
for _, size := range testDataSizes {
44+
data := generateTestData(size)
45+
compressed := compressTestData(data)
46+
47+
b.Run("New_"+formatSize(size), func(b *testing.B) {
48+
b.ResetTimer()
49+
b.ReportAllocs()
50+
for i := 0; i < b.N; i++ {
51+
result := UnCompress(compressed)
52+
_ = result
53+
}
54+
})
55+
56+
b.Run("Original_"+formatSize(size), func(b *testing.B) {
57+
b.ResetTimer()
58+
b.ReportAllocs()
59+
for i := 0; i < b.N; i++ {
60+
result := UnCompressOriginal(compressed)
61+
_ = result
62+
}
63+
})
64+
}
65+
}
66+
67+
func BenchmarkMemoryUsage(b *testing.B) {
68+
// 测试大内存使用情况
69+
largeData := generateTestData(4 * 1024 * 1024) // 4MB
70+
compressed := compressTestData(largeData)
71+
72+
b.Run("New_Memory", func(b *testing.B) {
73+
b.ResetTimer()
74+
b.ReportAllocs()
75+
for i := 0; i < b.N; i++ {
76+
result := UnCompress(compressed)
77+
_ = result
78+
}
79+
})
80+
81+
b.Run("Original_Memory", func(b *testing.B) {
82+
b.ResetTimer()
83+
b.ReportAllocs()
84+
for i := 0; i < b.N; i++ {
85+
result := UnCompressOriginal(compressed)
86+
_ = result
87+
}
88+
})
89+
}
90+
91+
func formatSize(bytes int) string {
92+
if bytes < 1024 {
93+
return strconv.Itoa(bytes) + "B"
94+
} else if bytes < 1024*1024 {
95+
return strconv.Itoa(bytes/1024) + "KB"
96+
} else if bytes < 1024*1024*1024 {
97+
return strconv.Itoa(bytes/(1024*1024)) + "MB"
98+
} else {
99+
return strconv.Itoa(bytes/(1024*1024*1024)) + "GB"
100+
}
101+
}

internal/utils/set.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"bytes"
2222
"encoding/json"
2323
"sort"
24+
"sync"
2425
)
2526

2627
type UniqueItem interface {
@@ -34,6 +35,7 @@ func (str StringUnique) UniqueID() string {
3435
}
3536

3637
type Set struct {
38+
mux sync.RWMutex
3739
items map[string]UniqueItem
3840
}
3941

@@ -44,29 +46,41 @@ func NewSet() Set {
4446
}
4547

4648
func (s *Set) Items() map[string]UniqueItem {
49+
s.mux.RLock()
50+
defer s.mux.RUnlock()
4751
return s.items
4852
}
4953

5054
func (s *Set) Add(v UniqueItem) {
55+
s.mux.Lock()
56+
defer s.mux.Unlock()
5157
s.items[v.UniqueID()] = v
5258
}
5359

5460
func (s *Set) AddKV(k, v string) {
61+
s.mux.Lock()
62+
defer s.mux.Unlock()
5563
s.items[k] = StringUnique(v)
5664
}
5765

5866
func (s *Set) Contains(k string) (UniqueItem, bool) {
67+
s.mux.RLock()
68+
defer s.mux.RUnlock()
5969
v, ok := s.items[k]
6070
return v, ok
6171
}
6272

6373
func (s *Set) Len() int {
74+
s.mux.RLock()
75+
defer s.mux.RUnlock()
6476
return len(s.items)
6577
}
6678

6779
var _ json.Marshaler = &Set{}
6880

6981
func (s *Set) MarshalJSON() ([]byte, error) {
82+
s.mux.RLock()
83+
defer s.mux.RUnlock()
7084
if len(s.items) == 0 {
7185
return []byte("[]"), nil
7286
}

0 commit comments

Comments
 (0)