Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Parameterized Queries #597

Merged
merged 9 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions bindings/nodejs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,20 @@ row.setOpts({ variantAsObject: true });
console.log(row.data());
```

Parameter binding

```javascript
const row = await this.conn.queryRow(
"SELECT $1, $2, $3, $4",
(params = [3, false, 4, "55"]),
);
const row = await this.conn.queryRow(
"SELECT :a, :b, :c, :d",
(params = { a: 3, b: false, c: 4, d: "55" }),
);
const row = await this.conn.queryRow("SELECT ?, ?, ?, ?", [3, false, 4, "55"]);
```

## Development

```shell
Expand Down
14 changes: 8 additions & 6 deletions bindings/nodejs/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,17 @@ export declare class Connection {
info(): Promise<ConnectionInfo>
/** Get the databend version. */
version(): Promise<string>
formatSql(sql: string, params?: Params | undefined | null): string
/** Execute a SQL query, return the number of affected rows. */
exec(sql: string): Promise<number>
exec(sql: string, params?: Params | undefined | null): Promise<number>
/** Execute a SQL query, and only return the first row. */
queryRow(sql: string): Promise<Row | null>
queryRow(sql: string, params?: Params | undefined | null): Promise<Row | null>
/** Execute a SQL query and fetch all data into the result */
queryAll(sql: string): Promise<Array<Row>>
queryAll(sql: string, params?: Params | undefined | null): Promise<Array<Row>>
/** Execute a SQL query, and return all rows. */
queryIter(sql: string): Promise<RowIterator>
queryIter(sql: string, params?: Params | undefined | null): Promise<RowIterator>
/** Execute a SQL query, and return all rows with schema and stats. */
queryIterExt(sql: string): Promise<RowIteratorExt>
queryIterExt(sql: string, params?: Params | undefined | null): Promise<RowIteratorExt>
/**
* Load data with stage attachment.
* The SQL can be `INSERT INTO tbl VALUES` or `REPLACE INTO tbl VALUES`.
Expand All @@ -52,7 +53,7 @@ export declare class Connection {
* Load file with stage attachment.
* The SQL can be `INSERT INTO tbl VALUES` or `REPLACE INTO tbl VALUES`.
*/
loadFile(sql: string, file: string, formatOptions: Record<string, string>, copyOptions?: Record<string, string> | undefined | null): Promise<ServerStats>
loadFile(sql: string, file: string, formatOptions?: Record<string, string> | undefined | null, copyOptions?: Record<string, string> | undefined | null): Promise<ServerStats>
}
export declare class ConnectionInfo {
get handler(): string
Expand Down Expand Up @@ -108,5 +109,6 @@ export declare class ServerStats {
get readBytes(): bigint
get writeRows(): bigint
get writeBytes(): bigint
get spillFileNums(): bigint
get runningTimeMs(): number
}
38 changes: 26 additions & 12 deletions bindings/nodejs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,14 @@ impl Client {

#[napi]
pub struct Connection {
inner: Box<dyn databend_driver::Connection>,
inner: databend_driver::Connection,
opts: ValueOptions,
}

pub type Params = serde_json::Value;

impl Connection {
pub fn new(inner: Box<dyn databend_driver::Connection>, opts: ValueOptions) -> Self {
pub fn new(inner: databend_driver::Connection, opts: ValueOptions) -> Self {
Self { inner, opts }
}
}
Expand All @@ -101,18 +103,26 @@ impl Connection {
self.inner.version().await.map_err(format_napi_error)
}

#[napi]
pub fn format_sql(&self, sql: String, params: Option<Params>) -> Result<String> {
Ok(self.inner.format_sql(&sql, params))
}

/// Execute a SQL query, return the number of affected rows.
#[napi]
pub async fn exec(&self, sql: String) -> Result<i64> {
self.inner.exec(&sql).await.map_err(format_napi_error)
pub async fn exec(&self, sql: String, params: Option<Params>) -> Result<i64> {
self.inner
.exec(&sql, params)
.await
.map_err(format_napi_error)
}

/// Execute a SQL query, and only return the first row.
#[napi]
pub async fn query_row(&self, sql: String) -> Result<Option<Row>> {
pub async fn query_row(&self, sql: String, params: Option<Params>) -> Result<Option<Row>> {
let ret = self
.inner
.query_row(&sql)
.query_row(&sql, params)
.await
.map_err(format_napi_error)?;
let row = ret.map(|r| Row::new(r, self.opts.clone()));
Expand All @@ -121,10 +131,10 @@ impl Connection {

/// Execute a SQL query and fetch all data into the result
#[napi]
pub async fn query_all(&self, sql: String) -> Result<Vec<Row>> {
pub async fn query_all(&self, sql: String, params: Option<Params>) -> Result<Vec<Row>> {
Ok(self
.inner
.query_all(&sql)
.query_all(&sql, params)
.await
.map_err(format_napi_error)?
.into_iter()
Expand All @@ -134,21 +144,25 @@ impl Connection {

/// Execute a SQL query, and return all rows.
#[napi]
pub async fn query_iter(&self, sql: String) -> Result<RowIterator> {
pub async fn query_iter(&self, sql: String, params: Option<Params>) -> Result<RowIterator> {
let iterator = self
.inner
.query_iter(&sql)
.query_iter(&sql, params)
.await
.map_err(format_napi_error)?;
Ok(RowIterator::new(iterator, self.opts.clone()))
}

/// Execute a SQL query, and return all rows with schema and stats.
#[napi]
pub async fn query_iter_ext(&self, sql: String) -> Result<RowIteratorExt> {
pub async fn query_iter_ext(
&self,
sql: String,
params: Option<Params>,
) -> Result<RowIteratorExt> {
let iterator = self
.inner
.query_iter_ext(&sql)
.query_iter_ext(&sql, params)
.await
.map_err(format_napi_error)?;
Ok(RowIteratorExt::new(iterator, self.opts.clone()))
Expand Down
33 changes: 32 additions & 1 deletion bindings/nodejs/tests/binding.js
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ Then("Select string {string} should be equal to {string}", async function (input
assert.equal(output, value);
});

Then("Select params binding", async function () {
{
const row = await this.conn.queryRow("SELECT $1, $2, $3, $4", (params = [3, false, 4, "55"]));
assert.deepEqual(row.values(), [3, false, 4, "55"]);
}

{
const row = await this.conn.queryRow("SELECT :a, :b, :c, :d", (params = { a: 3, b: false, c: 4, d: "55" }));
assert.deepEqual(row.values(), [3, false, 4, "55"]);
}

{
const row = await this.conn.queryRow("SELECT ?, ?, ?, ?", [3, false, 4, "55"]);
assert.deepEqual(row.values(), [3, false, 4, "55"]);
}
});

Then("Select types should be expected native types", async function () {
// BOOLEAN
{
Expand Down Expand Up @@ -74,7 +91,7 @@ Then("Select types should be expected native types", async function () {

// FLOAT
{
const row = await this.conn.queryRow("SELECT 1.11::FLOAT, 2.22::FLOAT");
const row = await this.conn.queryRow("SELECT ?::FLOAT, ?::FLOAT", (params = [1.11, 2.22]));
assert.deepEqual(
row.values().map((v) => v.toFixed(2)),
[1.11, 2.22],
Expand Down Expand Up @@ -166,6 +183,20 @@ Then("Select types should be expected native types", async function () {
],
});
}

// Variant as param
{
const value = [3, "15", { aa: 3 }];
const row = await this.conn.queryRow(`SELECT ?, ?, ?`, (params = value));
row.setOpts({ variantAsObject: true });
assert.deepEqual(row.values(), [
3,
"15",
{
aa: 3,
},
]);
}
});

Then("Select numbers should iterate all rows", async function () {
Expand Down
28 changes: 21 additions & 7 deletions bindings/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ async def main():
asyncio.run(main())
```

### Parameter bindings

```python
# Test with positional parameters
row = await context.conn.query_row("SELECT ?, ?, ?, ?", (3, False, 4, "55"))
row = await context.conn.query_row(
"SELECT :a, :b, :c, :d", {"a": 3, "b": False, "c": 4, "d": "55"}
)
row = await context.conn.query_row(
"SELECT ?", 3
)
row = await context.conn.query_row("SELECT ?, ?, ?, ?", params = (3, False, 4, "55"))
```

## Type Mapping

[Databend Types](https://docs.databend.com/sql/sql-reference/data-types/)
Expand Down Expand Up @@ -156,9 +170,9 @@ class AsyncDatabendClient:
class AsyncDatabendConnection:
async def info(self) -> ConnectionInfo: ...
async def version(self) -> str: ...
async def exec(self, sql: str) -> int: ...
async def query_row(self, sql: str) -> Row: ...
async def query_iter(self, sql: str) -> RowIterator: ...
async def exec(self, sql: str, params: list[string] | tuple[string] | any = None) -> int: ...
async def query_row(self, sql: str, params: list[string] | tuple[string] | any = None) -> Row: ...
async def query_iter(self, sql: str, params: list[string] | tuple[string] | any = None) -> RowIterator: ...
async def stream_load(self, sql: str, data: list[list[str]]) -> ServerStats: ...
async def load_file(self, sql: str, file: str, format_option: dict, copy_options: dict = None) -> ServerStats: ...
```
Expand All @@ -178,9 +192,9 @@ class BlockingDatabendClient:
class BlockingDatabendConnection:
def info(self) -> ConnectionInfo: ...
def version(self) -> str: ...
def exec(self, sql: str) -> int: ...
def query_row(self, sql: str) -> Row: ...
def query_iter(self, sql: str) -> RowIterator: ...
def exec(self, sql: str, params: list[string] | tuple[string] | any = None) -> int: ...
def query_row(self, sql: str, params: list[string] | tuple[string] | any = None) -> Row: ...
def query_iter(self, sql: str, params: list[string] | tuple[string] | any = None) -> RowIterator: ...
def stream_load(self, sql: str, data: list[list[str]]) -> ServerStats: ...
def load_file(self, sql: str, file: str, format_option: dict, copy_options: dict = None) -> ServerStats: ...
```
Expand All @@ -195,7 +209,7 @@ class BlockingDatabendCursor:
def rowcount(self) -> int: ...
def close(self) -> None: ...
def execute(self, operation: str, params: list[string] | tuple[string] = None) -> None | int: ...
def executemany(self, operation: str, params: list[list[string] | tuple[string]]) -> None | int: ...
def executemany(self, operation: str, params: list[string] | tuple[string] = None, values: list[list[string] | tuple[string]]) -> None | int: ...
def fetchone(self) -> Row | None: ...
def fetchmany(self, size: int = 1) -> list[Row]: ...
def fetchall(self) -> list[Row]: ...
Expand Down
70 changes: 60 additions & 10 deletions bindings/python/src/asyncio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ use std::sync::Arc;
use pyo3::prelude::*;
use pyo3_async_runtimes::tokio::future_into_py;

use crate::types::{ConnectionInfo, DriverError, Row, RowIterator, ServerStats, VERSION};
use crate::{
types::{ConnectionInfo, DriverError, Row, RowIterator, ServerStats, VERSION},
utils::to_sql_params,
};

#[pyclass(module = "databend_driver")]
pub struct AsyncDatabendClient(databend_driver::Client);
Expand All @@ -44,7 +47,7 @@ impl AsyncDatabendClient {
}

#[pyclass(module = "databend_driver")]
pub struct AsyncDatabendConnection(Arc<Box<dyn databend_driver::Connection>>);
pub struct AsyncDatabendConnection(Arc<databend_driver::Connection>);

#[pymethods]
impl AsyncDatabendConnection {
Expand All @@ -64,27 +67,63 @@ impl AsyncDatabendConnection {
})
}

pub fn exec<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<Bound<'p, PyAny>> {
#[pyo3(signature = (sql, params=None))]
pub fn format_sql(
&self,
_py: Python,
sql: String,
params: Option<Bound<PyAny>>,
) -> PyResult<String> {
let this = self.0.clone();
let params = to_sql_params(params);
Ok(this.format_sql(&sql, params))
}

#[pyo3(signature = (sql, params=None))]
pub fn exec<'p>(
&'p self,
py: Python<'p>,
sql: String,
params: Option<Bound<'p, PyAny>>,
) -> PyResult<Bound<'p, PyAny>> {
let this = self.0.clone();
let params = to_sql_params(params);
future_into_py(py, async move {
let res = this.exec(&sql).await.map_err(DriverError::new)?;
let res = this.exec(&sql, params).await.map_err(DriverError::new)?;
Ok(res)
})
}

pub fn query_row<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<Bound<'p, PyAny>> {
#[pyo3(signature = (sql, params=None))]
pub fn query_row<'p>(
&'p self,
py: Python<'p>,
sql: String,
params: Option<Bound<'p, PyAny>>,
) -> PyResult<Bound<'p, PyAny>> {
let this = self.0.clone();
let params = to_sql_params(params);
future_into_py(py, async move {
let row = this.query_row(&sql).await.map_err(DriverError::new)?;
let row = this
.query_row(&sql, params)
.await
.map_err(DriverError::new)?;
Ok(row.map(Row::new))
})
}

pub fn query_all<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<Bound<'p, PyAny>> {
#[pyo3(signature = (sql, params=None))]
pub fn query_all<'p>(
&'p self,
py: Python<'p>,
sql: String,
params: Option<Bound<'p, PyAny>>,
) -> PyResult<Bound<'p, PyAny>> {
let this = self.0.clone();
let params = to_sql_params(params);
future_into_py(py, async move {
let rows: Vec<Row> = this
.query_all(&sql)
.query_all(&sql, params)
.await
.map_err(DriverError::new)?
.into_iter()
Expand All @@ -94,10 +133,21 @@ impl AsyncDatabendConnection {
})
}

pub fn query_iter<'p>(&'p self, py: Python<'p>, sql: String) -> PyResult<Bound<'p, PyAny>> {
#[pyo3(signature = (sql, params=None))]
pub fn query_iter<'p>(
&'p self,
py: Python<'p>,
sql: String,
params: Option<Bound<'p, PyAny>>,
) -> PyResult<Bound<'p, PyAny>> {
let this = self.0.clone();
let params = to_sql_params(params);

future_into_py(py, async move {
let streamer = this.query_iter(&sql).await.map_err(DriverError::new)?;
let streamer = this
.query_iter(&sql, params)
.await
.map_err(DriverError::new)?;
Ok(RowIterator::new(streamer))
})
}
Expand Down
Loading
Loading