Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,29 @@ def _get_cursor(self, raw_cursor: str) -> CursorType:
valid_cursors = ", ".join(cursor_types.keys())
raise ValueError(f"Invalid cursor passed {_cursor}. Valid options are: {valid_cursors}")

def _get_cursor_config(self, raw_cursor: str) -> tuple[str, Any]:
cursor = self._get_cursor(raw_cursor)

if USE_PSYCOPG3:
return "row_factory", cursor

return "cursor_factory", cursor

def _create_connection(self, conn_args: dict[str, Any]) -> CompatConnection:
if USE_PSYCOPG3:
from psycopg.connection import Connection as pgConnection

connection = pgConnection.connect(**cast("Any", conn_args))

register_default_adapters(connection)

if self.enable_log_db_messages and hasattr(connection, "add_notice_handler"):
connection.add_notice_handler(self._notice_handler)

return connection

return ppg2_connect(**conn_args)

def _generate_cursor_name(self):
"""Generate a unique name for server-side cursor."""
import uuid
Expand Down Expand Up @@ -262,30 +285,13 @@ def get_conn(self) -> CompatConnection:
if arg_name not in self.ignored_extra_options:
conn_args[arg_name] = arg_val

if USE_PSYCOPG3:
from psycopg.connection import Connection as pgConnection

raw_cursor = conn.extra_dejson.get("cursor")
if raw_cursor:
conn_args["row_factory"] = self._get_cursor(raw_cursor)

# Use Any type for the connection args to avoid type conflicts
connection = pgConnection.connect(**cast("Any", conn_args))
self.conn = cast("CompatConnection", connection)

# Register JSON handlers for both json and jsonb types
# This ensures JSON data is properly decoded from bytes to Python objects
register_default_adapters(connection)
raw_cursor = conn.extra_dejson.get("cursor")

# Add the notice handler AFTER the connection is established
if self.enable_log_db_messages and hasattr(self.conn, "add_notice_handler"):
self.conn.add_notice_handler(self._notice_handler)
else: # psycopg2
raw_cursor = conn.extra_dejson.get("cursor", False)
if raw_cursor:
conn_args["cursor_factory"] = self._get_cursor(raw_cursor)
if raw_cursor:
key, value = self._get_cursor_config(raw_cursor)
conn_args[key] = value
Comment thread
SameerMesiah97 marked this conversation as resolved.

self.conn = cast("CompatConnection", ppg2_connect(**conn_args))
self.conn = self._create_connection(conn_args)

return self.conn

Expand Down
Loading
Loading