Skip to content

Commit

Permalink
fix(postgres): fix insertion of NaT/None into timestamp columns (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored Jan 29, 2025
1 parent 2da91fb commit 847ed85
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 3 deletions.
10 changes: 9 additions & 1 deletion ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import TYPE_CHECKING, Any
from urllib.parse import unquote_plus

import psycopg
import sqlglot as sg
import sqlglot.expressions as sge
from pandas.api.types import is_float_dtype
Expand All @@ -31,10 +32,14 @@

import pandas as pd
import polars as pl
import psycopg
import pyarrow as pa


class NatDumper(psycopg.adapt.Dumper):
def dump(self, obj, context: Any | None = None) -> str | None:
return None


class Backend(SQLBackend, CanListCatalog, CanCreateDatabase):
name = "postgres"
compiler = sc.postgres.compiler
Expand Down Expand Up @@ -233,6 +238,7 @@ def do_connect(
year int32
month int32
"""
import pandas as pd
import psycopg
import psycopg.types.json

Expand All @@ -248,6 +254,8 @@ def do_connect(
**kwargs,
)

self.con.adapters.register_dumper(type(pd.NaT), NatDumper)

self._post_connect()

@util.experimental
Expand Down
16 changes: 16 additions & 0 deletions ibis/backends/postgres/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.types as ir
from ibis import literal as L
from ibis.util import gen_name

pytest.importorskip("psycopg")

Expand Down Expand Up @@ -1230,3 +1231,18 @@ def test_array_discovery(con):
)
)
assert t.schema() == expected


@pytest.mark.parametrize("tz", [None, "UTC", "America/New_York"])
def test_insert_null_timestamp(con, tz):
name = gen_name("test_insert_nat")
ts = pd.Timestamp(datetime(2025, 1, 3, 12, 37, 38, 234236), tz=tz)
df = pd.DataFrame({"ts": [ts, None]})

# check that timezones match the input
t = con.create_table(name, obj=df, temp=True)
assert t.ts.type().timezone == tz

# check that the NaT went into the database as NULL
res = t.ts.count().execute()
assert res == 1
4 changes: 2 additions & 2 deletions ibis/backends/risingwave/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def begin(self):
def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame:
import pandas as pd

from ibis.backends.postgres.converter import PostgresPandasData
from ibis.backends.risingwave.converter import RisingWavePandasData

try:
df = pd.DataFrame.from_records(
Expand All @@ -135,7 +135,7 @@ def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame:
# artificially locked tables
cursor.close()
raise
df = PostgresPandasData.convert_table(df, schema)
df = RisingWavePandasData.convert_table(df, schema)
return df

@property
Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/risingwave/converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from ibis.formats.pandas import PandasData


class RisingWavePandasData(PandasData):
@classmethod
def convert_Binary(cls, s, dtype, pandas_type):
return s.map(bytes, na_action="ignore")

0 comments on commit 847ed85

Please sign in to comment.