diff --git a/bindings/nodejs/README.md b/bindings/nodejs/README.md index 3cff7df3b..ed91e3380 100644 --- a/bindings/nodejs/README.md +++ b/bindings/nodejs/README.md @@ -14,7 +14,7 @@ pnpm run build const { Client } = require("databend-driver"); const client = new Client( - "databend+http://root:root@localhost:8000/?sslmode=disable", + "databend://root:root@localhost:8000/?sslmode=disable", ); const conn = await client.getConn(); diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index cf44144ec..fa65859c1 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -17,6 +17,7 @@ chrono = { workspace = true } databend-driver = { workspace = true, features = ["rustls", "flight-sql"] } tokio-stream = { workspace = true } +csv = "1.3" ctor = "0.2" once_cell = "1.20" pyo3 = { version = "0.23.3", features = ["abi3-py37", "chrono"] } diff --git a/bindings/python/README.md b/bindings/python/README.md index daf5fdf03..385448f09 100644 --- a/bindings/python/README.md +++ b/bindings/python/README.md @@ -9,12 +9,40 @@ maturin develop ## Usage +### PEP 249 cursor object + +```python +from databend_driver import BlockingDatabendClient + +client = BlockingDatabendClient('databend://root:root@localhost:8000/?sslmode=disable') +cursor = client.cursor() + +cursor.execute( + """ + CREATE TABLE test ( + i64 Int64, + u64 UInt64, + f64 Float64, + s String, + s2 String, + d Date, + t DateTime + ) + """ +) +cursor.execute("INSERT INTO test VALUES", (1, 1, 1.0, 'hello', 'world', '2021-01-01', '2021-01-01 00:00:00')) +cursor.execute("SELECT * FROM test") +rows = cursor.fetchall() +for row in rows: + print(row.values()) +``` + ### Blocking ```python from databend_driver import BlockingDatabendClient -client = BlockingDatabendClient('databend+http://root:root@localhost:8000/?sslmode=disable') +client = BlockingDatabendClient('databend://root:root@localhost:8000/?sslmode=disable') conn = client.get_conn() conn.exec( """ @@ -41,7 +69,7 @@ import asyncio from databend_driver import AsyncDatabendClient async def main(): - client = AsyncDatabendClient('databend+http://root:root@localhost:8000/?sslmode=disable') + client = AsyncDatabendClient('databend://root:root@localhost:8000/?sslmode=disable') conn = await client.get_conn() await conn.exec( """ @@ -141,6 +169,7 @@ class AsyncDatabendConnection: class BlockingDatabendClient: def __init__(self, dsn: str): ... def get_conn(self) -> BlockingDatabendConnection: ... + def cursor(self) -> BlockingDatabendCursor: ... ``` ### BlockingDatabendConnection @@ -156,11 +185,26 @@ class BlockingDatabendConnection: def load_file(self, sql: str, file: str, format_option: dict, copy_options: dict = None) -> ServerStats: ... ``` +### BlockingDatabendCursor + +```python +class BlockingDatabendCursor: + def close(self) -> None: ... + def execute(self, operation: str, params: list[string] | tuple[string] = None) -> None | int: ... + def executemany(self, operation: str, params: list[list[string] | tuple[string]]) -> None | int: ... + def fetchone(self) -> Row: ... + def fetchall(self) -> list[Row]: ... +``` + ### Row ```python class Row: def values(self) -> tuple: ... + def __len__(self) -> int: ... + def __iter__(self) -> list: ... + def __dict__(self) -> dict: ... + def __getitem__(self, key: int | str) -> any: ... ``` ### RowIterator diff --git a/bindings/python/pyproject.toml b/bindings/python/pyproject.toml index e47159d38..ad9b4d6bf 100644 --- a/bindings/python/pyproject.toml +++ b/bindings/python/pyproject.toml @@ -13,6 +13,7 @@ license = { text = "Apache-2.0" } name = "databend-driver" readme = "README.md" requires-python = ">=3.7, < 3.14" +dynamic = ["version"] [project.urls] Repository = "https://github.com/databendlabs/bendsql" diff --git a/bindings/python/src/blocking.rs b/bindings/python/src/blocking.rs index 8f0adf869..1b9be255c 100644 --- a/bindings/python/src/blocking.rs +++ b/bindings/python/src/blocking.rs @@ -16,7 +16,11 @@ use std::collections::BTreeMap; use std::path::Path; use std::sync::Arc; +use pyo3::exceptions::{PyAttributeError, PyException}; use pyo3::prelude::*; +use pyo3::types::{PyList, PyTuple}; +use tokio::sync::Mutex; +use tokio_stream::StreamExt; use crate::types::{ConnectionInfo, DriverError, Row, RowIterator, ServerStats, VERSION}; use crate::utils::wait_for_future; @@ -41,6 +45,14 @@ impl BlockingDatabendClient { })?; Ok(BlockingDatabendConnection(Arc::new(conn))) } + + pub fn cursor(&self, py: Python) -> PyResult { + let this = self.0.clone(); + let conn = wait_for_future(py, async move { + this.get_conn().await.map_err(DriverError::new) + })?; + Ok(BlockingDatabendCursor::new(conn)) + } } #[pyclass(module = "databend_driver")] @@ -142,3 +154,177 @@ impl BlockingDatabendConnection { Ok(ServerStats::new(ret)) } } + +/// BlockingDatabendCursor is an object that follows PEP 249 +/// https://peps.python.org/pep-0249/#cursor-objects +#[pyclass(module = "databend_driver")] +pub struct BlockingDatabendCursor { + conn: Arc>, + rows: Option>>, + // buffer is used to store only the first row after execute + buffer: Vec, +} + +impl BlockingDatabendCursor { + fn new(conn: Box) -> Self { + Self { + conn: Arc::new(conn), + rows: None, + buffer: Vec::new(), + } + } +} + +impl BlockingDatabendCursor { + fn reset(&mut self) { + self.rows = None; + self.buffer.clear(); + } +} + +#[pymethods] +impl BlockingDatabendCursor { + pub fn close(&mut self, py: Python) -> PyResult<()> { + self.reset(); + wait_for_future(py, async move { + self.conn.close().await.map_err(DriverError::new) + })?; + Ok(()) + } + + #[pyo3(signature = (operation, parameters=None))] + pub fn execute<'p>( + &'p mut self, + py: Python<'p>, + operation: String, + parameters: Option>, + ) -> PyResult { + if let Some(param) = parameters { + return self.executemany(py, operation, [param].to_vec()); + } + + self.reset(); + let conn = self.conn.clone(); + // fetch first row after execute + // then we could finish the query directly if there's no result + let (first, rows) = wait_for_future(py, async move { + let mut rows = conn.query_iter(&operation).await?; + let first = rows.next().await.transpose()?; + Ok::<_, databend_driver::Error>((first, rows)) + }) + .map_err(DriverError::new)?; + if let Some(first) = first { + self.buffer.push(Row::new(first)); + } + self.rows = Some(Arc::new(Mutex::new(rows))); + Ok(py.None()) + } + + pub fn executemany<'p>( + &'p mut self, + py: Python<'p>, + operation: String, + parameters: Vec>, + ) -> PyResult { + self.reset(); + let conn = self.conn.clone(); + if let Some(param) = parameters.first() { + if param.downcast::().is_ok() || param.downcast::().is_ok() { + let bytes = format_csv(parameters)?; + let size = bytes.len() as u64; + let reader = Box::new(std::io::Cursor::new(bytes)); + let stats = wait_for_future(py, async move { + conn.load_data(&operation, reader, size, None, None) + .await + .map_err(DriverError::new) + })?; + let result = stats.write_rows.into_pyobject(py)?; + return Ok(result.into()); + } else { + return Err(PyAttributeError::new_err( + "Invalid parameter type, expected list or tuple", + )); + } + } + Ok(py.None()) + } + + pub fn fetchone(&mut self, py: Python) -> PyResult> { + if let Some(row) = self.buffer.pop() { + return Ok(Some(row)); + } + match self.rows { + Some(ref rows) => { + match wait_for_future(py, async move { rows.lock().await.next().await }) { + Some(row) => Ok(Some(Row::new(row.map_err(DriverError::new)?))), + None => Ok(None), + } + } + None => Ok(None), + } + } + + pub fn fetchall(&mut self, py: Python) -> PyResult> { + let mut result = self.buffer.drain(..).collect::>(); + match self.rows.take() { + Some(rows) => { + let fetched = wait_for_future(py, async move { + let mut rows = rows.lock().await; + let mut result = Vec::new(); + while let Some(row) = rows.next().await { + result.push(row); + } + result + }); + for row in fetched { + result.push(Row::new(row.map_err(DriverError::new)?)); + } + Ok(result) + } + None => Ok(vec![]), + } + } +} + +fn format_csv<'p>(parameters: Vec>) -> PyResult> { + let mut wtr = csv::WriterBuilder::new().from_writer(vec![]); + for row in parameters { + let iter = row.try_iter()?; + let data = iter + .map(|v| match v { + Ok(v) => to_csv_field(v), + Err(e) => Err(e.into()), + }) + .collect::, _>>()?; + wtr.write_record(data) + .map_err(|e| PyException::new_err(e.to_string())) + .unwrap(); + } + let bytes = wtr + .into_inner() + .map_err(|e| PyException::new_err(e.to_string())) + .unwrap(); + Ok(bytes) +} + +fn to_csv_field(v: Bound) -> PyResult { + match v.downcast::() { + Ok(v) => { + if let Ok(v) = v.extract::() { + Ok(v) + } else if let Ok(v) = v.extract::() { + Ok(v.to_string()) + } else if let Ok(v) = v.extract::() { + Ok(v.to_string()) + } else if let Ok(v) = v.extract::() { + Ok(v.to_string()) + } else { + Err(PyAttributeError::new_err(format!( + "Invalid parameter type for: {:?}, expected str, bool, int or float", + v + ))) + } + } + Err(e) => Err(e.into()), + } +} diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 20111b94c..d6bc14d43 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -20,7 +20,7 @@ mod utils; use pyo3::prelude::*; use crate::asyncio::{AsyncDatabendClient, AsyncDatabendConnection}; -use crate::blocking::{BlockingDatabendClient, BlockingDatabendConnection}; +use crate::blocking::{BlockingDatabendClient, BlockingDatabendConnection, BlockingDatabendCursor}; use crate::types::{ConnectionInfo, Field, Row, RowIterator, Schema, ServerStats}; #[pymodule] @@ -29,6 +29,7 @@ fn _databend_driver(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/bindings/python/src/types.rs b/bindings/python/src/types.rs index 9029740e9..95f5e95b5 100644 --- a/bindings/python/src/types.rs +++ b/bindings/python/src/types.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use chrono::{NaiveDate, NaiveDateTime}; use once_cell::sync::Lazy; -use pyo3::exceptions::{PyException, PyStopAsyncIteration, PyStopIteration}; +use pyo3::exceptions::{PyAttributeError, PyException, PyStopAsyncIteration, PyStopIteration}; use pyo3::sync::GILOnceCell; use pyo3::types::{PyBytes, PyDict, PyList, PyTuple, PyType}; use pyo3::{intern, IntoPyObjectExt}; @@ -162,6 +162,53 @@ impl Row { let tuple = PyTuple::new(py, vals)?; Ok(tuple) } + + pub fn __len__(&self) -> usize { + self.0.len() + } + + pub fn __iter__<'p>(&'p self, py: Python<'p>) -> PyResult> { + let vals = self.0.values().iter().map(|v| Value(v.clone())); + let list = PyList::new(py, vals)?; + Ok(list.into_bound()) + } + + pub fn __dict__<'p>(&'p self, py: Python<'p>) -> PyResult> { + let dict = PyDict::new(py); + let schema = self.0.schema(); + for (field, value) in schema.fields().iter().zip(self.0.values()) { + dict.set_item(&field.name, Value(value.clone()))?; + } + Ok(dict.into_bound()) + } + + fn get_by_index(&self, idx: usize) -> PyResult { + Ok(Value(self.0.values()[idx].clone())) + } + + fn get_by_field(&self, field: &str) -> PyResult { + let schema = self.0.schema(); + let idx = schema + .fields() + .iter() + .position(|f| f.name == field) + .ok_or_else(|| { + PyException::new_err(format!("field '{}' not found in schema", field)) + })?; + Ok(Value(self.0.values()[idx].clone())) + } + + pub fn __getitem__<'p>(&'p self, key: Bound<'p, PyAny>) -> PyResult { + if let Ok(idx) = key.extract::() { + self.get_by_index(idx) + } else if let Ok(field) = key.extract::() { + self.get_by_field(&field) + } else { + Err(PyAttributeError::new_err( + "key must be an integer or a string", + )) + } + } } #[pyclass(module = "databend_driver")] diff --git a/bindings/python/tests/asyncio/steps/binding.py b/bindings/python/tests/asyncio/steps/binding.py index f56dd79d0..a23368123 100644 --- a/bindings/python/tests/asyncio/steps/binding.py +++ b/bindings/python/tests/asyncio/steps/binding.py @@ -26,7 +26,7 @@ async def _(context): dsn = os.getenv( "TEST_DATABEND_DSN", - "databend+http://root:root@localhost:8000/?sslmode=disable", + "databend://root:root@localhost:8000/?sslmode=disable", ) client = databend_driver.AsyncDatabendClient(dsn) context.conn = await client.get_conn() diff --git a/bindings/python/tests/blocking/steps/binding.py b/bindings/python/tests/blocking/steps/binding.py index 44e7d54a7..6ff9e7349 100644 --- a/bindings/python/tests/blocking/steps/binding.py +++ b/bindings/python/tests/blocking/steps/binding.py @@ -24,7 +24,7 @@ def _(context): dsn = os.getenv( "TEST_DATABEND_DSN", - "databend+http://root:root@localhost:8000/?sslmode=disable", + "databend://root:root@localhost:8000/?sslmode=disable", ) client = databend_driver.BlockingDatabendClient(dsn) context.conn = client.get_conn() diff --git a/bindings/python/tests/cursor/binding.feature b/bindings/python/tests/cursor/binding.feature new file mode 120000 index 000000000..fcf9cd614 --- /dev/null +++ b/bindings/python/tests/cursor/binding.feature @@ -0,0 +1 @@ +../../../tests/features/binding.feature \ No newline at end of file diff --git a/bindings/python/tests/cursor/steps/binding.py b/bindings/python/tests/cursor/steps/binding.py new file mode 100644 index 000000000..4cbabdc1c --- /dev/null +++ b/bindings/python/tests/cursor/steps/binding.py @@ -0,0 +1,152 @@ +# Copyright 2021 Datafuse Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from datetime import datetime, date +from decimal import Decimal + +from behave import given, when, then +import databend_driver + + +@given("A new Databend Driver Client") +def _(context): + dsn = os.getenv( + "TEST_DATABEND_DSN", + "databend://root:root@localhost:8000/?sslmode=disable", + ) + client = databend_driver.BlockingDatabendClient(dsn) + context.cursor = client.cursor() + + +@when("Create a test table") +def _(context): + context.cursor.execute("DROP TABLE IF EXISTS test") + context.cursor.execute( + """ + CREATE TABLE test ( + i64 Int64, + u64 UInt64, + f64 Float64, + s String, + s2 String, + d Date, + t DateTime + ) + """ + ) + + +@then("Select string {input} should be equal to {output}") +def _(context, input, output): + context.cursor.execute(f"SELECT '{input}'") + row = context.cursor.fetch_one() + assert output == row[0], f"output: {output}" + + +@then("Select types should be expected native types") +async def _(context): + # Binary + context.cursor.execute("select to_binary('xyz')") + row = context.cursor.fetch_one() + assert row[0] == b"xyz", f"Binary: {row.values()}" + + # Decimal + context.cursor.execute("SELECT 15.7563::Decimal(8,4), 2.0+3.0") + row = context.cursor.fetch_one() + assert row.values() == ( + Decimal("15.7563"), + Decimal("5.0"), + ), f"Decimal: {row.values()}" + + # Array + context.cursor.execute("select [10::Decimal(15,2), 1.1+2.3]") + row = context.cursor.fetch_one() + assert row.values() == ( + [Decimal("10.00"), Decimal("3.40")], + ), f"Array: {row.values()}" + + # Map + context.cursor.execute("select {'xx':to_date('2020-01-01')}") + row = context.cursor.fetch_one() + assert row.values() == ({"xx": date(2020, 1, 1)},), f"Map: {row.values()}" + + # Tuple + context.cursor.execute("select (10, '20', to_datetime('2024-04-16 12:34:56.789'))") + row = context.cursor.fetch_one() + assert row.values() == ( + (10, "20", datetime(2024, 4, 16, 12, 34, 56, 789000)), + ), f"Tuple: {row.values()}" + + +@then("Select numbers should iterate all rows") +def _(context): + context.cursor.execute("SELECT number FROM numbers(5)") + rows = context.cursor.fetch_all() + ret = [] + for row in rows: + ret.append(row[0]) + expected = [0, 1, 2, 3, 4] + assert ret == expected, f"ret: {ret}" + + +@then("Insert and Select should be equal") +def _(context): + context.cursor.exec( + """ + INSERT INTO test VALUES + (-1, 1, 1.0, '1', '1', '2011-03-06', '2011-03-06 06:20:00'), + (-2, 2, 2.0, '2', '2', '2012-05-31', '2012-05-31 11:20:00'), + (-3, 3, 3.0, '3', '2', '2016-04-04', '2016-04-04 11:30:00') + """ + ) + context.cursor.execute("SELECT * FROM test") + rows = context.cursor.fetch_all() + ret = [] + for row in rows: + ret.append(row.values()) + expected = [ + (-1, 1, 1.0, "1", "1", date(2011, 3, 6), datetime(2011, 3, 6, 6, 20)), + (-2, 2, 2.0, "2", "2", date(2012, 5, 31), datetime(2012, 5, 31, 11, 20)), + (-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)), + ] + assert ret == expected, f"ret: {ret}" + + +@then("Stream load and Select should be equal") +def _(context): + values = [ + [-1, 1, 1.0, "1", "1", "2011-03-06", "2011-03-06T06:20:00Z"], + (-2, "2", 2.0, "2", "2", "2012-05-31", "2012-05-31T11:20:00Z"), + ["-3", 3, 3.0, "3", "2", "2016-04-04", "2016-04-04T11:30:00Z"], + ] + count = context.cursor.executemany("INSERT INTO test VALUES", values) + assert count == 3, f"count: {count}" + + context.cursor.execute("SELECT * FROM test") + rows = context.cursor.fetch_all() + ret = [] + for row in rows: + ret.append(row.values()) + expected = [ + (-1, 1, 1.0, "1", "1", date(2011, 3, 6), datetime(2011, 3, 6, 6, 20)), + (-2, 2, 2.0, "2", "2", date(2012, 5, 31), datetime(2012, 5, 31, 11, 20)), + (-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)), + ] + assert ret == expected, f"ret: {ret}" + + +@then("Load file and Select should be equal") +def _(context): + print("SKIP") diff --git a/cli/test.sh b/cli/test.sh index 0e0f3085a..4cd849fec 100755 --- a/cli/test.sh +++ b/cli/test.sh @@ -33,7 +33,7 @@ case $TEST_HANDLER in ;; "http") echo "==> Testing REST API handler" - export BENDSQL_DSN="databend+http://${DATABEND_USER}:${DATABEND_PASSWORD}@${DATABEND_HOST}:${DATABEND_PORT}/?sslmode=disable&presign=on" + export BENDSQL_DSN="databend://${DATABEND_USER}:${DATABEND_PASSWORD}@${DATABEND_HOST}:${DATABEND_PORT}/?sslmode=disable&presign=on" ;; *) echo "Usage: $0 [flight|http]" diff --git a/driver/src/rest_api.rs b/driver/src/rest_api.rs index 638c2021e..39ef627c5 100644 --- a/driver/src/rest_api.rs +++ b/driver/src/rest_api.rs @@ -289,21 +289,26 @@ impl Stream for RestAPIRows { if let Some(ss) = self.stats.take() { return Poll::Ready(Some(Ok(RowWithStats::Stats(ss)))); } - if let Some(row) = self.data.pop_front() { - let row = Row::try_from((self.schema.clone(), row))?; - return Poll::Ready(Some(Ok(RowWithStats::Row(row)))); + // Skip to fetch next page if there is only one row left in buffer. + // Therefore we could guarantee the `/final` called before the last row. + if self.data.len() > 1 { + if let Some(row) = self.data.pop_front() { + let row = Row::try_from((self.schema.clone(), row))?; + return Poll::Ready(Some(Ok(RowWithStats::Row(row)))); + } } match self.next_page { Some(ref mut next_page) => match Pin::new(next_page).poll(cx) { Poll::Ready(Ok(resp)) => { - self.data = resp.data.into(); if self.schema.fields().is_empty() { self.schema = Arc::new(resp.schema.try_into()?); } self.next_uri = resp.next_uri; self.next_page = None; - self.stats = Some(ServerStats::from(resp.stats)); - self.poll_next(cx) + let mut new_data = resp.data.into(); + self.data.append(&mut new_data); + let stats = ServerStats::from(resp.stats); + Poll::Ready(Some(Ok(RowWithStats::Stats(stats)))) } Poll::Ready(Err(e)) => { self.next_page = None; @@ -325,7 +330,13 @@ impl Stream for RestAPIRows { })); self.poll_next(cx) } - None => Poll::Ready(None), + None => match self.data.pop_front() { + Some(row) => { + let row = Row::try_from((self.schema.clone(), row))?; + Poll::Ready(Some(Ok(RowWithStats::Row(row)))) + } + None => Poll::Ready(None), + }, }, } }