Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions core/capabilities/remote/messagecache/message_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ func (c *MessageCache[EventID, PeerID]) Delete(eventID EventID) {
delete(c.events, eventID)
}

// Peers returns a snapshot of peer IDs that have inserted a message for eventID.
func (c *MessageCache[EventID, PeerID]) Peers(eventID EventID) map[PeerID]bool {
ev, ok := c.events[eventID]
if !ok {
return nil
}

peers := make(map[PeerID]bool, len(ev.peerMsgs))
for peerID := range ev.peerMsgs {
peers[peerID] = true
}
return peers
}

// Return the number of events deleted.
// Scans all keys, which might be slow for large caches.
func (c *MessageCache[EventID, PeerID]) DeleteOlderThan(cutoffTimestamp int64) int {
Expand Down
91 changes: 70 additions & 21 deletions core/capabilities/remote/trigger_publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (p *triggerPublisher) Start(ctx context.Context) error {
}

p.wg.Add(1)
go p.registrationCleanupLoop()
go p.cacheCleanupLoop()
p.wg.Add(1)
go p.batchingLoop()
p.lggr.Info("TriggerPublisher started")
Expand Down Expand Up @@ -287,7 +287,7 @@ func (p *triggerPublisher) Receive(_ context.Context, msg *types.MessageBody) {
nowMs := time.Now().UnixMilli()
p.ackCache.Insert(key, sender, nowMs, msg.Payload)
minRequired := uint32(2*callerDon.F + 1)
ready, _ := p.ackCache.Ready(key, minRequired, nowMs-cfg.remoteConfig.RegistrationExpiry.Milliseconds(), false)
ready, _ := p.ackCache.Ready(key, minRequired, nowMs-cfg.remoteConfig.MessageExpiry.Milliseconds(), false)
if !ready {
p.lggr.Debugw("not ready to ACK trigger event yet", "triggerEventId", triggerEventID, "minRequired", minRequired)
return
Expand All @@ -298,22 +298,22 @@ func (p *triggerPublisher) Receive(_ context.Context, msg *types.MessageBody) {
p.lggr.Debugw("ACKing trigger event", "triggerEventId", triggerEventID)
err = cfg.underlying.AckEvent(ctx, p.capabilityID, triggerEventID, p.capMethodName)
if err != nil {
p.lggr.Errorf("failed to AckEvent on underlying trigger capability (eventID = %s, capabilityID: %s): %v",
triggerEventID, p.capabilityID, err)
p.lggr.Errorw("failed to AckEvent on underlying trigger capability",
"eventID", triggerEventID, "capabilityID", p.capabilityID, "err", err)
}
default:
p.lggr.Errorw("received message with unknown method",
"method", SanitizeLogString(msg.Method), "sender", sender)
}
}

func (p *triggerPublisher) registrationCleanupLoop() {
func (p *triggerPublisher) cacheCleanupLoop() {
defer p.wg.Done()

// Get initial config for ticker setup
firstCfg := p.cfg.Load()
if firstCfg == nil {
p.lggr.Errorw("registrationCleanupLoop started but config not set")
p.lggr.Errorw("cacheCleanupLoop started but config not set")
return
}
cleanupInterval := firstCfg.remoteConfig.MessageExpiry
Expand Down Expand Up @@ -348,7 +348,13 @@ func (p *triggerPublisher) registrationCleanupLoop() {
p.messageCache.Delete(key)
}
}

deleted := p.ackCache.DeleteOlderThan(now - cfg.remoteConfig.MessageExpiry.Milliseconds())
p.mu.Unlock()

if deleted > 0 {
p.lggr.Debugw("cleaned expired AckCache entries", "deleted", deleted)
}
}
}
}
Expand Down Expand Up @@ -434,23 +440,66 @@ func (p *triggerPublisher) sendBatch(resp *batchedResponse) {
resp.workflowIDs = nil
resp.triggerIDs = nil
}
msg := &types.MessageBody{
CapabilityId: p.capabilityID,
CapabilityDonId: cfg.capDonInfo.ID,
CallerDonId: resp.callerDonID,
Method: types.MethodTriggerEvent,
Payload: resp.rawResponse,
Metadata: &types.MessageBody_TriggerEventMetadata{
TriggerEventMetadata: &types.TriggerEventMetadata{
WorkflowIds: workflowBatch,
TriggerIds: triggerBatch,
TriggerEventId: resp.triggerEventID,
},
},
CapabilityMethod: p.capMethodName,

ackSnapshot := make(map[string]map[p2ptypes.PeerID]bool)
p.mu.RLock()
for _, triggerID := range triggerBatch {
key := ackKey{
callerDonID: resp.callerDonID,
triggerEventID: resp.triggerEventID,
triggerID: triggerID,
}
ackSnapshot[triggerID] = p.ackCache.Peers(key)
}
// NOTE: send to all nodes by default, introduce different strategies later (KS-76)
p.mu.RUnlock()

for _, peerID := range cfg.workflowDONs[resp.callerDonID].Members {
var missingTriggerIDs []string
var missingWorkflowIDs []string

// determine which triggerIDs / workflowIDs have not yet ACKd this trigger event
for i, triggerID := range triggerBatch {
peers := ackSnapshot[triggerID]
if peers == nil || !peers[peerID] {
missingTriggerIDs = append(missingTriggerIDs, triggerID)
missingWorkflowIDs = append(missingWorkflowIDs, workflowBatch[i])
}
}

if len(missingTriggerIDs) == 0 {
p.lggr.Debugw("skipping trigger event send; all triggerIDs already ACKed by peer",
"peerID", peerID,
"callerDonID", resp.callerDonID,
"triggerEventID", resp.triggerEventID,
"triggerIDs", triggerBatch,
)
continue
}

p.lggr.Debugw("sending trigger event to peer",
"peerID", peerID,
"callerDonID", resp.callerDonID,
"triggerEventID", resp.triggerEventID,
"workflowIDs", missingWorkflowIDs,
"triggerIDs", missingTriggerIDs,
)

msg := &types.MessageBody{
CapabilityId: p.capabilityID,
CapabilityDonId: cfg.capDonInfo.ID,
CallerDonId: resp.callerDonID,
Method: types.MethodTriggerEvent,
Payload: resp.rawResponse,
CapabilityMethod: p.capMethodName,
Metadata: &types.MessageBody_TriggerEventMetadata{
TriggerEventMetadata: &types.TriggerEventMetadata{
WorkflowIds: missingWorkflowIDs,
TriggerIds: missingTriggerIDs,
TriggerEventId: resp.triggerEventID,
},
},
}

err := p.dispatcher.Send(peerID, msg)
if err != nil {
p.lggr.Errorw("failed to send trigger event", "peerID", peerID, "err", err)
Expand Down
166 changes: 166 additions & 0 deletions core/capabilities/remote/trigger_publisher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,172 @@ func TestTriggerPublisher_MultipleTriggersSameWorkflow(t *testing.T) {
require.NoError(t, publisher.Close())
}

func TestTriggerPublisher_ResendBehavior_MultiTriggerBatch(t *testing.T) {
ctx := testutils.Context(t)
lggr := logger.Test(t)

capabilityDONID := uint32(1)
workflowDONID := uint32(2)

peers := make([]p2ptypes.PeerID, 2)
require.NoError(t, peers[0].UnmarshalText([]byte(peerID1)))
require.NoError(t, peers[1].UnmarshalText([]byte(peerID2)))

capDonInfo := commoncap.DON{
ID: capabilityDONID,
Members: []p2ptypes.PeerID{peers[0]},
F: 0,
}

workflowDonInfo := commoncap.DON{
ID: workflowDONID,
Members: []p2ptypes.PeerID{peers[0], peers[1]},
F: 0,
}

workflowDONs := map[uint32]commoncap.DON{
workflowDonInfo.ID: workflowDonInfo,
}

underlying := newMultiTrigger(commoncap.CapabilityInfo{ID: capID})
dispatcher := mocks.NewDispatcher(t)

config := &commoncap.RemoteTriggerConfig{
RegistrationRefresh: 100 * time.Millisecond,
RegistrationExpiry: 100 * time.Second,
MinResponsesToAggregate: 1,
MessageExpiry: 100 * time.Second,
MaxBatchSize: 2,
BatchCollectionPeriod: 10 * time.Millisecond,
}

publisher := remote.NewTriggerPublisher(capID, "", dispatcher, lggr)
require.NoError(t, publisher.SetConfig(config, underlying, capDonInfo, workflowDONs))
require.NoError(t, publisher.Start(ctx))
defer func() {
require.NoError(t, publisher.Close())
}()

// Register two triggers
for _, trig := range []string{"triggerA", "triggerB"} {
reg := newRegisterTriggerMessageWithTriggerID(t, workflowDONID, peers[0], trig)
publisher.Receive(ctx, reg)
<-underlying.registrationsCh
}

var mu sync.Mutex
sendRecords := make([]struct {
peer p2ptypes.PeerID
triggerIDs []string
}, 0)

sendCh := make(chan struct{}, 10)

dispatcher.On("Send", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
mu.Lock()
defer mu.Unlock()

peer := args.Get(0).(p2ptypes.PeerID)
msg := args.Get(1).(*remotetypes.MessageBody)
meta := msg.Metadata.(*remotetypes.MessageBody_TriggerEventMetadata)

sendRecords = append(sendRecords, struct {
peer p2ptypes.PeerID
triggerIDs []string
}{
peer: peer,
triggerIDs: append([]string(nil), meta.TriggerEventMetadata.TriggerIds...),
})

sendCh <- struct{}{}
}).Return(nil)

t.Run("initial send to both peers with both triggerIDs", func(t *testing.T) {
mu.Lock()
sendRecords = nil
mu.Unlock()

underlying.SendEvent("triggerA", commoncap.TriggerResponse{
Event: commoncap.TriggerEvent{ID: "event1"},
})
underlying.SendEvent("triggerB", commoncap.TriggerResponse{
Event: commoncap.TriggerEvent{ID: "event1"},
})

// Expect 2 sends
<-sendCh
<-sendCh

mu.Lock()
defer mu.Unlock()

require.Len(t, sendRecords, 2)
for _, rec := range sendRecords {
require.ElementsMatch(t, []string{"triggerA", "triggerB"}, rec.triggerIDs)
}
})

t.Run("partial ACK trims only missing triggerIDs per peer", func(t *testing.T) {
publisher.Receive(ctx, newAckEventMessage(t, "event1", "triggerA", workflowDONID, peers[0]))

mu.Lock()
sendRecords = nil
mu.Unlock()

underlying.SendEvent("triggerA", commoncap.TriggerResponse{
Event: commoncap.TriggerEvent{ID: "event1"},
})
underlying.SendEvent("triggerB", commoncap.TriggerResponse{
Event: commoncap.TriggerEvent{ID: "event1"},
})

// Expect 2 sends
<-sendCh
<-sendCh

mu.Lock()
defer mu.Unlock()

require.Len(t, sendRecords, 2)

for _, rec := range sendRecords {
if rec.peer == peers[0] {
require.ElementsMatch(t, []string{"triggerB"}, rec.triggerIDs)
}
if rec.peer == peers[1] {
require.ElementsMatch(t, []string{"triggerA", "triggerB"}, rec.triggerIDs)
}
}
})

t.Run("full ACK suppresses resend", func(t *testing.T) {
publisher.Receive(ctx, newAckEventMessage(t, "event1", "triggerA", workflowDONID, peers[1]))
publisher.Receive(ctx, newAckEventMessage(t, "event1", "triggerB", workflowDONID, peers[0]))
publisher.Receive(ctx, newAckEventMessage(t, "event1", "triggerB", workflowDONID, peers[1]))

mu.Lock()
sendRecords = nil
mu.Unlock()

underlying.SendEvent("triggerA", commoncap.TriggerResponse{
Event: commoncap.TriggerEvent{ID: "event1"},
})
underlying.SendEvent("triggerB", commoncap.TriggerResponse{
Event: commoncap.TriggerEvent{ID: "event1"},
})

select {
case <-sendCh:
t.Fatal("unexpected resend after full ACK")
case <-time.After(100 * time.Millisecond):
}

mu.Lock()
defer mu.Unlock()
require.Empty(t, sendRecords)
})
}

func newRegisterTriggerMessageWithTriggerID(t *testing.T, callerDonID uint32, sender p2ptypes.PeerID, triggerID string) *remotetypes.MessageBody {
triggerRequest := commoncap.TriggerRegistrationRequest{
TriggerID: triggerID,
Expand Down
2 changes: 1 addition & 1 deletion core/capabilities/remote/trigger_subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (s *triggerSubscriber) Info(ctx context.Context) (commoncap.CapabilityInfo,
}

func (s *triggerSubscriber) AckEvent(ctx context.Context, triggerID string, eventID string, method string) error {
s.lggr.Debugf("AckEvent called on subscriber (triggerID=%s, eventID=%s)", triggerID, eventID)
s.lggr.Debugw("AckEvent called on subscriber", "triggerID", triggerID, "eventID", eventID)
cfg := s.cfg.Load()
for _, peerID := range cfg.capDonInfo.Members {
m := &types.MessageBody{
Expand Down
8 changes: 7 additions & 1 deletion core/internal/cltest/cltest.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"net/url"
"os"
"reflect"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -721,7 +722,12 @@ func (ta *TestApplication) Start(ctx context.Context) error {
if err != nil {
return err
}
ta.t.Cleanup(func() { require.NoError(ta.t, ta.Stop()) })
ta.t.Cleanup(func() {
err := ta.Stop()
if err != nil && !strings.Contains(err.Error(), "stopped") {
require.NoError(ta.t, err)
}
})
return nil
}

Expand Down
7 changes: 6 additions & 1 deletion core/services/workflows/v2/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,11 @@ func (e *Engine) startExecution(ctx context.Context, wrappedTriggerEvent enqueue
if errors.Is(addErr, store.ErrDuplicateExecution) {
e.logger().Infow("Skipping duplicate execution", "executionID", executionID, "triggerID", wrappedTriggerEvent.triggerCapID, "triggerIndex", wrappedTriggerEvent.triggerIndex)
e.metrics.With(platform.KeyTriggerID, wrappedTriggerEvent.triggerCapID).IncrementWorkflowTriggerEventErrorCounter(ctx)
registrationID := TriggerRegistrationID(e.cfg.WorkflowID, wrappedTriggerEvent.triggerIndex)
err = e.ackTriggerEvent(ctx, registrationID, &triggerEvent)
if err != nil {
e.lggr.Errorw("failed to re-ACK trigger event", "eventID", triggerEvent.ID, "err", err)
}
return
}
e.logger().Errorw("Failed to register execution in store, proceeding anyway", "executionID", executionID, "err", addErr)
Expand Down Expand Up @@ -674,7 +679,7 @@ func (e *Engine) startExecution(ctx context.Context, wrappedTriggerEvent enqueue
registrationID := TriggerRegistrationID(e.cfg.WorkflowID, wrappedTriggerEvent.triggerIndex)
err = e.ackTriggerEvent(ctx, registrationID, &triggerEvent)
if err != nil {
e.lggr.Errorf("failed to ACK trigger event (eventID=%s): %v", triggerEvent.ID, err)
e.lggr.Errorw("failed to ACK trigger event", "eventID", triggerEvent.ID, "err", err)
}
e.metrics.With("workflowID", e.cfg.WorkflowID, "workflowName", e.cfg.WorkflowName.String()).IncrementWorkflowExecutionStartedCounter(ctx)

Expand Down
Loading