Skip to content

Commit 70e221b

Browse files
committed
snapshot/revert views
1 parent 788681d commit 70e221b

File tree

2 files changed

+61
-3
lines changed

2 files changed

+61
-3
lines changed

action/protocol/protocol.go

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,49 @@ type (
120120

121121
// Views stores the view for all protocols
122122
Views struct {
123-
vm map[string]View
123+
snapshotID int
124+
snapshots map[int]map[string]View
125+
vm map[string]View
124126
}
125127
)
126128

127129
func NewViews() *Views {
128130
return &Views{
129-
vm: make(map[string]View),
131+
snapshotID: 0,
132+
snapshots: make(map[int]map[string]View),
133+
vm: make(map[string]View),
130134
}
131135
}
132136

137+
func (views *Views) Snapshot() int {
138+
views.snapshotID++
139+
views.snapshots[views.snapshotID] = make(map[string]View)
140+
keys := make([]string, 0, len(views.vm))
141+
for key := range views.vm {
142+
keys = append(keys, key)
143+
}
144+
for _, key := range keys {
145+
views.snapshots[views.snapshotID][key] = views.vm[key]
146+
views.vm[key] = views.vm[key].Clone()
147+
}
148+
return views.snapshotID
149+
}
150+
151+
func (views *Views) Revert(id int) error {
152+
if id > views.snapshotID || id < 0 {
153+
return errors.Errorf("invalid snapshot id %d, max id is %d", id, views.snapshotID)
154+
}
155+
for k, v := range views.snapshots[id] {
156+
views.vm[k] = v
157+
}
158+
views.snapshotID = id
159+
// clean up snapshots that are not needed anymore
160+
for i := id + 1; i <= views.snapshotID; i++ {
161+
delete(views.snapshots, i)
162+
}
163+
return nil
164+
}
165+
133166
func (views *Views) Clone() *Views {
134167
clone := NewViews()
135168
for key, view := range views.vm {

state/factory/workingset.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package factory
77

88
import (
99
"context"
10+
"math"
1011
"math/big"
1112
"sort"
1213
"time"
@@ -70,6 +71,7 @@ type (
7071
workingSetStoreFactory WorkingSetStoreFactory
7172
height uint64
7273
views *protocol.Views
74+
viewsSnapshots map[int]int
7375
store workingSetStore
7476
finalized bool
7577
txValidator *protocol.GenericValidator
@@ -81,6 +83,7 @@ func newWorkingSet(height uint64, views *protocol.Views, store workingSetStore,
8183
ws := &workingSet{
8284
height: height,
8385
views: views,
86+
viewsSnapshots: make(map[int]int),
8487
store: store,
8588
workingSetStoreFactory: storeFactory,
8689
}
@@ -280,14 +283,36 @@ func (ws *workingSet) finalizeTx(ctx context.Context) {
280283
}
281284

282285
func (ws *workingSet) Snapshot() int {
283-
return ws.store.Snapshot()
286+
id := ws.store.Snapshot()
287+
vid := ws.views.Snapshot()
288+
ws.viewsSnapshots[id] = vid
289+
290+
return id
284291
}
285292

286293
func (ws *workingSet) Revert(snapshot int) error {
294+
vid, ok := ws.viewsSnapshots[snapshot]
295+
if !ok {
296+
return errors.Errorf("snapshot %d not found", snapshot)
297+
}
298+
if err := ws.views.Revert(vid); err != nil {
299+
return errors.Wrapf(err, "failed to revert views to snapshot %d", vid)
300+
}
287301
return ws.store.RevertSnapshot(snapshot)
288302
}
289303

290304
func (ws *workingSet) ResetSnapshots() {
305+
if len(ws.viewsSnapshots) > 0 {
306+
minVID := math.MaxInt
307+
for _, vid := range ws.viewsSnapshots {
308+
if vid < minVID {
309+
minVID = vid
310+
}
311+
}
312+
if err := ws.views.Revert(minVID); err != nil {
313+
log.L().Panic("failed to revert views to minimum snapshot", zap.Error(err))
314+
}
315+
}
291316
ws.store.ResetSnapshots()
292317
}
293318

0 commit comments

Comments
 (0)