Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 80 additions & 26 deletions db2_sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@
)


def _add_dummy_table_if_needed(expression: exp.Select) -> exp.Select:
"""
Add SYSIBM.SYSDUMMY1 as FROM clause if SELECT has no FROM.
Db2 requires a FROM clause for all SELECT statements.
"""
# Note: SQLGlot uses 'from_' (with underscore) as the key for FROM clauses
if not expression.args.get("from_"):
from_table = exp.From(
this=exp.Table(
this=exp.Identifier(this="SYSDUMMY1"),
db=exp.Identifier(this="SYSIBM")
)
)
expression.set("from_", from_table)
return expression


def _date_add_sql(
kind: str,
) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]:
Expand Down Expand Up @@ -83,6 +100,8 @@ class Db2(generator.Generator):
exp.DType.NVARCHAR: "NVARCHAR",
exp.DType.TIMESTAMPTZ: "TIMESTAMP",
exp.DType.DATETIME: "TIMESTAMP",
# UUID is not natively supported in Db2, use CHAR(36)
exp.DType.UUID: "CHAR(36)",
}

AFTER_HAVING_MODIFIER_TRANSFORMS = {
Expand All @@ -91,33 +110,34 @@ class Db2(generator.Generator):
"sort": lambda self, e: "",
}

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ArgMax: rename_func("MAX"),
exp.ArgMin: rename_func("MIN"),
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
exp.DateDiff: lambda self, e: (
f"{self.func('DAYS', e.this)} - "
f"{self.func('DAYS', e.expression)}"
),
exp.CurrentDate: lambda self, e: "CURRENT DATE",
exp.CurrentTimestamp: lambda self, e: "CURRENT TIMESTAMP",
exp.ILike: no_ilike_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.StrPosition: rename_func("POSSTR"),
exp.TimeToStr: rename_func("VARCHAR_FORMAT"),
exp.TryCast: no_trycast_sql,
exp.Trim: trim_sql,
}

TRANSFORMS = {**generator.Generator.TRANSFORMS,
exp.ArgMax: rename_func("MAX"),
exp.ArgMin: rename_func("MIN"),
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
exp.DateDiff: lambda self,
e: f"{self.func('DAYS', e.this)} - {self.func('DAYS', e.expression)}",
exp.CurrentDate: lambda self,
e: "CURRENT DATE",
exp.CurrentTimestamp: lambda self,
e: "CURRENT TIMESTAMP",
exp.ILike: no_ilike_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on,
_add_dummy_table_if_needed,
]),
exp.StrPosition: rename_func("POSSTR"),
exp.TimeToStr: rename_func("VARCHAR_FORMAT"),
exp.TryCast: no_trycast_sql,
exp.Trim: trim_sql,
}

# Note: Db2-specific types (GRAPHIC, VARGRAPHIC, DBCLOB) are automatically
# handled by SQLGlot's default datatype_sql() when parsed as USERDEFINED
# types. The 'kind' field preserves the original type name, so no custom
# override needed.
# handled by SQLGlot's default datatype_sql() when parsed as USERDEFINED types.
# The 'kind' field preserves the original type name, so no custom override needed.

def extract_sql(self, expression: exp.Extract) -> str:
this = self.sql(expression, "this")
expression_sql = self.sql(expression, "expression")
Expand All @@ -138,3 +158,37 @@ def fetch_sql(self, expression: exp.Fetch) -> str:

def boolean_sql(self, expression: exp.Boolean) -> str:
return "1" if expression.this else "0"

def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
"""
Override cast_sql to handle UUID casts.
Since Db2 doesn't have native UUID type, we use CHAR(36).
When casting to UUID, we just return the value without the cast.
"""
# Check if casting to UUID
to_type = expression.to
if isinstance(to_type, exp.DataType) and to_type.this == exp.DataType.Type.UUID:
# Just return the expression being cast, without the CAST
return self.sql(expression.this)

# For all other casts, use the default behavior
return super().cast_sql(expression, safe_prefix)

def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
"""
Override columnconstraint_sql to handle UUID DEFAULT values.
Db2 doesn't have UUID generation functions, so we replace
gen_random_uuid() and similar functions with a placeholder.
"""
kind = expression.args.get("kind")

# Check if this is a DEFAULT constraint with UUID generation
if isinstance(kind, exp.DefaultColumnConstraint):
default_value = kind.this
# Check if the default is a Uuid() function call
if isinstance(default_value, exp.Uuid):
# Replace with a simple string default
# Note: In production, this should be handled with a trigger or application logic
kind.set("this", exp.Literal.string("0"))

return super().columnconstraint_sql(expression)
Loading