Skip to content

Commit 26adda4

Browse files
authored
Seq aggregate functions (ncruces#229)
1 parent 2f6cd8d commit 26adda4

File tree

7 files changed

+163
-32
lines changed

7 files changed

+163
-32
lines changed

ext/fileio/fileio.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func lsmode(ctx sqlite3.Context, arg ...sqlite3.Value) {
4242
ctx.ResultText(fs.FileMode(arg[0].Int()).String())
4343
}
4444

45-
func readfile(fsys fs.FS) func(ctx sqlite3.Context, arg ...sqlite3.Value) {
45+
func readfile(fsys fs.FS) sqlite3.ScalarFunction {
4646
return func(ctx sqlite3.Context, arg ...sqlite3.Value) {
4747
var err error
4848
var data []byte

ext/fileio/fsdir.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ func (d fsdir) Open() (sqlite3.VTabCursor, error) {
6363

6464
type cursor struct {
6565
fsdir
66-
base string
67-
resume func() (entry, bool)
68-
cancel func()
69-
curr entry
70-
eof bool
71-
rowID int64
66+
base string
67+
next func() (entry, bool)
68+
stop func()
69+
curr entry
70+
eof bool
71+
rowID int64
7272
}
7373

7474
type entry struct {
@@ -78,8 +78,8 @@ type entry struct {
7878
}
7979

8080
func (c *cursor) Close() error {
81-
if c.cancel != nil {
82-
c.cancel()
81+
if c.stop != nil {
82+
c.stop()
8383
}
8484
return nil
8585
}
@@ -102,7 +102,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
102102
c.base = base
103103
}
104104

105-
c.resume, c.cancel = iter.Pull(func(yield func(entry) bool) {
105+
c.next, c.stop = iter.Pull(func(yield func(entry) bool) {
106106
walkDir := func(path string, d fs.DirEntry, err error) error {
107107
if yield(entry{d, err, path}) {
108108
return nil
@@ -121,7 +121,7 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
121121
}
122122

123123
func (c *cursor) Next() error {
124-
curr, ok := c.resume()
124+
curr, ok := c.next()
125125
c.curr = curr
126126
c.eof = !ok
127127
c.rowID++

ext/stats/boolean.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ const (
77
some
88
)
99

10-
func newBoolean(kind int) func() sqlite3.AggregateFunction {
10+
func newBoolean(kind int) sqlite3.AggregateConstructor {
1111
return func() sqlite3.AggregateFunction { return &boolean{kind: kind} }
1212
}
1313

ext/stats/percentile.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ const (
2121
percentile_disc
2222
)
2323

24-
func newPercentile(kind int) func() sqlite3.AggregateFunction {
24+
func newPercentile(kind int) sqlite3.AggregateConstructor {
2525
return func() sqlite3.AggregateFunction { return &percentile{kind: kind} }
2626
}
2727

ext/stats/stats.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ func special(kind int, n int64) (null, zero bool) {
130130
}
131131
}
132132

133-
func newVariance(kind int) func() sqlite3.AggregateFunction {
133+
func newVariance(kind int) sqlite3.AggregateConstructor {
134134
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
135135
}
136136

@@ -178,7 +178,7 @@ func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
178178
}
179179
}
180180

181-
func newCovariance(kind int) func() sqlite3.AggregateFunction {
181+
func newCovariance(kind int) sqlite3.AggregateConstructor {
182182
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
183183
}
184184

@@ -254,7 +254,7 @@ func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
254254
}
255255
}
256256

257-
func newMoments(kind int) func() sqlite3.AggregateFunction {
257+
func newMoments(kind int) sqlite3.AggregateConstructor {
258258
return func() sqlite3.AggregateFunction { return &momentfn{kind: kind} }
259259
}
260260

func.go

Lines changed: 90 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package sqlite3
33
import (
44
"context"
55
"io"
6+
"iter"
67
"sync"
78

89
"github.com/tetratelabs/wazero/api"
@@ -45,7 +46,7 @@ func (c Conn) AnyCollationNeeded() error {
4546
// CreateCollation defines a new collating sequence.
4647
//
4748
// https://sqlite.org/c3ref/create_collation.html
48-
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
49+
func (c *Conn) CreateCollation(name string, fn CollatingFunction) error {
4950
var funcPtr ptr_t
5051
defer c.arena.mark()()
5152
namePtr := c.arena.string(name)
@@ -57,6 +58,10 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
5758
return c.error(rc)
5859
}
5960

61+
// Collating function is the type of a collation callback.
62+
// Implementations must not retain a or b.
63+
type CollatingFunction func(a, b []byte) int
64+
6065
// CreateFunction defines a new scalar SQL function.
6166
//
6267
// https://sqlite.org/c3ref/create_function.html
@@ -77,34 +82,67 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn Scala
7782
// Implementations must not retain arg.
7883
type ScalarFunction func(ctx Context, arg ...Value)
7984

85+
// CreateAggregateFunction defines a new aggregate SQL function.
86+
//
87+
// https://sqlite.org/c3ref/create_function.html
88+
func (c *Conn) CreateAggregateFunction(name string, nArg int, flag FunctionFlag, fn AggregateSeqFunction) error {
89+
var funcPtr ptr_t
90+
defer c.arena.mark()()
91+
namePtr := c.arena.string(name)
92+
if fn != nil {
93+
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
94+
var a aggregateFunc
95+
coro := func(yieldCoro func(struct{}) bool) {
96+
seq := func(yieldSeq func([]Value) bool) {
97+
for yieldSeq(a.arg) {
98+
if !yieldCoro(struct{}{}) {
99+
break
100+
}
101+
}
102+
}
103+
fn(&a.ctx, seq)
104+
}
105+
a.next, a.stop = iter.Pull(coro)
106+
return &a
107+
}))
108+
}
109+
rc := res_t(c.call("sqlite3_create_aggregate_function_go",
110+
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
111+
stk_t(flag), stk_t(funcPtr)))
112+
return c.error(rc)
113+
}
114+
115+
// AggregateSeqFunction is the type of an aggregate SQL function.
116+
// Implementations must not retain the slices produced by seq.
117+
type AggregateSeqFunction func(ctx *Context, seq iter.Seq[[]Value])
118+
80119
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
81-
// If fn returns a [WindowFunction], then an aggregate window function is created.
120+
// If fn returns a [WindowFunction], an aggregate window function is created.
82121
// If fn returns an [io.Closer], it will be called to free resources.
83122
//
84123
// https://sqlite.org/c3ref/create_function.html
85-
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
124+
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn AggregateConstructor) error {
86125
var funcPtr ptr_t
87126
defer c.arena.mark()()
88127
namePtr := c.arena.string(name)
89-
call := "sqlite3_create_aggregate_function_go"
90128
if fn != nil {
91-
agg := fn()
92-
if c, ok := agg.(io.Closer); ok {
93-
if err := c.Close(); err != nil {
94-
return err
129+
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction {
130+
agg := fn()
131+
if win, ok := agg.(WindowFunction); ok {
132+
return win
95133
}
96-
}
97-
if _, ok := agg.(WindowFunction); ok {
98-
call = "sqlite3_create_window_function_go"
99-
}
100-
funcPtr = util.AddHandle(c.ctx, fn)
134+
return windowFunc{agg, name}
135+
}))
101136
}
102-
rc := res_t(c.call(call,
137+
rc := res_t(c.call("sqlite3_create_window_function_go",
103138
stk_t(c.handle), stk_t(namePtr), stk_t(nArg),
104139
stk_t(flag), stk_t(funcPtr)))
105140
return c.error(rc)
106141
}
107142

143+
// AggregateConstructor is a an [AggregateFunction] constructor.
144+
type AggregateConstructor func() AggregateFunction
145+
108146
// AggregateFunction is the interface an aggregate function should implement.
109147
//
110148
// https://sqlite.org/appfunc.html
@@ -153,7 +191,7 @@ func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTe
153191
}
154192

155193
func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 {
156-
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int)
194+
fn := util.GetHandle(ctx, pApp).(CollatingFunction)
157195
return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2))))
158196
}
159197

@@ -211,7 +249,7 @@ func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
211249
}
212250

213251
// We need to create the aggregate.
214-
fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)()
252+
fn := util.GetHandle(db.ctx, pApp).(AggregateConstructor)()
215253
if pAgg != 0 {
216254
handle := util.AddHandle(db.ctx, fn)
217255
util.Write32(db.mod, pAgg, handle)
@@ -232,6 +270,7 @@ func callbackArgs(db *Conn, arg []Value, pArg ptr_t) {
232270
var funcArgsPool sync.Pool
233271

234272
func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) {
273+
clear(p[:])
235274
funcArgsPool.Put(p)
236275
}
237276

@@ -242,3 +281,38 @@ func getFuncArgs() *[_MAX_FUNCTION_ARG]Value {
242281
return p.(*[_MAX_FUNCTION_ARG]Value)
243282
}
244283
}
284+
285+
type aggregateFunc struct {
286+
ctx Context
287+
arg []Value
288+
next func() (struct{}, bool)
289+
stop func()
290+
}
291+
292+
func (a *aggregateFunc) Step(ctx Context, arg ...Value) {
293+
a.ctx = ctx
294+
a.arg = arg
295+
if _, more := a.next(); !more {
296+
a.stop()
297+
}
298+
}
299+
300+
func (a *aggregateFunc) Value(ctx Context) {
301+
a.ctx = ctx
302+
a.stop()
303+
}
304+
305+
func (a *aggregateFunc) Close() error {
306+
a.stop()
307+
return nil
308+
}
309+
310+
type windowFunc struct {
311+
AggregateFunction
312+
name string
313+
}
314+
315+
func (w windowFunc) Inverse(ctx Context, arg ...Value) {
316+
// Implementing inverse allows certain queries that don't really need it to succeed.
317+
ctx.ResultError(util.ErrorString(w.name + ": may not be used as a window function"))
318+
}

func_seq_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package sqlite3_test
2+
3+
import (
4+
"fmt"
5+
"iter"
6+
"log"
7+
8+
"github.com/ncruces/go-sqlite3"
9+
_ "github.com/ncruces/go-sqlite3/embed"
10+
)
11+
12+
func ExampleConn_CreateAggregateFunction() {
13+
db, err := sqlite3.Open(":memory:")
14+
if err != nil {
15+
log.Fatal(err)
16+
}
17+
defer db.Close()
18+
19+
err = db.Exec(`CREATE TABLE test (col)`)
20+
if err != nil {
21+
log.Fatal(err)
22+
}
23+
24+
err = db.Exec(`INSERT INTO test VALUES (1), (2), (3)`)
25+
if err != nil {
26+
log.Fatal(err)
27+
}
28+
29+
err = db.CreateAggregateFunction("seq_avg", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS,
30+
func(ctx *sqlite3.Context, seq iter.Seq[[]sqlite3.Value]) {
31+
count := 0
32+
total := 0.0
33+
for arg := range seq {
34+
total += arg[0].Float()
35+
count++
36+
}
37+
ctx.ResultFloat(total / float64(count))
38+
})
39+
if err != nil {
40+
log.Fatal(err)
41+
}
42+
43+
stmt, _, err := db.Prepare(`SELECT seq_avg(col) FROM test`)
44+
if err != nil {
45+
log.Fatal(err)
46+
}
47+
defer stmt.Close()
48+
49+
for stmt.Step() {
50+
fmt.Println(stmt.ColumnFloat(0))
51+
}
52+
if err := stmt.Err(); err != nil {
53+
log.Fatal(err)
54+
}
55+
// Output:
56+
// 2
57+
}

0 commit comments

Comments
 (0)