Skip to content

Commit d112c4d

Browse files
committed
feat(sqlite): support expressions and multiple no-data statements in the macros
1 parent 0def87b commit d112c4d

File tree

8 files changed

+390
-38
lines changed

8 files changed

+390
-38
lines changed

sqlx-core/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ pub mod types;
4646
#[macro_use]
4747
pub mod query;
4848

49+
mod column;
4950
mod common;
5051
pub mod database;
5152
pub mod describe;
+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
use crate::describe::{Column, Describe};
2+
use crate::error::Error;
3+
use crate::sqlite::connection::explain::explain;
4+
use crate::sqlite::statement::SqliteStatement;
5+
use crate::sqlite::type_info::DataType;
6+
use crate::sqlite::{Sqlite, SqliteConnection, SqliteTypeInfo};
7+
use futures_core::future::BoxFuture;
8+
9+
pub(super) async fn describe(
10+
conn: &mut SqliteConnection,
11+
query: &str,
12+
) -> Result<Describe<Sqlite>, Error> {
13+
describe_with(conn, query, vec![]).await
14+
}
15+
16+
pub(super) fn describe_with<'c: 'e, 'q: 'e, 'e>(
17+
conn: &'c mut SqliteConnection,
18+
query: &'q str,
19+
fallback: Vec<SqliteTypeInfo>,
20+
) -> BoxFuture<'e, Result<Describe<Sqlite>, Error>> {
21+
Box::pin(async move {
22+
// describing a statement from SQLite can be involved
23+
// each SQLx statement is comprised of multiple SQL statements
24+
25+
let SqliteConnection {
26+
ref mut handle,
27+
ref worker,
28+
..
29+
} = conn;
30+
31+
let statement = SqliteStatement::prepare(handle, query, false);
32+
33+
let mut columns = Vec::new();
34+
let mut num_params = 0;
35+
36+
let mut statement = statement?;
37+
38+
// we start by finding the first statement that *can* return results
39+
while let Some((statement, _)) = statement.execute()? {
40+
num_params += statement.bind_parameter_count();
41+
42+
let mut stepped = false;
43+
44+
let num = statement.column_count();
45+
if num == 0 {
46+
// no columns in this statement; skip
47+
continue;
48+
}
49+
50+
// next we try to use [column_decltype] to inspect the type of each column
51+
columns.reserve(num);
52+
53+
for col in 0..num {
54+
let name = statement.column_name(col).to_owned();
55+
56+
let type_info = if let Some(ty) = statement.column_decltype(col) {
57+
ty
58+
} else {
59+
// if that fails, we back up and attempt to step the statement
60+
// once *if* its read-only and then use [column_type] as a
61+
// fallback to [column_decltype]
62+
if !stepped && statement.read_only() && fallback.is_empty() {
63+
stepped = true;
64+
65+
worker.execute(statement);
66+
worker.wake();
67+
68+
let _ = worker.step(statement).await?;
69+
}
70+
71+
let mut ty = statement.column_type_info(col);
72+
73+
if ty.0 == DataType::Null {
74+
if fallback.is_empty() {
75+
// this will _still_ fail if there are no actual rows to return
76+
// this happens more often than not for the macros as we tell
77+
// users to execute against an empty database
78+
79+
// as a last resort, we explain the original query and attempt to
80+
// infer what would the expression types be as a fallback
81+
// to [column_decltype]
82+
83+
let fallback = explain(conn, statement.sql()).await?;
84+
85+
return describe_with(conn, query, fallback).await;
86+
}
87+
88+
if let Some(fallback) = fallback.get(col).cloned() {
89+
ty = fallback;
90+
}
91+
}
92+
93+
ty
94+
};
95+
96+
let not_null = statement.column_not_null(col)?;
97+
98+
columns.push(Column {
99+
name,
100+
type_info: Some(type_info),
101+
not_null,
102+
});
103+
}
104+
}
105+
106+
// println!("describe ->> {:#?}", columns);
107+
108+
Ok(Describe {
109+
columns,
110+
params: vec![None; num_params],
111+
})
112+
})
113+
}

sqlx-core/src/sqlite/connection/executor.rs

+4-31
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ use std::sync::Arc;
33
use either::Either;
44
use futures_core::future::BoxFuture;
55
use futures_core::stream::BoxStream;
6-
use futures_util::TryStreamExt;
6+
use futures_util::{FutureExt, TryStreamExt};
77
use hashbrown::HashMap;
88

99
use crate::common::StatementCache;
10-
use crate::describe::{Column, Describe};
10+
use crate::describe::Describe;
1111
use crate::error::Error;
1212
use crate::executor::{Execute, Executor};
1313
use crate::ext::ustr::UStr;
14+
use crate::sqlite::connection::describe::describe;
1415
use crate::sqlite::connection::ConnectionHandle;
1516
use crate::sqlite::statement::{SqliteStatement, StatementHandle};
1617
use crate::sqlite::{Sqlite, SqliteArguments, SqliteConnection, SqliteRow};
@@ -176,34 +177,6 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
176177
'c: 'e,
177178
E: Execute<'q, Self::Database>,
178179
{
179-
let query = query.query();
180-
let statement = SqliteStatement::prepare(&mut self.handle, query, false);
181-
182-
Box::pin(async move {
183-
let mut params = Vec::new();
184-
let mut columns = Vec::new();
185-
186-
if let Some(statement) = statement?.handles.get(0) {
187-
// NOTE: we can infer *nothing* about parameters apart from the count
188-
params.resize(statement.bind_parameter_count(), None);
189-
190-
let num_columns = statement.column_count();
191-
columns.reserve(num_columns);
192-
193-
for i in 0..num_columns {
194-
let name = statement.column_name(i).to_owned();
195-
let type_info = statement.column_decltype(i);
196-
let not_null = statement.column_not_null(i)?;
197-
198-
columns.push(Column {
199-
name,
200-
type_info,
201-
not_null,
202-
})
203-
}
204-
}
205-
206-
Ok(Describe { params, columns })
207-
})
180+
describe(self, query.query()).boxed()
208181
}
209182
}
+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
use crate::error::Error;
2+
use crate::query_as::query_as;
3+
use crate::sqlite::type_info::DataType;
4+
use crate::sqlite::{SqliteConnection, SqliteTypeInfo};
5+
use hashbrown::HashMap;
6+
7+
const OP_INIT: &str = "Init";
8+
const OP_GOTO: &str = "Goto";
9+
const OP_COLUMN: &str = "Column";
10+
const OP_AGG_STEP: &str = "AggStep";
11+
const OP_MOVE: &str = "Move";
12+
const OP_COPY: &str = "Copy";
13+
const OP_SCOPY: &str = "SCopy";
14+
const OP_INT_COPY: &str = "IntCopy";
15+
const OP_STRING8: &str = "String8";
16+
const OP_INT64: &str = "Int64";
17+
const OP_INTEGER: &str = "Integer";
18+
const OP_REAL: &str = "Real";
19+
const OP_NOT: &str = "Not";
20+
const OP_BLOB: &str = "Blob";
21+
const OP_COUNT: &str = "Count";
22+
const OP_ROWID: &str = "Rowid";
23+
const OP_OR: &str = "Or";
24+
const OP_AND: &str = "And";
25+
const OP_BIT_AND: &str = "BitAnd";
26+
const OP_BIT_OR: &str = "BitOr";
27+
const OP_SHIFT_LEFT: &str = "ShiftLeft";
28+
const OP_SHIFT_RIGHT: &str = "ShiftRight";
29+
const OP_ADD: &str = "Add";
30+
const OP_SUBTRACT: &str = "Subtract";
31+
const OP_MULTIPLY: &str = "Multiply";
32+
const OP_DIVIDE: &str = "Divide";
33+
const OP_REMAINDER: &str = "Remainder";
34+
const OP_CONCAT: &str = "Concat";
35+
const OP_RESULT_ROW: &str = "ResultRow";
36+
37+
fn to_type(op: &str) -> DataType {
38+
match op {
39+
OP_REAL => DataType::Float,
40+
OP_BLOB => DataType::Blob,
41+
OP_AND | OP_OR => DataType::Bool,
42+
OP_ROWID | OP_COUNT | OP_INT64 | OP_INTEGER => DataType::Int64,
43+
OP_STRING8 => DataType::Text,
44+
OP_COLUMN | _ => DataType::Null,
45+
}
46+
}
47+
48+
pub(super) async fn explain(
49+
conn: &mut SqliteConnection,
50+
query: &str,
51+
) -> Result<Vec<SqliteTypeInfo>, Error> {
52+
let mut r = HashMap::<i64, DataType>::with_capacity(6);
53+
54+
let program =
55+
query_as::<_, (i64, String, i64, i64, i64, String)>(&*format!("EXPLAIN {}", query))
56+
.fetch_all(&mut *conn)
57+
.await?;
58+
59+
let mut program_i = 0;
60+
let program_size = program.len();
61+
62+
while program_i < program_size {
63+
let (_, ref opcode, p1, p2, p3, ref p4) = program[program_i];
64+
65+
match &**opcode {
66+
OP_INIT => {
67+
// start at <p2>
68+
program_i = p2 as usize;
69+
continue;
70+
}
71+
72+
OP_GOTO => {
73+
// goto <p2>
74+
program_i = p2 as usize;
75+
continue;
76+
}
77+
78+
OP_COLUMN => {
79+
// r[p3] = <value of column>
80+
r.insert(p3, DataType::Null);
81+
}
82+
83+
OP_AGG_STEP => {
84+
if p4.starts_with("count(") {
85+
// count(_) -> INTEGER
86+
r.insert(p3, DataType::Int64);
87+
} else if let Some(v) = r.get(&p2).copied() {
88+
// r[p3] = AGG ( r[p2] )
89+
r.insert(p3, v);
90+
}
91+
}
92+
93+
OP_COPY | OP_MOVE | OP_SCOPY | OP_INT_COPY => {
94+
// r[p2] = r[p1]
95+
if let Some(v) = r.get(&p1).copied() {
96+
r.insert(p2, v);
97+
}
98+
}
99+
100+
OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID => {
101+
// r[p2] = <value of constant>
102+
r.insert(p2, to_type(&opcode));
103+
}
104+
105+
OP_NOT => {
106+
// r[p2] = NOT r[p1]
107+
if let Some(a) = r.get(&p1).copied() {
108+
r.insert(p2, a);
109+
}
110+
}
111+
112+
OP_BIT_AND | OP_BIT_OR | OP_SHIFT_LEFT | OP_SHIFT_RIGHT | OP_ADD | OP_SUBTRACT
113+
| OP_MULTIPLY | OP_DIVIDE | OP_REMAINDER | OP_CONCAT => {
114+
// r[p3] = r[p1] + r[p2]
115+
match (r.get(&p1).copied(), r.get(&p2).copied()) {
116+
(Some(a), Some(b)) => {
117+
r.insert(p3, if matches!(a, DataType::Null) { b } else { a });
118+
}
119+
120+
(Some(v), None) => {
121+
r.insert(p3, v);
122+
}
123+
124+
(None, Some(v)) => {
125+
r.insert(p3, v);
126+
}
127+
128+
_ => {}
129+
}
130+
}
131+
132+
OP_RESULT_ROW => {
133+
// output = r[p1 .. p1 + p2]
134+
let mut output = Vec::with_capacity(p2 as usize);
135+
for i in p1..p1 + p2 {
136+
output.push(SqliteTypeInfo(r.remove(&i).unwrap_or(DataType::Null)));
137+
}
138+
139+
return Ok(output);
140+
}
141+
142+
_ => {
143+
// ignore unsupported operations
144+
// if we fail to find an r later, we just give up
145+
}
146+
}
147+
148+
program_i += 1;
149+
}
150+
151+
// no rows
152+
Ok(vec![])
153+
}

sqlx-core/src/sqlite/connection/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ use crate::sqlite::connection::establish::establish;
1515
use crate::sqlite::statement::{SqliteStatement, StatementWorker};
1616
use crate::sqlite::{Sqlite, SqliteConnectOptions};
1717

18+
mod describe;
1819
mod establish;
1920
mod executor;
21+
mod explain;
2022
mod handle;
2123

2224
pub(crate) use handle::ConnectionHandle;

sqlx-core/src/sqlite/statement/handle.rs

+21-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ use libsqlite3_sys::{
1111
sqlite3_column_count, sqlite3_column_database_name, sqlite3_column_decltype,
1212
sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, sqlite3_column_name,
1313
sqlite3_column_origin_name, sqlite3_column_table_name, sqlite3_column_type,
14-
sqlite3_column_value, sqlite3_db_handle, sqlite3_stmt, sqlite3_table_column_metadata,
15-
SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8,
14+
sqlite3_column_value, sqlite3_db_handle, sqlite3_sql, sqlite3_stmt, sqlite3_stmt_readonly,
15+
sqlite3_table_column_metadata, SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8,
1616
};
1717

1818
use crate::error::{BoxDynError, Error};
@@ -38,6 +38,21 @@ impl StatementHandle {
3838
sqlite3_db_handle(self.0.as_ptr())
3939
}
4040

41+
pub(crate) fn read_only(&self) -> bool {
42+
// https://sqlite.org/c3ref/stmt_readonly.html
43+
unsafe { sqlite3_stmt_readonly(self.0.as_ptr()) != 0 }
44+
}
45+
46+
pub(crate) fn sql(&self) -> &str {
47+
// https://sqlite.org/c3ref/expanded_sql.html
48+
unsafe {
49+
let raw = sqlite3_sql(self.0.as_ptr());
50+
debug_assert!(!raw.is_null());
51+
52+
from_utf8_unchecked(CStr::from_ptr(raw).to_bytes())
53+
}
54+
}
55+
4156
#[inline]
4257
pub(crate) fn last_error(&self) -> SqliteError {
4358
SqliteError::new(unsafe { self.db_handle() })
@@ -68,6 +83,10 @@ impl StatementHandle {
6883
}
6984
}
7085

86+
pub(crate) fn column_type_info(&self, index: usize) -> SqliteTypeInfo {
87+
SqliteTypeInfo(DataType::from_code(self.column_type(index)))
88+
}
89+
7190
#[inline]
7291
pub(crate) fn column_decltype(&self, index: usize) -> Option<SqliteTypeInfo> {
7392
unsafe {

sqlx-core/src/sqlite/type_info.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use libsqlite3_sys::{SQLITE_BLOB, SQLITE_FLOAT, SQLITE_INTEGER, SQLITE_NULL, SQL
77
use crate::error::BoxDynError;
88
use crate::type_info::TypeInfo;
99

10-
#[derive(Debug, Clone, Eq, PartialEq)]
10+
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
1111
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
1212
pub(crate) enum DataType {
1313
Null,

0 commit comments

Comments
 (0)