Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo):
self._connector = RedshiftConnector(connection_info)
elif data_source == DataSource.postgres:
self._connector = PostgresConnector(connection_info)
elif data_source == DataSource.mysql:
self._connector = MySqlConnector(connection_info)
else:
self._connector = SimpleConnector(data_source, connection_info)

Expand Down Expand Up @@ -294,6 +296,39 @@ def close(self) -> None:
self.connection = None


class MySqlConnector(SimpleConnector):
def __init__(self, connection_info: ConnectionInfo):
super().__init__(DataSource.mysql, connection_info)

def _handle_pyarrow_unsupported_type(self, ibis_table: Table, **kwargs) -> Table:
result_table = ibis_table
for name, dtype in ibis_table.schema().items():
if isinstance(dtype, Decimal):
# Round decimal columns to a specified scale
result_table = self._round_decimal_columns(
result_table=result_table, col_name=name, **kwargs
)
elif isinstance(dtype, UUID):
# Convert UUID to string for compatibility
result_table = self._cast_uuid_columns(
result_table=result_table, col_name=name
)
elif isinstance(dtype, dt.JSON):
# ibis doesn't handle JSON type for MySQL properly.
# We need to convert JSON columns to string for compatibility manually.
result_table = self._cast_json_columns(
result_table=result_table, col_name=name
)

return result_table

def _cast_json_columns(self, result_table: Table, col_name: str) -> Table:
col = result_table[col_name]
# Convert JSON to string for compatibility
casted_col = col.cast("string")
return result_table.mutate(**{col_name: casted_col})


class MSSqlConnector(SimpleConnector):
def __init__(self, connection_info: ConnectionInfo):
super().__init__(DataSource.mssql, connection_info)
Expand Down
52 changes: 52 additions & 0 deletions ibis-server/tests/routers/v3/connector/mysql/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import pathlib

import pytest
import sqlalchemy
from testcontainers.mysql import MySqlContainer

from app.config import get_config
from tests.conftest import file_path

pytestmark = pytest.mark.mysql

base_url = "/v3/connector/mysql"
Expand All @@ -18,10 +22,58 @@ def pytest_collection_modifyitems(items):
@pytest.fixture(scope="session")
def mysql(request) -> MySqlContainer:
mysql = MySqlContainer(image="mysql:8.0.40", dialect="pymysql").start()
connection_url = mysql.get_connection_url()
engine = sqlalchemy.create_engine(connection_url)
with engine.connect() as conn:
conn.execute(
sqlalchemy.text(
"""
CREATE TABLE json_test (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
object_col JSON NOT NULL,
array_col JSON NOT NULL,
CHECK (JSON_TYPE(object_col) = 'OBJECT'),
CHECK (JSON_TYPE(array_col) = 'ARRAY')
) ENGINE=InnoDB;
"""
)
)
conn.execute(
sqlalchemy.text(
"""
INSERT INTO json_test (object_col, array_col) VALUES
('{"name": "Alice", "age": 30, "city": "New York"}', '["apple", "banana", "cherry"]'),
('{"name": "Bob", "age": 25, "city": "Los Angeles"}', '["dog", "cat", "mouse"]'),
('{"name": "Charlie", "age": 35, "city": "Chicago"}', '["red", "green", "blue"]');
"""
)
)
conn.commit()

request.addfinalizer(mysql.stop)
return mysql


function_list_path = file_path("../resources/function_list")
white_function_list_path = file_path("../resources/white_function_list")


@pytest.fixture(autouse=True)
def set_remote_function_list_path():
config = get_config()
config.set_remote_function_list_path(function_list_path)
yield
config.set_remote_function_list_path(None)


@pytest.fixture(autouse=True)
def set_remote_white_function_list_path():
config = get_config()
config.set_remote_white_function_list_path(white_function_list_path)
yield
config.set_remote_white_function_list_path(None)


@pytest.fixture(scope="module")
def connection_info(mysql: MySqlContainer) -> dict[str, str]:
return {
Expand Down
27 changes: 6 additions & 21 deletions ibis-server/tests/routers/v3/connector/mysql/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import pytest

from app.config import get_config
from tests.conftest import DATAFUSION_FUNCTION_COUNT, file_path
from tests.routers.v3.connector.mysql.conftest import base_url
from tests.conftest import DATAFUSION_FUNCTION_COUNT
from tests.routers.v3.connector.mysql.conftest import (
base_url,
function_list_path,
white_function_list_path,
)

manifest = {
"dataSource": "mysql",
Expand All @@ -25,31 +29,12 @@
],
}

function_list_path = file_path("../resources/function_list")
white_function_list_path = file_path("../resources/white_function_list")


@pytest.fixture(scope="module")
def manifest_str():
return base64.b64encode(orjson.dumps(manifest)).decode("utf-8")


@pytest.fixture(autouse=True)
def set_remote_function_list_path():
config = get_config()
config.set_remote_function_list_path(function_list_path)
yield
config.set_remote_function_list_path(None)


@pytest.fixture(autouse=True)
def set_remote_white_function_list_path():
config = get_config()
config.set_remote_white_function_list_path(white_function_list_path)
yield
config.set_remote_white_function_list_path(None)


async def test_function_list(client):
config = get_config()

Expand Down
67 changes: 67 additions & 0 deletions ibis-server/tests/routers/v3/connector/mysql/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import base64

import orjson
import pytest

from app.dependencies import X_WREN_FALLBACK_DISABLE
from tests.routers.v3.connector.mysql.conftest import base_url

manifest = {
"catalog": "wren",
"schema": "public",
"models": [
{
"name": "json_test",
"tableReference": {
"table": "json_test",
},
"columns": [
{"name": "id", "type": "bigint"},
{"name": "object_col", "type": "json"},
{"name": "array_col", "type": "json"},
],
},
{
"name": "orders",
"tableReference": {
"table": "orders",
},
"columns": [
{"name": "o_orderkey", "type": "integer"},
{"name": "o_orderdate", "type": "date"},
],
},
],
"dataSource": "mysql",
}


@pytest.fixture(scope="module")
async def manifest_str():
return base64.b64encode(orjson.dumps(manifest)).decode("utf-8")


async def test_json_query(client, manifest_str, connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT object_col, array_col FROM wren.public.json_test limit 1",
},
headers={
X_WREN_FALLBACK_DISABLE: "true",
},
)
assert response.status_code == 200
result = response.json()
assert result["data"] == [
[
'{"age": 30, "city": "New York", "name": "Alice"}',
'["apple", "banana", "cherry"]',
]
]
assert result["dtypes"] == {
"object_col": "string",
"array_col": "string",
}