Skip to content

Commit b5f746a

Browse files
authored
Automatically load extensions. (#115)
1 parent fff8b1c commit b5f746a

36 files changed

+261
-245
lines changed

conn.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
7272
c.arena = c.newArena(1024)
7373
c.ctx = context.WithValue(c.ctx, connKey{}, c)
7474
c.handle, err = c.openDB(filename, flags)
75+
if err == nil {
76+
err = initExtensions(c)
77+
}
7578
if err != nil {
7679
return nil, err
7780
}

ext/array/array.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ import (
1515
// The argument must be bound to a Go slice or array of
1616
// ints, floats, bools, strings or byte slices,
1717
// using [sqlite3.BindPointer] or [sqlite3.Pointer].
18-
func Register(db *sqlite3.Conn) {
19-
sqlite3.CreateModule(db, "array", nil,
18+
func Register(db *sqlite3.Conn) error {
19+
return sqlite3.CreateModule(db, "array", nil,
2020
func(db *sqlite3.Conn, _, _, _ string, _ ...string) (array, error) {
2121
err := db.DeclareVTab(`CREATE TABLE x(value, array HIDDEN)`)
2222
return array{}, err

ext/array/array_test.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@ import (
1515
)
1616

1717
func Example_driver() {
18-
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
19-
array.Register(c)
20-
return nil
21-
})
18+
db, err := driver.Open(":memory:", array.Register)
2219
if err != nil {
2320
log.Fatal(err)
2421
}
@@ -53,14 +50,14 @@ func Example_driver() {
5350
}
5451

5552
func Example() {
53+
sqlite3.AutoExtension(array.Register)
54+
5655
db, err := sqlite3.Open(":memory:")
5756
if err != nil {
5857
log.Fatal(err)
5958
}
6059
defer db.Close()
6160

62-
array.Register(db)
63-
6461
stmt, _, err := db.Prepare(`
6562
SELECT name
6663
FROM pragma_function_list
@@ -91,10 +88,7 @@ func Example() {
9188
func Test_cursor_Column(t *testing.T) {
9289
t.Parallel()
9390

94-
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
95-
array.Register(c)
96-
return nil
97-
})
91+
db, err := driver.Open(":memory:", array.Register)
9892
if err != nil {
9993
t.Fatal(err)
10094
}
@@ -139,7 +133,10 @@ func Test_array_errors(t *testing.T) {
139133
}
140134
defer db.Close()
141135

142-
array.Register(db)
136+
err = array.Register(db)
137+
if err != nil {
138+
t.Fatal(err)
139+
}
143140

144141
err = db.Exec(`SELECT * FROM array()`)
145142
if err == nil {

ext/blobio/blob.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,11 @@ import (
2929
// along with the [sqlite3.Blob] handle.
3030
//
3131
// https://sqlite.org/c3ref/blob.html
32-
func Register(db *sqlite3.Conn) {
33-
db.CreateFunction("readblob", 6, 0, readblob)
34-
db.CreateFunction("writeblob", 6, 0, writeblob)
35-
db.CreateFunction("openblob", -1, 0, openblob)
32+
func Register(db *sqlite3.Conn) error {
33+
return errors.Join(
34+
db.CreateFunction("readblob", 6, 0, readblob),
35+
db.CreateFunction("writeblob", 6, 0, writeblob),
36+
db.CreateFunction("openblob", -1, 0, openblob))
3637
}
3738

3839
// OpenCallback is the type for the openblob callback.

ext/blobio/blob_test.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ import (
1818

1919
func Example() {
2020
// Open the database, registering the extension.
21-
db, err := driver.Open("file:/test.db?vfs=memdb", func(conn *sqlite3.Conn) error {
22-
blobio.Register(conn)
23-
return nil
24-
})
21+
db, err := driver.Open("file:/test.db?vfs=memdb", blobio.Register)
2522

2623
if err != nil {
2724
log.Fatal(err)
@@ -60,6 +57,11 @@ func Example() {
6057
// Hello BLOB!
6158
}
6259

60+
func init() {
61+
sqlite3.AutoExtension(blobio.Register)
62+
sqlite3.AutoExtension(array.Register)
63+
}
64+
6365
func Test_readblob(t *testing.T) {
6466
t.Parallel()
6567

@@ -69,9 +71,6 @@ func Test_readblob(t *testing.T) {
6971
}
7072
defer db.Close()
7173

72-
blobio.Register(db)
73-
array.Register(db)
74-
7574
err = db.Exec(`SELECT readblob()`)
7675
if err == nil {
7776
t.Fatal("want error")
@@ -129,9 +128,6 @@ func Test_openblob(t *testing.T) {
129128
}
130129
defer db.Close()
131130

132-
blobio.Register(db)
133-
array.Register(db)
134-
135131
err = db.Exec(`SELECT openblob()`)
136132
if err == nil {
137133
t.Fatal("want error")

ext/bloom/bloom.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ import (
2020
// Register registers the bloom_filter virtual table:
2121
//
2222
// CREATE VIRTUAL TABLE foo USING bloom_filter(nElements, falseProb, kHashes)
23-
func Register(db *sqlite3.Conn) {
24-
sqlite3.CreateModule(db, "bloom_filter", create, connect)
23+
func Register(db *sqlite3.Conn) error {
24+
return sqlite3.CreateModule(db, "bloom_filter", create, connect)
2525
}
2626

2727
type bloom struct {

ext/bloom/bloom_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ import (
1212
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
1313
)
1414

15+
func init() {
16+
sqlite3.AutoExtension(bloom.Register)
17+
}
18+
1519
func TestRegister(t *testing.T) {
1620
t.Parallel()
1721

@@ -21,8 +25,6 @@ func TestRegister(t *testing.T) {
2125
}
2226
defer db.Close()
2327

24-
bloom.Register(db)
25-
2628
err = db.Exec(`
2729
CREATE VIRTUAL TABLE sports_cars USING bloom_filter(20);
2830
INSERT INTO sports_cars VALUES ('ferrari'), ('lamborghini'), ('alfa romeo')
@@ -90,8 +92,6 @@ func Test_compatible(t *testing.T) {
9092
}
9193
defer db.Close()
9294

93-
bloom.Register(db)
94-
9595
query, _, err := db.Prepare(`SELECT COUNT(*) FROM plants(?)`)
9696
if err != nil {
9797
t.Fatal(err)

ext/csv/csv.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ import (
2323

2424
// Register registers the CSV virtual table.
2525
// If a filename is specified, [os.Open] is used to open the file.
26-
func Register(db *sqlite3.Conn) {
27-
RegisterFS(db, osutil.FS{})
26+
func Register(db *sqlite3.Conn) error {
27+
return RegisterFS(db, osutil.FS{})
2828
}
2929

3030
// RegisterFS registers the CSV virtual table.
3131
// If a filename is specified, fsys is used to open the file.
32-
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
32+
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
3333
declare := func(db *sqlite3.Conn, _, _, _ string, arg ...string) (_ *table, err error) {
3434
var (
3535
filename string
@@ -118,7 +118,7 @@ func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
118118
return table, nil
119119
}
120120

121-
sqlite3.CreateModule(db, "csv", declare, declare)
121+
return sqlite3.CreateModule(db, "csv", declare, declare)
122122
}
123123

124124
type table struct {

ext/csv/csv_test.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ func Example() {
1818
}
1919
defer db.Close()
2020

21-
csv.Register(db)
21+
err = csv.Register(db)
22+
if err != nil {
23+
log.Fatal(err)
24+
}
2225

2326
err = db.Exec(`
2427
CREATE VIRTUAL TABLE eurofxref USING csv(
@@ -51,6 +54,10 @@ func Example() {
5154
// On Twosday, 1€ = $1.1342
5255
}
5356

57+
func init() {
58+
sqlite3.AutoExtension(csv.Register)
59+
}
60+
5461
func TestRegister(t *testing.T) {
5562
t.Parallel()
5663

@@ -60,8 +67,6 @@ func TestRegister(t *testing.T) {
6067
}
6168
defer db.Close()
6269

63-
csv.Register(db)
64-
6570
const data = `
6671
# Comment
6772
"Rob" "Pike" rob
@@ -124,8 +129,6 @@ func TestAffinity(t *testing.T) {
124129
}
125130
defer db.Close()
126131

127-
csv.Register(db)
128-
129132
const data = "01\n0.10\ne"
130133
err = db.Exec(`
131134
CREATE VIRTUAL TABLE temp.nums USING csv(
@@ -168,8 +171,6 @@ func TestRegister_errors(t *testing.T) {
168171
}
169172
defer db.Close()
170173

171-
csv.Register(db)
172-
173174
err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv()`)
174175
if err == nil {
175176
t.Fatal("want error")

ext/fileio/fileio.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,26 @@ import (
1414

1515
// Register registers SQL functions readfile, writefile, lsmode,
1616
// and the table-valued function fsdir.
17-
func Register(db *sqlite3.Conn) {
18-
RegisterFS(db, nil)
17+
func Register(db *sqlite3.Conn) error {
18+
return RegisterFS(db, nil)
1919
}
2020

2121
// Register registers SQL functions readfile, lsmode,
2222
// and the table-valued function fsdir;
2323
// fsys will be used to read files and list directories.
24-
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) {
25-
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode)
26-
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys))
24+
func RegisterFS(db *sqlite3.Conn, fsys fs.FS) error {
25+
var err error
2726
if fsys == nil {
28-
db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
27+
err = db.CreateFunction("writefile", -1, sqlite3.DIRECTONLY, writefile)
2928
}
30-
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
31-
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
32-
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
33-
return fsdir{fsys}, err
34-
})
29+
return errors.Join(err,
30+
db.CreateFunction("readfile", 1, sqlite3.DIRECTONLY, readfile(fsys)),
31+
db.CreateFunction("lsmode", 1, sqlite3.DETERMINISTIC, lsmode),
32+
sqlite3.CreateModule(db, "fsdir", nil, func(db *sqlite3.Conn, _, _, _ string, _ ...string) (fsdir, error) {
33+
err := db.DeclareVTab(`CREATE TABLE x(name,mode,mtime TIMESTAMP,data,path HIDDEN,dir HIDDEN)`)
34+
db.VTabConfig(sqlite3.VTAB_DIRECTONLY)
35+
return fsdir{fsys}, err
36+
}))
3537
}
3638

3739
func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) {

0 commit comments

Comments
 (0)