|
| 1 | +package main |
| 2 | + |
| 3 | +import ( |
| 4 | + "database/sql" |
| 5 | + "fmt" |
| 6 | + "log" |
| 7 | + "math" |
| 8 | + "math/rand" |
| 9 | + |
| 10 | + sqlite "github.com/mattn/go-sqlite3" |
| 11 | +) |
| 12 | + |
| 13 | +// Computes x^y |
| 14 | +func pow(x, y int64) int64 { |
| 15 | + return int64(math.Pow(float64(x), float64(y))) |
| 16 | +} |
| 17 | + |
| 18 | +// Computes the bitwise exclusive-or of all its arguments |
| 19 | +func xor(xs ...int64) int64 { |
| 20 | + var ret int64 |
| 21 | + for _, x := range xs { |
| 22 | + ret ^= x |
| 23 | + } |
| 24 | + return ret |
| 25 | +} |
| 26 | + |
| 27 | +// Returns a random number. It's actually deterministic here because |
| 28 | +// we don't seed the RNG, but it's an example of a non-pure function |
| 29 | +// from SQLite's POV. |
| 30 | +func getrand() int64 { |
| 31 | + return rand.Int63() |
| 32 | +} |
| 33 | + |
| 34 | +// Computes the standard deviation of a GROUPed BY set of values |
| 35 | +type stddev struct { |
| 36 | + xs []int64 |
| 37 | + // Running average calculation |
| 38 | + sum int64 |
| 39 | + n int64 |
| 40 | +} |
| 41 | + |
| 42 | +func newStddev() *stddev { return &stddev{} } |
| 43 | + |
| 44 | +func (s *stddev) Step(x int64) { |
| 45 | + s.xs = append(s.xs, x) |
| 46 | + s.sum += x |
| 47 | + s.n++ |
| 48 | +} |
| 49 | + |
| 50 | +func (s *stddev) Done() float64 { |
| 51 | + mean := float64(s.sum) / float64(s.n) |
| 52 | + var sqDiff []float64 |
| 53 | + for _, x := range s.xs { |
| 54 | + sqDiff = append(sqDiff, math.Pow(float64(x)-mean, 2)) |
| 55 | + } |
| 56 | + var dev float64 |
| 57 | + for _, x := range sqDiff { |
| 58 | + dev += x |
| 59 | + } |
| 60 | + dev /= float64(len(sqDiff)) |
| 61 | + return math.Sqrt(dev) |
| 62 | +} |
| 63 | + |
| 64 | +func main() { |
| 65 | + sql.Register("sqlite3_custom", &sqlite.SQLiteDriver{ |
| 66 | + ConnectHook: func(conn *sqlite.SQLiteConn) error { |
| 67 | + if err := conn.RegisterFunc("pow", pow, true); err != nil { |
| 68 | + return err |
| 69 | + } |
| 70 | + if err := conn.RegisterFunc("xor", xor, true); err != nil { |
| 71 | + return err |
| 72 | + } |
| 73 | + if err := conn.RegisterFunc("rand", getrand, false); err != nil { |
| 74 | + return err |
| 75 | + } |
| 76 | + if err := conn.RegisterAggregator("stddev", newStddev, true); err != nil { |
| 77 | + return err |
| 78 | + } |
| 79 | + return nil |
| 80 | + }, |
| 81 | + }) |
| 82 | + |
| 83 | + db, err := sql.Open("sqlite3_custom", ":memory:") |
| 84 | + if err != nil { |
| 85 | + log.Fatal("Failed to open database:", err) |
| 86 | + } |
| 87 | + defer db.Close() |
| 88 | + |
| 89 | + var i int64 |
| 90 | + err = db.QueryRow("SELECT pow(2,3)").Scan(&i) |
| 91 | + if err != nil { |
| 92 | + log.Fatal("POW query error:", err) |
| 93 | + } |
| 94 | + fmt.Println("pow(2,3) =", i) // 8 |
| 95 | + |
| 96 | + err = db.QueryRow("SELECT xor(1,2,3,4,5,6)").Scan(&i) |
| 97 | + if err != nil { |
| 98 | + log.Fatal("XOR query error:", err) |
| 99 | + } |
| 100 | + fmt.Println("xor(1,2,3,4,5) =", i) // 7 |
| 101 | + |
| 102 | + err = db.QueryRow("SELECT rand()").Scan(&i) |
| 103 | + if err != nil { |
| 104 | + log.Fatal("RAND query error:", err) |
| 105 | + } |
| 106 | + fmt.Println("rand() =", i) // pseudorandom |
| 107 | + |
| 108 | + _, err = db.Exec("create table foo (department integer, profits integer)") |
| 109 | + if err != nil { |
| 110 | + log.Fatal("Failed to create table:", err) |
| 111 | + } |
| 112 | + _, err = db.Exec("insert into foo values (1, 10), (1, 20), (1, 45), (2, 42), (2, 115)") |
| 113 | + if err != nil { |
| 114 | + log.Fatal("Failed to insert records:", err) |
| 115 | + } |
| 116 | + |
| 117 | + rows, err := db.Query("select department, stddev(profits) from foo group by department") |
| 118 | + if err != nil { |
| 119 | + log.Fatal("STDDEV query error:", err) |
| 120 | + } |
| 121 | + defer rows.Close() |
| 122 | + for rows.Next() { |
| 123 | + var dept int64 |
| 124 | + var dev float64 |
| 125 | + if err := rows.Scan(&dept, &dev); err != nil { |
| 126 | + log.Fatal(err) |
| 127 | + } |
| 128 | + fmt.Printf("dept=%d stddev=%f\n", dept, dev) |
| 129 | + } |
| 130 | + if err := rows.Err(); err != nil { |
| 131 | + log.Fatal(err) |
| 132 | + } |
| 133 | +} |
0 commit comments