Skip to content

Commit 1462832

Browse files
authored
Merge pull request #1001 from ellemouton/sql23
[sql-23] firewalldb: thread contexts through for kv-store interfaces
2 parents 64ab73a + 627e7c4 commit 1462832

File tree

5 files changed

+116
-99
lines changed

5 files changed

+116
-99
lines changed

firewalldb/interface.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,24 @@ type SessionDB interface {
1414
// GetSession returns the session for a specific id.
1515
GetSession(context.Context, session.ID) (*session.Session, error)
1616
}
17+
18+
// DBExecutor provides an Update and View method that will allow the caller
19+
// to perform atomic read and write transactions defined by PrivacyMapTx on the
20+
// underlying BoltDB.
21+
type DBExecutor[T any] interface {
22+
// Update opens a database read/write transaction and executes the
23+
// function f with the transaction passed as a parameter. After f exits,
24+
// if f did not error, the transaction is committed. Otherwise, if f did
25+
// error, the transaction is rolled back. If the rollback fails, the
26+
// original error returned by f is still returned. If the commit fails,
27+
// the commit error is returned.
28+
Update(ctx context.Context, f func(ctx context.Context,
29+
tx T) error) error
30+
31+
// View opens a database read transaction and executes the function f
32+
// with the transaction passed as a parameter. After f exits, the
33+
// transaction is rolled back. If f errors, its error is returned, not a
34+
// rollback error (if any occur).
35+
View(ctx context.Context, f func(ctx context.Context,
36+
tx T) error) error
37+
}

firewalldb/kvstores.go

Lines changed: 19 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -54,21 +54,7 @@ var (
5454
// KVStores provides an Update and View method that will allow the caller to
5555
// perform atomic read and write transactions on and of the key value stores
5656
// offered the KVStoreTx.
57-
type KVStores interface {
58-
// Update opens a database read/write transaction and executes the
59-
// function f with the transaction passed as a parameter. After f exits,
60-
// if f did not error, the transaction is committed. Otherwise, if f did
61-
// error, the transaction is rolled back. If the rollback fails, the
62-
// original error returned by f is still returned. If the commit fails,
63-
// the commit error is returned.
64-
Update(f func(tx KVStoreTx) error) error
65-
66-
// View opens a database read transaction and executes the function f
67-
// with the transaction passed as a parameter. After f exits, the
68-
// transaction is rolled back. If f errors, its error is returned, not a
69-
// rollback error (if any occur).
70-
View(f func(tx KVStoreTx) error) error
71-
}
57+
type KVStores = DBExecutor[KVStoreTx]
7258

7359
// KVStoreTx represents a database transaction that can be used for both read
7460
// and writes of the various different key value stores offered for the rule.
@@ -122,7 +108,7 @@ func (db *DB) GetKVStores(rule string, groupID session.ID,
122108
feature string) KVStores {
123109

124110
return &kvStores{
125-
DB: db,
111+
db: db.DB,
126112
ruleName: rule,
127113
groupID: groupID,
128114
featureName: feature,
@@ -131,25 +117,12 @@ func (db *DB) GetKVStores(rule string, groupID session.ID,
131117

132118
// kvStores implements the rules.KVStores interface.
133119
type kvStores struct {
134-
*DB
120+
db *bbolt.DB
135121
ruleName string
136122
groupID session.ID
137123
featureName string
138124
}
139125

140-
// beginTx starts db transaction. The transaction will be a read or read-write
141-
// transaction depending on the value of the `writable` parameter.
142-
func (s *kvStores) beginTx(writable bool) (*kvStoreTx, error) {
143-
boltTx, err := s.Begin(writable)
144-
if err != nil {
145-
return nil, err
146-
}
147-
return &kvStoreTx{
148-
kvStores: s,
149-
boltTx: boltTx,
150-
}, nil
151-
}
152-
153126
// Update opens a database read/write transaction and executes the function f
154127
// with the transaction passed as a parameter. After f exits, if f did not
155128
// error, the transaction is committed. Otherwise, if f did error, the
@@ -158,28 +131,17 @@ func (s *kvStores) beginTx(writable bool) (*kvStoreTx, error) {
158131
// returned.
159132
//
160133
// NOTE: this is part of the KVStores interface.
161-
func (s *kvStores) Update(f func(tx KVStoreTx) error) error {
162-
tx, err := s.beginTx(true)
163-
if err != nil {
164-
return err
165-
}
134+
func (s *kvStores) Update(ctx context.Context, fn func(ctx context.Context,
135+
tx KVStoreTx) error) error {
166136

167-
// Make sure the transaction rolls back in the event of a panic.
168-
defer func() {
169-
if tx != nil {
170-
_ = tx.boltTx.Rollback()
137+
return s.db.Update(func(tx *bbolt.Tx) error {
138+
boltTx := &kvStoreTx{
139+
boltTx: tx,
140+
kvStores: s,
171141
}
172-
}()
173-
174-
err = f(tx)
175-
if err != nil {
176-
// Want to return the original error, not a rollback error if
177-
// any occur.
178-
_ = tx.boltTx.Rollback()
179-
return err
180-
}
181142

182-
return tx.boltTx.Commit()
143+
return fn(ctx, boltTx)
144+
})
183145
}
184146

185147
// View opens a database read transaction and executes the function f with the
@@ -188,29 +150,17 @@ func (s *kvStores) Update(f func(tx KVStoreTx) error) error {
188150
// occur).
189151
//
190152
// NOTE: this is part of the KVStores interface.
191-
func (s *kvStores) View(f func(tx KVStoreTx) error) error {
192-
tx, err := s.beginTx(false)
193-
if err != nil {
194-
return err
195-
}
153+
func (s *kvStores) View(ctx context.Context, fn func(ctx context.Context,
154+
tx KVStoreTx) error) error {
196155

197-
// Make sure the transaction rolls back in the event of a panic.
198-
defer func() {
199-
if tx != nil {
200-
_ = tx.boltTx.Rollback()
156+
return s.db.View(func(tx *bbolt.Tx) error {
157+
boltTx := &kvStoreTx{
158+
boltTx: tx,
159+
kvStores: s,
201160
}
202-
}()
203161

204-
err = f(tx)
205-
rollbackErr := tx.boltTx.Rollback()
206-
if err != nil {
207-
return err
208-
}
209-
210-
if rollbackErr != nil {
211-
return rollbackErr
212-
}
213-
return nil
162+
return fn(ctx, boltTx)
163+
})
214164
}
215165

216166
// getBucketFunc defines the signature of the bucket creation/fetching function

firewalldb/kvstores_test.go

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func TestKVStoreTxs(t *testing.T) {
2828

2929
// Test that if an action fails midway through the transaction, then
3030
// it is rolled back.
31-
err = store.Update(func(tx KVStoreTx) error {
31+
err = store.Update(ctx, func(ctx context.Context, tx KVStoreTx) error {
3232
err := tx.Global().Set(ctx, "test", []byte{1})
3333
if err != nil {
3434
return err
@@ -46,7 +46,7 @@ func TestKVStoreTxs(t *testing.T) {
4646
require.Error(t, err)
4747

4848
var v []byte
49-
err = store.View(func(tx KVStoreTx) error {
49+
err = store.View(ctx, func(ctx context.Context, tx KVStoreTx) error {
5050
b, err := tx.Global().Get(ctx, "test")
5151
if err != nil {
5252
return err
@@ -94,7 +94,7 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) {
9494

9595
store := db.GetKVStores("test-rule", [4]byte{1, 1, 1, 1}, featureName)
9696

97-
err = store.Update(func(tx KVStoreTx) error {
97+
err = store.Update(ctx, func(ctx context.Context, tx KVStoreTx) error {
9898
// Set an item in the temp store.
9999
err := tx.LocalTemp().Set(ctx, "test", []byte{4, 3, 2})
100100
if err != nil {
@@ -112,7 +112,7 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) {
112112
v1 []byte
113113
v2 []byte
114114
)
115-
err = store.View(func(tx KVStoreTx) error {
115+
err = store.View(ctx, func(ctx context.Context, tx KVStoreTx) error {
116116
b, err := tx.LocalTemp().Get(ctx, "test")
117117
if err != nil {
118118
return err
@@ -144,7 +144,7 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) {
144144

145145
// The temp store should no longer have the stored value but the perm
146146
// store should .
147-
err = store.View(func(tx KVStoreTx) error {
147+
err = store.View(ctx, func(ctx context.Context, tx KVStoreTx) error {
148148
b, err := tx.LocalTemp().Get(ctx, "test")
149149
if err != nil {
150150
return err
@@ -188,29 +188,37 @@ func TestKVStoreNameSpaces(t *testing.T) {
188188
rulesDB3 := db.GetKVStores("test-rule", groupID2, "re-balance")
189189

190190
// Test that the three ruleDBs share the same global space.
191-
err = rulesDB1.Update(func(tx KVStoreTx) error {
191+
err = rulesDB1.Update(ctx, func(ctx context.Context,
192+
tx KVStoreTx) error {
193+
192194
return tx.Global().Set(
193195
ctx, "test-global", []byte("global thing!"),
194196
)
195197
})
196198
require.NoError(t, err)
197199

198-
err = rulesDB2.Update(func(tx KVStoreTx) error {
200+
err = rulesDB2.Update(ctx, func(ctx context.Context,
201+
tx KVStoreTx) error {
202+
199203
return tx.Global().Set(
200204
ctx, "test-global", []byte("different global thing!"),
201205
)
202206
})
203207
require.NoError(t, err)
204208

205-
err = rulesDB3.Update(func(tx KVStoreTx) error {
209+
err = rulesDB3.Update(ctx, func(ctx context.Context,
210+
tx KVStoreTx) error {
211+
206212
return tx.Global().Set(
207213
ctx, "test-global", []byte("yet another global thing"),
208214
)
209215
})
210216
require.NoError(t, err)
211217

212218
var v []byte
213-
err = rulesDB1.View(func(tx KVStoreTx) error {
219+
err = rulesDB1.View(ctx, func(ctx context.Context,
220+
tx KVStoreTx) error {
221+
214222
b, err := tx.Global().Get(ctx, "test-global")
215223
if err != nil {
216224
return err
@@ -221,7 +229,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
221229
require.NoError(t, err)
222230
require.True(t, bytes.Equal(v, []byte("yet another global thing")))
223231

224-
err = rulesDB2.View(func(tx KVStoreTx) error {
232+
err = rulesDB2.View(ctx, func(ctx context.Context,
233+
tx KVStoreTx) error {
234+
225235
b, err := tx.Global().Get(ctx, "test-global")
226236
if err != nil {
227237
return err
@@ -232,7 +242,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
232242
require.NoError(t, err)
233243
require.True(t, bytes.Equal(v, []byte("yet another global thing")))
234244

235-
err = rulesDB3.View(func(tx KVStoreTx) error {
245+
err = rulesDB3.View(ctx, func(ctx context.Context,
246+
tx KVStoreTx) error {
247+
236248
b, err := tx.Global().Get(ctx, "test-global")
237249
if err != nil {
238250
return err
@@ -244,22 +256,30 @@ func TestKVStoreNameSpaces(t *testing.T) {
244256
require.True(t, bytes.Equal(v, []byte("yet another global thing")))
245257

246258
// Test that the feature space is not shared by any of the dbs.
247-
err = rulesDB1.Update(func(tx KVStoreTx) error {
259+
err = rulesDB1.Update(ctx, func(ctx context.Context,
260+
tx KVStoreTx) error {
261+
248262
return tx.Local().Set(ctx, "count", []byte("1"))
249263
})
250264
require.NoError(t, err)
251265

252-
err = rulesDB2.Update(func(tx KVStoreTx) error {
266+
err = rulesDB2.Update(ctx, func(ctx context.Context,
267+
tx KVStoreTx) error {
268+
253269
return tx.Local().Set(ctx, "count", []byte("2"))
254270
})
255271
require.NoError(t, err)
256272

257-
err = rulesDB3.Update(func(tx KVStoreTx) error {
273+
err = rulesDB3.Update(ctx, func(ctx context.Context,
274+
tx KVStoreTx) error {
275+
258276
return tx.Local().Set(ctx, "count", []byte("3"))
259277
})
260278
require.NoError(t, err)
261279

262-
err = rulesDB1.View(func(tx KVStoreTx) error {
280+
err = rulesDB1.View(ctx, func(ctx context.Context,
281+
tx KVStoreTx) error {
282+
263283
b, err := tx.Local().Get(ctx, "count")
264284
if err != nil {
265285
return err
@@ -270,7 +290,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
270290
require.NoError(t, err)
271291
require.True(t, bytes.Equal(v, []byte("1")))
272292

273-
err = rulesDB2.View(func(tx KVStoreTx) error {
293+
err = rulesDB2.View(ctx, func(ctx context.Context,
294+
tx KVStoreTx) error {
295+
274296
b, err := tx.Local().Get(ctx, "count")
275297
if err != nil {
276298
return err
@@ -281,7 +303,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
281303
require.NoError(t, err)
282304
require.True(t, bytes.Equal(v, []byte("2")))
283305

284-
err = rulesDB3.View(func(tx KVStoreTx) error {
306+
err = rulesDB3.View(ctx, func(ctx context.Context,
307+
tx KVStoreTx) error {
308+
285309
b, err := tx.Local().Get(ctx, "count")
286310
if err != nil {
287311
return err
@@ -299,22 +323,30 @@ func TestKVStoreNameSpaces(t *testing.T) {
299323
rulesDB2 = db.GetKVStores("test-rule", groupID1, "")
300324
rulesDB3 = db.GetKVStores("test-rule", groupID2, "")
301325

302-
err = rulesDB1.Update(func(tx KVStoreTx) error {
326+
err = rulesDB1.Update(ctx, func(ctx context.Context,
327+
tx KVStoreTx) error {
328+
303329
return tx.Local().Set(ctx, "test", []byte("thing 1"))
304330
})
305331
require.NoError(t, err)
306332

307-
err = rulesDB2.Update(func(tx KVStoreTx) error {
333+
err = rulesDB2.Update(ctx, func(ctx context.Context,
334+
tx KVStoreTx) error {
335+
308336
return tx.Local().Set(ctx, "test", []byte("thing 2"))
309337
})
310338
require.NoError(t, err)
311339

312-
err = rulesDB3.Update(func(tx KVStoreTx) error {
340+
err = rulesDB3.Update(ctx, func(ctx context.Context,
341+
tx KVStoreTx) error {
342+
313343
return tx.Local().Set(ctx, "test", []byte("thing 3"))
314344
})
315345
require.NoError(t, err)
316346

317-
err = rulesDB1.View(func(tx KVStoreTx) error {
347+
err = rulesDB1.View(ctx, func(ctx context.Context,
348+
tx KVStoreTx) error {
349+
318350
b, err := tx.Local().Get(ctx, "test")
319351
if err != nil {
320352
return err
@@ -325,7 +357,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
325357
require.NoError(t, err)
326358
require.True(t, bytes.Equal(v, []byte("thing 2")))
327359

328-
err = rulesDB2.View(func(tx KVStoreTx) error {
360+
err = rulesDB2.View(ctx, func(ctx context.Context,
361+
tx KVStoreTx) error {
362+
329363
b, err := tx.Local().Get(ctx, "test")
330364
if err != nil {
331365
return err
@@ -336,7 +370,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
336370
require.NoError(t, err)
337371
require.True(t, bytes.Equal(v, []byte("thing 2")))
338372

339-
err = rulesDB3.View(func(tx KVStoreTx) error {
373+
err = rulesDB3.View(ctx, func(ctx context.Context,
374+
tx KVStoreTx) error {
375+
340376
b, err := tx.Local().Get(ctx, "test")
341377
if err != nil {
342378
return err

0 commit comments

Comments
 (0)