Skip to content

Commit

Permalink
feat(sql): support inserts with default constraints (#9844)
Browse files Browse the repository at this point in the history
Co-authored-by: Phillip Cloud <[email protected]>
  • Loading branch information
IndexSeek and cpcloud authored Aug 25, 2024
1 parent 949fbea commit 86a3c06
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 7 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def current_catalog(self) -> str:

@property
def current_database(self) -> str:
return NotImplementedError()
raise NotImplementedError()

def list_catalogs(self, like: str | None = None) -> list[str]:
code = (
Expand Down
11 changes: 5 additions & 6 deletions ibis/backends/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,14 +423,13 @@ def _build_insert_from_table(
compiler = self.compiler
quoted = compiler.quoted
# Compare the columns between the target table and the object to be inserted
# If they don't match, assume auto-generated column names and use positional
# ordering.
source_cols = source.columns
# If source is a subset of target, use source columns for insert list
# Otherwise, assume auto-generated column names and use positional ordering.
target_cols = self.get_schema(target).keys()

columns = (
source_cols
if not set(target_cols := self.get_schema(target).names).difference(
source_cols
)
if (source_cols := source.schema().keys()) <= target_cols
else target_cols
)

Expand Down
43 changes: 43 additions & 0 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import pytest
import rich.console
import sqlglot as sg
import toolz
from packaging.version import parse as vparse
from pytest import mark, param
Expand All @@ -33,6 +34,7 @@
OracleDatabaseError,
PsycoPg2InternalError,
PsycoPg2UndefinedObject,
Py4JJavaError,
PyODBCProgrammingError,
PySparkAnalysisException,
SnowflakeProgrammingError,
Expand Down Expand Up @@ -1738,3 +1740,44 @@ def test_cross_database_join(con_create_database, monkeypatch):
con.drop_table(left_table)
con.drop_table(right_table, database=dbname)
con.drop_database(dbname)


@pytest.mark.notimpl(
["druid"], raises=AttributeError, reason="doesn't implement `raw_sql`"
)
@pytest.mark.notimpl(["clickhouse"], reason="create table isn't implemented")
@pytest.mark.notyet(["flink"], raises=Py4JJavaError)
@pytest.mark.notyet(["pandas", "dask", "polars"], reason="Doesn't support insert")
@pytest.mark.notyet(["exasol"], reason="Backend does not support raw_sql")
@pytest.mark.notimpl(
["impala", "pyspark", "trino"], reason="Default constraints are not supported"
)
def test_insert_into_table_missing_columns(con, temp_table):
try:
db = getattr(con, "current_database", None)
except NotImplementedError:
db = None

# UGH
if con.name == "oracle":
db = None

try:
catalog = getattr(con, "current_catalog", None)
except NotImplementedError:
catalog = None

raw_ident = ".".join(
sg.to_identifier(i, quoted=True).sql("duckdb")
for i in filter(None, (catalog, db, temp_table))
)

ct_sql = f'CREATE TABLE {raw_ident} ("a" INT DEFAULT 1, "b" INT)'
sg_expr = sg.parse_one(ct_sql, read="duckdb")
con.raw_sql(sg_expr.sql(dialect=con.dialect))
con.insert(temp_table, [{"b": 1}])

result = con.table(temp_table).to_pyarrow().to_pydict()
expected_result = {"a": [1], "b": [1]}

assert result == expected_result

0 comments on commit 86a3c06

Please sign in to comment.