Skip to content

Commit 26917df

Browse files
committed
Implement support for aggregation functions implemented in Go.
1 parent b037a61 commit 26917df

File tree

5 files changed

+449
-33
lines changed

5 files changed

+449
-33
lines changed
6.3 MB
Binary file not shown.

_example/go_custom_funcs/main.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package main
2+
3+
import (
4+
"database/sql"
5+
"fmt"
6+
"log"
7+
"math"
8+
"math/rand"
9+
10+
sqlite "github.com/mattn/go-sqlite3"
11+
)
12+
13+
// Computes x^y
14+
func pow(x, y int64) int64 {
15+
return int64(math.Pow(float64(x), float64(y)))
16+
}
17+
18+
// Computes the bitwise exclusive-or of all its arguments
19+
func xor(xs ...int64) int64 {
20+
var ret int64
21+
for _, x := range xs {
22+
ret ^= x
23+
}
24+
return ret
25+
}
26+
27+
// Returns a random number. It's actually deterministic here because
28+
// we don't seed the RNG, but it's an example of a non-pure function
29+
// from SQLite's POV.
30+
func getrand() int64 {
31+
return rand.Int63()
32+
}
33+
34+
// Computes the standard deviation of a GROUPed BY set of values
35+
type stddev struct {
36+
xs []int64
37+
// Running average calculation
38+
sum int64
39+
n int64
40+
}
41+
42+
func newStddev() *stddev { return &stddev{} }
43+
44+
func (s *stddev) Step(x int64) {
45+
s.xs = append(s.xs, x)
46+
s.sum += x
47+
s.n++
48+
}
49+
50+
func (s *stddev) Done() float64 {
51+
mean := float64(s.sum) / float64(s.n)
52+
var sqDiff []float64
53+
for _, x := range s.xs {
54+
sqDiff = append(sqDiff, math.Pow(float64(x)-mean, 2))
55+
}
56+
var dev float64
57+
for _, x := range sqDiff {
58+
dev += x
59+
}
60+
dev /= float64(len(sqDiff))
61+
return math.Sqrt(dev)
62+
}
63+
64+
func main() {
65+
sql.Register("sqlite3_custom", &sqlite.SQLiteDriver{
66+
ConnectHook: func(conn *sqlite.SQLiteConn) error {
67+
if err := conn.RegisterFunc("pow", pow, true); err != nil {
68+
return err
69+
}
70+
if err := conn.RegisterFunc("xor", xor, true); err != nil {
71+
return err
72+
}
73+
if err := conn.RegisterFunc("rand", getrand, false); err != nil {
74+
return err
75+
}
76+
if err := conn.RegisterAggregator("stddev", newStddev, true); err != nil {
77+
return err
78+
}
79+
return nil
80+
},
81+
})
82+
83+
db, err := sql.Open("sqlite3_custom", ":memory:")
84+
if err != nil {
85+
log.Fatal("Failed to open database:", err)
86+
}
87+
defer db.Close()
88+
89+
var i int64
90+
err = db.QueryRow("SELECT pow(2,3)").Scan(&i)
91+
if err != nil {
92+
log.Fatal("POW query error:", err)
93+
}
94+
fmt.Println("pow(2,3) =", i) // 8
95+
96+
err = db.QueryRow("SELECT xor(1,2,3,4,5,6)").Scan(&i)
97+
if err != nil {
98+
log.Fatal("XOR query error:", err)
99+
}
100+
fmt.Println("xor(1,2,3,4,5) =", i) // 7
101+
102+
err = db.QueryRow("SELECT rand()").Scan(&i)
103+
if err != nil {
104+
log.Fatal("RAND query error:", err)
105+
}
106+
fmt.Println("rand() =", i) // pseudorandom
107+
108+
_, err = db.Exec("create table foo (department integer, profits integer)")
109+
if err != nil {
110+
log.Fatal("Failed to create table:", err)
111+
}
112+
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115)")
113+
if err != nil {
114+
log.Fatal("Failed to insert records:", err)
115+
}
116+
117+
rows, err := db.Query("select department, stddev(profits) from foo group by department")
118+
if err != nil {
119+
log.Fatal("STDDEV query error:", err)
120+
}
121+
defer rows.Close()
122+
for rows.Next() {
123+
var dept int64
124+
var dev float64
125+
if err := rows.Scan(&dept, &dev); err != nil {
126+
log.Fatal(err)
127+
}
128+
fmt.Printf("dept=%d stddev=%f\n", dept, dev)
129+
}
130+
if err := rows.Err(); err != nil {
131+
log.Fatal(err)
132+
}
133+
}

callback.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ package sqlite3
1212

1313
/*
1414
#include <sqlite3-binding.h>
15+
#include <stdlib.h>
1516
1617
void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
1718
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
@@ -32,6 +33,19 @@ func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value
3233
fi.Call(ctx, args)
3334
}
3435

36+
//export stepTrampoline
37+
func stepTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
38+
args := (*[1 << 30]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
39+
ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
40+
ai.Step(ctx, args)
41+
}
42+
43+
//export doneTrampoline
44+
func doneTrampoline(ctx *C.sqlite3_context) {
45+
ai := (*aggInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
46+
ai.Done(ctx)
47+
}
48+
3549
// This is only here so that tests can refer to it.
3650
type callbackArgRaw C.sqlite3_value
3751

@@ -158,6 +172,33 @@ func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
158172
}
159173
}
160174

175+
func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) {
176+
var args []reflect.Value
177+
178+
if len(argv) < len(converters) {
179+
return nil, fmt.Errorf("function requires at least %d arguments", len(converters))
180+
}
181+
182+
for i, arg := range argv[:len(converters)] {
183+
v, err := converters[i](arg)
184+
if err != nil {
185+
return nil, err
186+
}
187+
args = append(args, v)
188+
}
189+
190+
if variadic != nil {
191+
for _, arg := range argv[len(converters):] {
192+
v, err := variadic(arg)
193+
if err != nil {
194+
return nil, err
195+
}
196+
args = append(args, v)
197+
}
198+
}
199+
return args, nil
200+
}
201+
161202
type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
162203

163204
func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
@@ -233,6 +274,12 @@ func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
233274
}
234275
}
235276

277+
func callbackError(ctx *C.sqlite3_context, err error) {
278+
cstr := C.CString(err.Error())
279+
defer C.free(unsafe.Pointer(cstr))
280+
C.sqlite3_result_error(ctx, cstr, -1)
281+
}
282+
236283
// Test support code. Tests are not allowed to import "C", so we can't
237284
// declare any functions that use C.sqlite3_value.
238285
func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {

0 commit comments

Comments
 (0)