Skip to content

Commit

Permalink
feat(bindings): support load file (#547)
Browse files Browse the repository at this point in the history
  • Loading branch information
everpcpc authored Dec 24, 2024
1 parent c3f68e4 commit b568749
Show file tree
Hide file tree
Showing 16 changed files with 176 additions and 6 deletions.
5 changes: 5 additions & 0 deletions bindings/nodejs/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ export declare class Connection {
* The SQL can be `INSERT INTO tbl VALUES` or `REPLACE INTO tbl VALUES`.
*/
streamLoad(sql: string, data: Array<Array<string>>): Promise<ServerStats>
/**
* Load file with stage attachment.
* The SQL can be `INSERT INTO tbl VALUES` or `REPLACE INTO tbl VALUES`.
*/
loadFile(sql: string, file: string, formatOptions: Record<string, string>, copyOptions?: Record<string, string> | undefined | null): Promise<ServerStats>
}
export declare class ConnectionInfo {
get handler(): string
Expand Down
31 changes: 30 additions & 1 deletion bindings/nodejs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
#[macro_use]
extern crate napi_derive;

use std::collections::HashMap;
use std::{
collections::{BTreeMap, HashMap},
path::Path,
};

use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
use napi::{bindgen_prelude::*, Env};
Expand Down Expand Up @@ -162,6 +165,32 @@ impl Connection {
.map_err(format_napi_error)?;
Ok(ServerStats(ss))
}

/// Load file with stage attachment.
/// The SQL can be `INSERT INTO tbl VALUES` or `REPLACE INTO tbl VALUES`.
#[napi]
pub async fn load_file(
&self,
sql: String,
file: String,
format_options: Option<BTreeMap<String, String>>,
copy_options: Option<BTreeMap<String, String>>,
) -> Result<ServerStats> {
let format_options = match format_options {
None => None,
Some(ref opts) => Some(opts.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect()),
};
let copy_options = match copy_options {
None => None,
Some(ref opts) => Some(opts.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect()),
};
let ss = self
.inner
.load_file(&sql, Path::new(&file), format_options, copy_options)
.await
.map_err(format_napi_error)?;
Ok(ServerStats(ss))
}
}

#[napi]
Expand Down
18 changes: 18 additions & 0 deletions bindings/nodejs/tests/binding.js
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,21 @@ Then("Stream load and Select should be equal", async function () {
];
assert.deepEqual(ret, expected);
});

Then("Load file and Select should be equal", async function () {
const progress = await this.conn.loadFile(`INSERT INTO test VALUES`, "tests/data/test.csv", { type: "CSV" });
assert.equal(progress.writeRows, 3);
assert.equal(progress.writeBytes, 187);

const rows = await this.conn.queryIter("SELECT * FROM test");
const ret = [];
for await (const row of rows) {
ret.push(row.values());
}
const expected = [
[-1, 1, 1.0, "1", "1", new Date("2011-03-06"), new Date("2011-03-06T06:20:00Z")],
[-2, 2, 2.0, "2", "2", new Date("2012-05-31"), new Date("2012-05-31T11:20:00Z")],
[-3, 3, 3.0, "3", "2", new Date("2016-04-04"), new Date("2016-04-04T11:30:00Z")],
];
assert.deepEqual(ret, expected);
});
1 change: 1 addition & 0 deletions bindings/nodejs/tests/data
2 changes: 2 additions & 0 deletions bindings/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class AsyncDatabendConnection:
async def query_row(self, sql: str) -> Row: ...
async def query_iter(self, sql: str) -> RowIterator: ...
async def stream_load(self, sql: str, data: list[list[str]]) -> ServerStats: ...
async def load_file(self, sql: str, file: str, format_option: dict, copy_options: dict = None) -> ServerStats: ...
```

### BlockingDatabendClient
Expand All @@ -152,6 +153,7 @@ class BlockingDatabendConnection:
def query_row(self, sql: str) -> Row: ...
def query_iter(self, sql: str) -> RowIterator: ...
def stream_load(self, sql: str, data: list[list[str]]) -> ServerStats: ...
def load_file(self, sql: str, file: str, format_option: dict, copy_options: dict = None) -> ServerStats: ...
```

### Row
Expand Down
33 changes: 33 additions & 0 deletions bindings/python/src/asyncio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::BTreeMap;
use std::path::Path;
use std::sync::Arc;

use pyo3::prelude::*;
Expand Down Expand Up @@ -119,4 +121,35 @@ impl AsyncDatabendConnection {
Ok(ServerStats::new(ss))
})
}

#[pyo3(signature = (sql, fp, format_options, copy_options=None))]
pub fn load_file<'p>(
&'p self,
py: Python<'p>,
sql: String,
fp: String,
format_options: Option<BTreeMap<String, String>>,
copy_options: Option<BTreeMap<String, String>>,
) -> PyResult<Bound<'p, PyAny>> {
let this = self.0.clone();
future_into_py(py, async move {
let format_options = match format_options {
None => None,
Some(ref opts) => {
Some(opts.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect())
}
};
let copy_options = match copy_options {
None => None,
Some(ref opts) => {
Some(opts.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect())
}
};
let ss = this
.load_file(&sql, Path::new(&fp), format_options, copy_options)
.await
.map_err(DriverError::new)?;
Ok(ServerStats::new(ss))
})
}
}
32 changes: 32 additions & 0 deletions bindings/python/src/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::BTreeMap;
use std::path::Path;
use std::sync::Arc;

use pyo3::prelude::*;
Expand Down Expand Up @@ -109,4 +111,34 @@ impl BlockingDatabendConnection {
})?;
Ok(ServerStats::new(ret))
}

#[pyo3(signature = (sql, fp, format_options=None, copy_options=None))]
pub fn load_file<'p>(
&'p self,
py: Python<'p>,
sql: String,
fp: String,
format_options: Option<BTreeMap<String, String>>,
copy_options: Option<BTreeMap<String, String>>,
) -> PyResult<ServerStats> {
let this = self.0.clone();
let ret = wait_for_future(py, async move {
let format_options = match format_options {
None => None,
Some(ref opts) => {
Some(opts.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect())
}
};
let copy_options = match copy_options {
None => None,
Some(ref opts) => {
Some(opts.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect())
}
};
this.load_file(&sql, Path::new(&fp), format_options, copy_options)
.await
.map_err(DriverError::new)
})?;
Ok(ServerStats::new(ret))
}
}
20 changes: 20 additions & 0 deletions bindings/python/tests/asyncio/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,23 @@ async def _(context):
(-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)),
]
assert ret == expected, f"ret: {ret}"


@then("Load file and Select should be equal")
async def _(context):
progress = await context.conn.load_file(
"INSERT INTO test VALUES", "tests/data/test.csv", {"type": "CSV"}
)
assert progress.write_rows == 3, f"progress.write_rows: {progress.write_rows}"
assert progress.write_bytes == 187, f"progress.write_bytes: {progress.write_bytes}"

rows = await context.conn.query_iter("SELECT * FROM test")
ret = []
for row in rows:
ret.append(row.values())
expected = [
(-1, 1, 1.0, "1", "1", date(2011, 3, 6), datetime(2011, 3, 6, 6, 20)),
(-2, 2, 2.0, "2", "2", date(2012, 5, 31), datetime(2012, 5, 31, 11, 20)),
(-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)),
]
assert ret == expected, f"ret: {ret}"
20 changes: 20 additions & 0 deletions bindings/python/tests/blocking/steps/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,23 @@ def _(context):
(-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)),
]
assert ret == expected, f"ret: {ret}"


@then("Load file and Select should be equal")
def _(context):
progress = context.conn.load_file(
"INSERT INTO test VALUES", "tests/data/test.csv", {"type": "CSV"}
)
assert progress.write_rows == 3, f"progress.write_rows: {progress.write_rows}"
assert progress.write_bytes == 187, f"progress.write_bytes: {progress.write_bytes}"

rows = context.conn.query_iter("SELECT * FROM test")
ret = []
for row in rows:
ret.append(row.values())
expected = [
(-1, 1, 1.0, "1", "1", date(2011, 3, 6), datetime(2011, 3, 6, 6, 20)),
(-2, 2, 2.0, "2", "2", date(2012, 5, 31), datetime(2012, 5, 31, 11, 20)),
(-3, 3, 3.0, "3", "2", date(2016, 4, 4), datetime(2016, 4, 4, 11, 30)),
]
assert ret == expected, f"ret: {ret}"
1 change: 1 addition & 0 deletions bindings/python/tests/data
3 changes: 3 additions & 0 deletions bindings/tests/data/test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-1,1,1.0,1,1,2011-03-06,2011-03-06 06:20:00
-2,2,2.0,2,2,2012-05-31,2012-05-31 11:20:00
-3,3,3.0,3,2,2016-04-04,2016-04-04 11:30:00
5 changes: 5 additions & 0 deletions bindings/tests/features/binding.feature
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,8 @@ Feature: Databend Driver
Given A new Databend Driver Client
When Create a test table
Then Stream load and Select should be equal

Scenario: Load file
Given A new Databend Driver Client
When Create a test table
Then Load file and Select should be equal
2 changes: 1 addition & 1 deletion driver/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ pub trait Connection: Send + Sync {
&self,
sql: &str,
fp: &Path,
format_options: BTreeMap<&str, &str>,
format_options: Option<BTreeMap<&str, &str>>,
copy_options: Option<BTreeMap<&str, &str>>,
) -> Result<ServerStats>;
async fn stream_load(&self, sql: &str, data: Vec<Vec<&str>>) -> Result<ServerStats>;
Expand Down
2 changes: 1 addition & 1 deletion driver/src/flight_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl Connection for FlightSQLConnection {
&self,
_sql: &str,
_fp: &Path,
_format_options: BTreeMap<&str, &str>,
_format_options: Option<BTreeMap<&str, &str>>,
_copy_options: Option<BTreeMap<&str, &str>>,
) -> Result<ServerStats> {
Err(Error::Protocol(
Expand Down
3 changes: 2 additions & 1 deletion driver/src/rest_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ impl Connection for RestAPIConnection {
&self,
sql: &str,
fp: &Path,
mut format_options: BTreeMap<&str, &str>,
format_options: Option<BTreeMap<&str, &str>>,
copy_options: Option<BTreeMap<&str, &str>>,
) -> Result<ServerStats> {
info!(
Expand All @@ -173,6 +173,7 @@ impl Connection for RestAPIConnection {
let metadata = file.metadata().await?;
let data = Box::new(file);
let size = metadata.len();
let mut format_options = format_options.unwrap_or_else(Self::default_file_format_options);
if !format_options.contains_key("type") {
let file_type = fp
.extension()
Expand Down
4 changes: 2 additions & 2 deletions driver/tests/driver/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::{collections::BTreeMap, path::Path, vec};
use std::{path::Path, vec};

use chrono::{NaiveDateTime, Utc};
use databend_driver::Client;
Expand Down Expand Up @@ -78,7 +78,7 @@ async fn prepare_data_with_file(table: &str, file_type: &str, client: &Client) {
let fp = format!("tests/driver/data/books.{}", file_type);
let sql = format!("INSERT INTO `{}` VALUES", table);
let stats = conn
.load_file(&sql, Path::new(&fp), BTreeMap::new(), None)
.load_file(&sql, Path::new(&fp), None, None)
.await
.unwrap();
assert_eq!(stats.write_rows, 3);
Expand Down

0 comments on commit b568749

Please sign in to comment.