Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions sqlx-core/src/fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,22 @@ impl ReadDir {
}
}
}

#[cfg(feature = "_rt-tokio")]
pub async fn open_file<P: AsRef<Path>>(path: P) -> Result<tokio::fs::File, io::Error> {
if rt::rt_tokio::available() {
return tokio::fs::File::open(path).await;
}

rt::missing_rt(path);
}

#[cfg(all(feature = "_rt-async-io", not(feature = "_rt-tokio")))]
pub async fn open_file<P: AsRef<Path>>(path: P) -> Result<async_fs::File, io::Error> {
async_fs::File::open(path).await
}

#[cfg(all(not(feature = "_rt-async-io"), not(feature = "_rt-tokio")))]
pub async fn open_file<P: AsRef<Path>>(path: P) -> Result<futures_util::io::Empty, io::Error> {
rt::missing_rt(path)
}
17 changes: 16 additions & 1 deletion sqlx-mysql/src/connection/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::executor::{Execute, Executor};
use crate::ext::ustr::UStr;
use crate::io::MySqlBufExt;
use crate::logger::QueryLogger;
use crate::protocol::response::Status;
use crate::protocol::response::{LocalInfilePacket, Status};
use crate::protocol::statement::{
BinaryRow, Execute as StatementExecute, Prepare, PrepareOk, StmtClose,
};
Expand All @@ -22,7 +22,9 @@ use futures_core::stream::BoxStream;
use futures_core::Stream;
use futures_util::TryStreamExt;
use sqlx_core::column::{ColumnOrigin, TableColumn};
use sqlx_core::fs::open_file;
use sqlx_core::sql_str::SqlStr;
use std::path::PathBuf;
use std::{pin::pin, sync::Arc};

impl MySqlConnection {
Expand Down Expand Up @@ -209,6 +211,19 @@ impl MySqlConnection {
return Ok(());
}

if packet[0] == 0xfb {
// LocalInfileRequest
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html
let packet = packet.decode::<LocalInfilePacket>()?;
let path = PathBuf::from(String::from_utf8_lossy(&packet.filename).into_owned());
let file = open_file(&path).await.map_err(|_| err_protocol!("cannot open file {} for local infile request", path.display()))?;

self.inner.stream.send_stream(file).await?;

continue;
}


// otherwise, this first packet is the start of the result-set metadata,
*self.inner.stream.waiting.front_mut().unwrap() = Waiting::Row;

Expand Down
40 changes: 38 additions & 2 deletions sqlx-mysql/src/connection/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use bytes::{Buf, Bytes, BytesMut};

use crate::error::Error;
use crate::io::MySqlBufExt;
use crate::io::{ProtocolDecode, ProtocolEncode};
use crate::io::{AsyncRead, ProtocolDecode, ProtocolEncode};
use crate::net::{BufferedSocket, Socket};
use crate::protocol::response::{EofPacket, ErrPacket, OkPacket, Status};
use crate::protocol::{Capabilities, Packet};
Expand Down Expand Up @@ -43,7 +43,8 @@ impl<S: Socket> MySqlStream<S> {
| Capabilities::MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::PS_MULTI_RESULTS
| Capabilities::SSL;
| Capabilities::SSL
| Capabilities::LOCAL_FILES;

if options.database.is_some() {
capabilities |= Capabilities::CONNECT_WITH_DB;
Expand Down Expand Up @@ -108,6 +109,41 @@ impl<S: Socket> MySqlStream<S> {
Ok(())
}

/// Send data from a stream to the database server as MySQL packets
///
/// This is used to send data for a LOCAL INFILE query
pub(crate) async fn send_stream(
&mut self,
mut source: impl AsyncRead + Unpin,
) -> Result<(), Error> {
loop {
let buf = self.socket.write_buffer_mut();

// Write the CopyData format code and reserve space for the length + sequence_id
// This is safe even if empty, since we always need to send an empty packet at the end
buf.put_slice(b"\0\0\0\0");

let read = buf.read_from(&mut source).await?;
let read32 = i32::try_from(read)
.map_err(|_| err_protocol!("number of bytes read exceeds 2^31 - 1: {}", read))?;

// rewrite header (len + sequenceid)
let mut header = read32.to_le_bytes();
header[3] = self.sequence_id;
self.sequence_id = self.sequence_id.wrapping_add(1);

buf.get_mut()[..4].copy_from_slice(&header);

self.socket.flush().await?;

if read32 == 0 {
break;
}
}

Ok(())
}

pub(crate) fn write_packet<'en, T>(&mut self, payload: T) -> Result<(), Error>
where
T: ProtocolEncode<'en, Capabilities>,
Expand Down
37 changes: 37 additions & 0 deletions sqlx-mysql/src/protocol/response/local_infile.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use bytes::{Buf, Bytes};
use sqlx_core::io::{BufExt, ProtocolDecode};

use crate::error::Error;

/// Requests the client to send a file to the server, following a LOCAL INFILE statement
///
/// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_local_infile_request.html
#[derive(Debug)]
pub struct LocalInfilePacket {
pub filename: Vec<u8>,
}

impl ProtocolDecode<'_> for LocalInfilePacket {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let header = buf.get_u8();
if header != 0xfb {
return Err(err_protocol!(
"expected 0xfb (LocalInfileRequest) but found 0x{:02x}",
header
));
}

let filename = buf.get_bytes(buf.len()).to_vec();

Ok(Self { filename })
}
}

#[test]
fn test_decode_localinfile_packet() {
const DATA: &[u8] = b"\xfb\x64\x75\x6d\x6d\x79";

let p = LocalInfilePacket::decode(DATA.into()).unwrap();

assert_eq!(p.filename, b"dummy");
}
2 changes: 2 additions & 0 deletions sqlx-mysql/src/protocol/response/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

mod eof;
mod err;
mod local_infile;
mod ok;
mod status;

pub use eof::EofPacket;
pub use err::ErrPacket;
pub use local_infile::LocalInfilePacket;
pub use ok::OkPacket;
pub use status::Status;
2 changes: 2 additions & 0 deletions tests/mysql/fixtures/load_data_infile.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1,a
2,b
44 changes: 42 additions & 2 deletions tests/mysql/mysql.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use anyhow::Context;
use futures_util::TryStreamExt;
use sqlx::mysql::{MySql, MySqlConnection, MySqlPool, MySqlPoolOptions, MySqlRow};
use sqlx::{Column, Connection, Executor, Row, SqlSafeStr, Statement, TypeInfo};
use sqlx::{AssertSqlSafe, Column, Connection, Executor, Row, SqlSafeStr, Statement, TypeInfo};
use sqlx_core::connection::ConnectOptions;
use sqlx_core::types::Type;
use sqlx_mysql::MySqlConnectOptions;
Expand Down Expand Up @@ -599,7 +599,7 @@ async fn select_statement_count(conn: &mut MySqlConnection) -> Result<i64, sqlx:
SELECT COUNT(*)
FROM performance_schema.threads AS t
INNER JOIN performance_schema.prepared_statements_instances AS psi
ON psi.OWNER_THREAD_ID = t.THREAD_ID
ON psi.OWNER_THREAD_ID = t.THREAD_ID
WHERE t.processlist_id = CONNECTION_ID()
"#,
)
Expand Down Expand Up @@ -727,3 +727,43 @@ async fn any_blob_conversions() -> anyhow::Result<()> {

Ok(())
}

#[sqlx_macros::test]
async fn it_can_load_a_file() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;

let _ = conn
.execute(
r#"
CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY, name TEXT);
"#,
)
.await?;

let _ = conn.execute("SET GLOBAL local_infile = 1;").await?;

let file_path = env::current_dir()
.unwrap()
.join("tests/mysql/fixtures/load_data_infile.txt");

// Execute LOAD DATA LOCAL INFILE
let load_query = format!(
"LOAD DATA LOCAL INFILE '{}' INTO TABLE users FIELDS TERMINATED BY ',' LINES TERMINATED BY '\\n'",
file_path.display()
);

let result = conn.execute(AssertSqlSafe(load_query)).await;

if let Err(e) = result {
assert!(false, "{:?}", e)
}

let name = sqlx::query("SELECT name FROM users WHERE id = 1")
.try_map(|row: MySqlRow| row.try_get::<String, _>(0))
.fetch_one(&mut conn)
.await?;

assert_eq!("a", name);

Ok(())
}
Loading