@@ -3,6 +3,7 @@ package sqlite3
3
3
import (
4
4
"context"
5
5
"io"
6
+ "iter"
6
7
"sync"
7
8
8
9
"github.com/tetratelabs/wazero/api"
@@ -45,7 +46,7 @@ func (c Conn) AnyCollationNeeded() error {
45
46
// CreateCollation defines a new collating sequence.
46
47
//
47
48
// https://sqlite.org/c3ref/create_collation.html
48
- func (c * Conn ) CreateCollation (name string , fn func ( a , b [] byte ) int ) error {
49
+ func (c * Conn ) CreateCollation (name string , fn CollatingFunction ) error {
49
50
var funcPtr ptr_t
50
51
defer c .arena .mark ()()
51
52
namePtr := c .arena .string (name )
@@ -57,6 +58,10 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
57
58
return c .error (rc )
58
59
}
59
60
61
+ // Collating function is the type of a collation callback.
62
+ // Implementations must not retain a or b.
63
+ type CollatingFunction func (a , b []byte ) int
64
+
60
65
// CreateFunction defines a new scalar SQL function.
61
66
//
62
67
// https://sqlite.org/c3ref/create_function.html
@@ -77,34 +82,67 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn Scala
77
82
// Implementations must not retain arg.
78
83
type ScalarFunction func (ctx Context , arg ... Value )
79
84
85
+ // CreateAggregateFunction defines a new aggregate SQL function.
86
+ //
87
+ // https://sqlite.org/c3ref/create_function.html
88
+ func (c * Conn ) CreateAggregateFunction (name string , nArg int , flag FunctionFlag , fn AggregateSeqFunction ) error {
89
+ var funcPtr ptr_t
90
+ defer c .arena .mark ()()
91
+ namePtr := c .arena .string (name )
92
+ if fn != nil {
93
+ funcPtr = util .AddHandle (c .ctx , AggregateConstructor (func () AggregateFunction {
94
+ var a aggregateFunc
95
+ coro := func (yieldCoro func (struct {}) bool ) {
96
+ seq := func (yieldSeq func ([]Value ) bool ) {
97
+ for yieldSeq (a .arg ) {
98
+ if ! yieldCoro (struct {}{}) {
99
+ break
100
+ }
101
+ }
102
+ }
103
+ fn (& a .ctx , seq )
104
+ }
105
+ a .next , a .stop = iter .Pull (coro )
106
+ return & a
107
+ }))
108
+ }
109
+ rc := res_t (c .call ("sqlite3_create_aggregate_function_go" ,
110
+ stk_t (c .handle ), stk_t (namePtr ), stk_t (nArg ),
111
+ stk_t (flag ), stk_t (funcPtr )))
112
+ return c .error (rc )
113
+ }
114
+
115
+ // AggregateSeqFunction is the type of an aggregate SQL function.
116
+ // Implementations must not retain the slices produced by seq.
117
+ type AggregateSeqFunction func (ctx * Context , seq iter.Seq [[]Value ])
118
+
80
119
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
81
- // If fn returns a [WindowFunction], then an aggregate window function is created.
120
+ // If fn returns a [WindowFunction], an aggregate window function is created.
82
121
// If fn returns an [io.Closer], it will be called to free resources.
83
122
//
84
123
// https://sqlite.org/c3ref/create_function.html
85
- func (c * Conn ) CreateWindowFunction (name string , nArg int , flag FunctionFlag , fn func () AggregateFunction ) error {
124
+ func (c * Conn ) CreateWindowFunction (name string , nArg int , flag FunctionFlag , fn AggregateConstructor ) error {
86
125
var funcPtr ptr_t
87
126
defer c .arena .mark ()()
88
127
namePtr := c .arena .string (name )
89
- call := "sqlite3_create_aggregate_function_go"
90
128
if fn != nil {
91
- agg := fn ()
92
- if c , ok := agg .(io. Closer ); ok {
93
- if err := c . Close ( ); err != nil {
94
- return err
129
+ funcPtr = util . AddHandle ( c . ctx , AggregateConstructor ( func () AggregateFunction {
130
+ agg := fn ()
131
+ if win , ok := agg .( WindowFunction ); ok {
132
+ return win
95
133
}
96
- }
97
- if _ , ok := agg .(WindowFunction ); ok {
98
- call = "sqlite3_create_window_function_go"
99
- }
100
- funcPtr = util .AddHandle (c .ctx , fn )
134
+ return windowFunc {agg , name }
135
+ }))
101
136
}
102
- rc := res_t (c .call (call ,
137
+ rc := res_t (c .call ("sqlite3_create_window_function_go" ,
103
138
stk_t (c .handle ), stk_t (namePtr ), stk_t (nArg ),
104
139
stk_t (flag ), stk_t (funcPtr )))
105
140
return c .error (rc )
106
141
}
107
142
143
+ // AggregateConstructor is a an [AggregateFunction] constructor.
144
+ type AggregateConstructor func () AggregateFunction
145
+
108
146
// AggregateFunction is the interface an aggregate function should implement.
109
147
//
110
148
// https://sqlite.org/appfunc.html
@@ -153,7 +191,7 @@ func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTe
153
191
}
154
192
155
193
func compareCallback (ctx context.Context , mod api.Module , pApp ptr_t , nKey1 int32 , pKey1 ptr_t , nKey2 int32 , pKey2 ptr_t ) uint32 {
156
- fn := util .GetHandle (ctx , pApp ).(func ( a , b [] byte ) int )
194
+ fn := util .GetHandle (ctx , pApp ).(CollatingFunction )
157
195
return uint32 (fn (util .View (mod , pKey1 , int64 (nKey1 )), util .View (mod , pKey2 , int64 (nKey2 ))))
158
196
}
159
197
@@ -211,7 +249,7 @@ func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
211
249
}
212
250
213
251
// We need to create the aggregate.
214
- fn := util .GetHandle (db .ctx , pApp ).(func () AggregateFunction )()
252
+ fn := util .GetHandle (db .ctx , pApp ).(AggregateConstructor )()
215
253
if pAgg != 0 {
216
254
handle := util .AddHandle (db .ctx , fn )
217
255
util .Write32 (db .mod , pAgg , handle )
@@ -232,6 +270,7 @@ func callbackArgs(db *Conn, arg []Value, pArg ptr_t) {
232
270
var funcArgsPool sync.Pool
233
271
234
272
func putFuncArgs (p * [_MAX_FUNCTION_ARG ]Value ) {
273
+ clear (p [:])
235
274
funcArgsPool .Put (p )
236
275
}
237
276
@@ -242,3 +281,38 @@ func getFuncArgs() *[_MAX_FUNCTION_ARG]Value {
242
281
return p .(* [_MAX_FUNCTION_ARG ]Value )
243
282
}
244
283
}
284
+
285
+ type aggregateFunc struct {
286
+ ctx Context
287
+ arg []Value
288
+ next func () (struct {}, bool )
289
+ stop func ()
290
+ }
291
+
292
+ func (a * aggregateFunc ) Step (ctx Context , arg ... Value ) {
293
+ a .ctx = ctx
294
+ a .arg = arg
295
+ if _ , more := a .next (); ! more {
296
+ a .stop ()
297
+ }
298
+ }
299
+
300
+ func (a * aggregateFunc ) Value (ctx Context ) {
301
+ a .ctx = ctx
302
+ a .stop ()
303
+ }
304
+
305
+ func (a * aggregateFunc ) Close () error {
306
+ a .stop ()
307
+ return nil
308
+ }
309
+
310
+ type windowFunc struct {
311
+ AggregateFunction
312
+ name string
313
+ }
314
+
315
+ func (w windowFunc ) Inverse (ctx Context , arg ... Value ) {
316
+ // Implementing inverse allows certain queries that don't really need it to succeed.
317
+ ctx .ResultError (util .ErrorString (w .name + ": may not be used as a window function" ))
318
+ }
0 commit comments