Skip to content

Commit 1a31bbf

Browse files
committed
fsm: add early abort observer option
1 parent 2048b32 commit 1a31bbf

File tree

2 files changed

+57
-9
lines changed

2 files changed

+57
-9
lines changed

fsm/fsm.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ var (
1313
ErrWaitForStateTimedOut = errors.New(
1414
"timed out while waiting for event",
1515
)
16-
ErrInvalidContextType = errors.New("invalid context")
16+
ErrInvalidContextType = errors.New("invalid context")
17+
ErrWaitingForStateEarlyAbortError = errors.New(
18+
"waiting for state early abort",
19+
)
1720
)
1821

1922
const (
@@ -73,6 +76,8 @@ type Notification struct {
7376
NextState StateType
7477
// Event is the event that was processed.
7578
Event EventType
79+
// LastActionError is the error returned by the last action executed.
80+
LastActionError error
7681
}
7782

7883
// Observer is an interface that can be implemented by types that want to
@@ -214,9 +219,10 @@ func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error {
214219
// Notify the state machine's observers.
215220
s.observerMutex.Lock()
216221
notification := Notification{
217-
PreviousState: s.previous,
218-
NextState: s.current,
219-
Event: event,
222+
PreviousState: s.previous,
223+
NextState: s.current,
224+
Event: event,
225+
LastActionError: s.LastActionError,
220226
}
221227

222228
for _, observer := range s.observers {

fsm/observer.go

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ type WaitForStateOption interface {
5555
// fsmOptions is a struct that holds all options that can be passed to the
5656
// WaitForState function.
5757
type fsmOptions struct {
58-
initialWait time.Duration
58+
initialWait time.Duration
59+
abortEarlyOnError bool
5960
}
6061

6162
// InitialWaitOption is an option that can be passed to the WaitForState
@@ -76,6 +77,24 @@ func (w *InitialWaitOption) apply(o *fsmOptions) {
7677
o.initialWait = w.initialWait
7778
}
7879

80+
// AbortEarlyOnErrorOption is an option that can be passed to the WaitForState
81+
// function to abort early if an error occurs.
82+
type AbortEarlyOnErrorOption struct {
83+
abortEarlyOnError bool
84+
}
85+
86+
// apply implements the WaitForStateOption interface.
87+
func (a *AbortEarlyOnErrorOption) apply(o *fsmOptions) {
88+
o.abortEarlyOnError = a.abortEarlyOnError
89+
}
90+
91+
// WithAbortEarlyOnErrorOption creates a new AbortEarlyOnErrorOption.
92+
func WithAbortEarlyOnErrorOption() WaitForStateOption {
93+
return &AbortEarlyOnErrorOption{
94+
abortEarlyOnError: true,
95+
}
96+
}
97+
7998
// WaitForState waits for the state machine to reach the given state.
8099
// If the optional initialWait parameter is set, the function will wait for
81100
// the given duration before checking the state. This is useful if the
@@ -105,7 +124,8 @@ func (s *CachedObserver) WaitForState(ctx context.Context,
105124
defer cancel()
106125

107126
// Channel to notify when the desired state is reached
108-
ch := make(chan struct{})
127+
// or an error occurred.
128+
ch := make(chan error)
109129

110130
// Goroutine to wait on condition variable
111131
go func() {
@@ -115,8 +135,26 @@ func (s *CachedObserver) WaitForState(ctx context.Context,
115135
for {
116136
// Check if the last state is the desired state
117137
if s.lastNotification.NextState == state {
118-
ch <- struct{}{}
119-
return
138+
select {
139+
case <-timeoutCtx.Done():
140+
return
141+
142+
case ch <- nil:
143+
return
144+
}
145+
}
146+
147+
// Check if an error occurred
148+
if s.lastNotification.Event == OnError {
149+
if options.abortEarlyOnError {
150+
select {
151+
case <-timeoutCtx.Done():
152+
return
153+
154+
case ch <- s.lastNotification.LastActionError:
155+
return
156+
}
157+
}
120158
}
121159

122160
// Otherwise, wait for the next notification
@@ -130,7 +168,11 @@ func (s *CachedObserver) WaitForState(ctx context.Context,
130168
return NewErrWaitingForStateTimeout(
131169
state, s.lastNotification.NextState,
132170
)
133-
case <-ch:
171+
172+
case lastActionErr := <-ch:
173+
if lastActionErr != nil {
174+
return lastActionErr
175+
}
134176
return nil
135177
}
136178
}

0 commit comments

Comments
 (0)