Skip to content

Commit

Permalink
upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
sundy-li committed Feb 19, 2025
1 parent 0251725 commit 9488368
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 29 deletions.
14 changes: 7 additions & 7 deletions bindings/python/tests/asyncio/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def _(context):
assert row.values() == (b"xyz",), f"Binary: {row.values()}"

# Tuple
row = await context.conn.query_row("select to_binary(?)", params = ("xyz"))
row = await context.conn.query_row("select to_binary(?)", params=("xyz"))
assert row.values() == (b"xyz",), f"Binary: {row.values()}"

# Interval
Expand All @@ -87,9 +87,9 @@ async def _(context):

# Array
row = await context.conn.query_row("select [10::Decimal(15,2), 1.1+2.3]")
assert row.values() == ([Decimal("10.00"), Decimal("3.40")],), (
f"Array: {row.values()}"
)
assert row.values() == (
[Decimal("10.00"), Decimal("3.40")],
), f"Array: {row.values()}"

# Map
row = await context.conn.query_row("select {'xx':to_date('2020-01-01')}")
Expand All @@ -99,9 +99,9 @@ async def _(context):
row = await context.conn.query_row(
"select (10, '20', to_datetime('2024-04-16 12:34:56.789'))"
)
assert row.values() == ((10, "20", datetime(2024, 4, 16, 12, 34, 56, 789000)),), (
f"Tuple: {row.values()}"
)
assert row.values() == (
(10, "20", datetime(2024, 4, 16, 12, 34, 56, 789000)),
), f"Tuple: {row.values()}"


@then("Select numbers should iterate all rows")
Expand Down
14 changes: 7 additions & 7 deletions bindings/python/tests/blocking/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,17 @@ async def _(context):
assert row.values() == (timedelta(microseconds=1),), f"Interval: {row.values()}"

# Decimal
row = context.conn.query_row("SELECT 15.7563::Decimal(?,?), 2.0+3.0", params = [8, 4])
row = context.conn.query_row("SELECT 15.7563::Decimal(?,?), 2.0+3.0", params=[8, 4])
assert row.values() == (
Decimal("15.7563"),
Decimal("5.0"),
), f"Decimal: {row.values()}"

# Array
row = context.conn.query_row("select [10::Decimal(15,2), 1.1+2.3]")
assert row.values() == ([Decimal("10.00"), Decimal("3.40")],), (
f"Array: {row.values()}"
)
assert row.values() == (
[Decimal("10.00"), Decimal("3.40")],
), f"Array: {row.values()}"

# Map
row = context.conn.query_row("select {'xx':to_date('2020-01-01')}")
Expand All @@ -86,9 +86,9 @@ async def _(context):
row = context.conn.query_row(
"select (10, '20', to_datetime('2024-04-16 12:34:56.789'))"
)
assert row.values() == ((10, "20", datetime(2024, 4, 16, 12, 34, 56, 789000)),), (
f"Tuple: {row.values()}"
)
assert row.values() == (
(10, "20", datetime(2024, 4, 16, 12, 34, 56, 789000)),
), f"Tuple: {row.values()}"


@then("Select numbers should iterate all rows")
Expand Down
11 changes: 9 additions & 2 deletions driver/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl Params {
return v.replace_sql(self, &stmt, sql);
}
}
return sql.to_string();
sql.to_string()
}
}

Expand Down Expand Up @@ -144,7 +144,7 @@ impl Param for serde_json::Value {
}
s.push_str(&v.as_sql_string());
}
s.push_str("]");
s.push(']');
s
}
serde_json::Value::Object(map) => {
Expand Down Expand Up @@ -368,5 +368,12 @@ mod tests {
let replaced_sql = params.replace(sql);
assert_eq!(replaced_sql, "SELECT b = '44', a = 1 FROM table WHERE a = 1 AND '?' = cj AND b = '44' AND c = 2 AND d = 3 AND e = '55' AND f = '66'");
}

{
let params = params! {1, "44", 2, 3, "55", "66"};
let sql = "SELECT $3, $2, $1 FROM table WHERE a = $1 AND '?' = cj AND b = $2 AND c = $3 AND d = $4 AND e = $5 AND f = $6";
let replaced_sql = params.replace(sql);
assert_eq!(replaced_sql, "SELECT 2, '44', 1 FROM table WHERE a = 1 AND '?' = cj AND b = '44' AND c = 2 AND d = 3 AND e = '55' AND f = '66'");
}
}
}
52 changes: 39 additions & 13 deletions driver/src/place_holder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

use std::vec;

use databend_common_ast::ast::ColumnID;
use databend_common_ast::ast::ColumnPosition;
use databend_common_ast::ast::ColumnRef;
use databend_common_ast::ast::Expr;
use databend_common_ast::ast::Identifier;
use databend_common_ast::ast::IdentifierType;
Expand All @@ -25,16 +28,18 @@ use derive_visitor::Visitor;
use crate::Params;

#[derive(Visitor)]
#[visitor(Expr(enter), Identifier(enter))]
#[visitor(Expr(enter), Identifier(enter), ColumnRef(enter))]
pub(crate) struct PlaceholderVisitor {
place_holders: Vec<Range>,
column_positions: Vec<(usize, Range)>,
names: Vec<(String, Range)>,
}

impl PlaceholderVisitor {
pub fn new() -> Self {
PlaceholderVisitor {
place_holders: vec![],
column_positions: vec![],
names: Vec::new(),
}
}
Expand All @@ -45,21 +50,29 @@ impl PlaceholderVisitor {
name,
span: Some(range),
} => {
self.names.push((name.clone(), range.clone()));
self.names.push((name.clone(), *range));
}
Expr::Placeholder { span: Some(range) } => {
self.place_holders.push(range.clone());
self.place_holders.push(*range);
}
_ => {}
}
}

fn enter_identifier(&mut self, ident: &Identifier) {
match (ident.ident_type, ident.span) {
(IdentifierType::Hole, Some(range)) => {
self.names.push((ident.name.clone(), range));
}
_ => {}
if let (IdentifierType::Hole, Some(range)) = (ident.ident_type, ident.span) {
self.names.push((ident.name.clone(), range));
}
}

fn enter_column_ref(&mut self, r: &ColumnRef) {
if let ColumnID::Position(ColumnPosition {
span: Some(range),
pos,
..
}) = r.column
{
self.column_positions.push((pos, range));
}
}

Expand All @@ -71,26 +84,39 @@ impl PlaceholderVisitor {

for (index, range) in self.place_holders.iter().enumerate() {
if let Some(v) = params.get_by_index(index + 1) {
results.push((v.to_string(), range.clone()));
results.push((v.to_string(), *range));
}
}

for (name, range) in self.names.iter() {
if let Some(v) = params.get_by_name(name) {
results.push((v.to_string(), range.clone()));
results.push((v.to_string(), *range));
}
}

let mut sql = sql.to_string();
if !results.is_empty() {
let mut sql = sql.to_string();
results.sort_by(|a, b| a.1.start.cmp(&b.1.start));
for (value, r) in results.iter().rev() {
let start = r.start as usize;
let end = r.end as usize;
sql.replace_range(start..end, value);
}
return sql;
}
sql.to_string()

if !self.column_positions.is_empty() {
self.column_positions
.sort_by(|a, b| a.1.start.cmp(&b.1.start));

for (index, r) in self.column_positions.iter().rev() {
if let Some(value) = params.get_by_index(*index) {
let start = r.start as usize;
let end = r.end as usize;
sql.replace_range(start..end, value);
}
}
}

sql
}
}

0 comments on commit 9488368

Please sign in to comment.