Skip to content

Commit 7ecddea

Browse files
author
Jesse
authored
[sqlalchemy] Add table and column comment support (databricks#329)
Signed-off-by: Christophe Bornet <[email protected]> Signed-off-by: Jesse Whitehouse <[email protected]> Co-authored-by: Jesse Whitehouse <[email protected]>
1 parent 918752f commit 7ecddea

File tree

9 files changed

+285
-63
lines changed

9 files changed

+285
-63
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# 3.1.0 (TBD)
44

5+
- SQLAlchemy: Added support for table and column comments (thanks @cbornet!)
56
- Fix: `server_hostname` URIs that included `https://` would raise an exception
67

78
## 3.0.1 (2023-12-01)

src/databricks/sqlalchemy/_ddl.py

+36-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from sqlalchemy.sql import compiler
2+
from sqlalchemy.sql import compiler, sqltypes
33
import logging
44

55
logger = logging.getLogger(__name__)
@@ -16,7 +16,13 @@ def __init__(self, dialect):
1616

1717
class DatabricksDDLCompiler(compiler.DDLCompiler):
1818
def post_create_table(self, table):
19-
return " USING DELTA"
19+
post = " USING DELTA"
20+
if table.comment:
21+
comment = self.sql_compiler.render_literal_value(
22+
table.comment, sqltypes.String()
23+
)
24+
post += " COMMENT " + comment
25+
return post
2026

2127
def visit_unique_constraint(self, constraint, **kw):
2228
logger.warning("Databricks does not support unique constraints")
@@ -39,17 +45,40 @@ def visit_identity_column(self, identity, **kw):
3945
)
4046
return text
4147

48+
def visit_set_column_comment(self, create, **kw):
49+
return "ALTER TABLE %s ALTER COLUMN %s COMMENT %s" % (
50+
self.preparer.format_table(create.element.table),
51+
self.preparer.format_column(create.element),
52+
self.sql_compiler.render_literal_value(
53+
create.element.comment, sqltypes.String()
54+
),
55+
)
56+
57+
def visit_drop_column_comment(self, create, **kw):
58+
return "ALTER TABLE %s ALTER COLUMN %s COMMENT ''" % (
59+
self.preparer.format_table(create.element.table),
60+
self.preparer.format_column(create.element),
61+
)
62+
4263
def get_column_specification(self, column, **kwargs):
43-
"""Currently we override this method only to emit a log message if a user attempts to set
44-
autoincrement=True on a column. See comments in test_suite.py. We may implement implicit
45-
IDENTITY using this feature in the future, similar to the Microsoft SQL Server dialect.
64+
"""
65+
Emit a log message if a user attempts to set autoincrement=True on a column.
66+
See comments in test_suite.py. We may implement implicit IDENTITY using this
67+
feature in the future, similar to the Microsoft SQL Server dialect.
4668
"""
4769
if column is column.table._autoincrement_column or column.autoincrement is True:
48-
logger.warn(
70+
logger.warning(
4971
"Databricks dialect ignores SQLAlchemy's autoincrement semantics. Use explicit Identity() instead."
5072
)
5173

52-
return super().get_column_specification(column, **kwargs)
74+
colspec = super().get_column_specification(column, **kwargs)
75+
if column.comment is not None:
76+
literal = self.sql_compiler.render_literal_value(
77+
column.comment, sqltypes.STRINGTYPE
78+
)
79+
colspec += " COMMENT " + literal
80+
81+
return colspec
5382

5483

5584
class DatabricksStatementCompiler(compiler.SQLCompiler):

src/databricks/sqlalchemy/_parse.py

+24
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,20 @@ def match_dte_rows_by_value(dte_output: List[Dict[str, str]], match: str) -> Lis
251251
return output_rows
252252

253253

254+
def match_dte_rows_by_key(dte_output: List[Dict[str, str]], match: str) -> List[dict]:
255+
"""Return a list of dictionaries containing only the col_name:data_type pairs where the `col_name`
256+
value contains the match argument.
257+
"""
258+
259+
output_rows = []
260+
261+
for row_dict in dte_output:
262+
if match in row_dict["col_name"]:
263+
output_rows.append(row_dict)
264+
265+
return output_rows
266+
267+
254268
def get_fk_strings_from_dte_output(dte_output: List[Dict[str, str]]) -> List[dict]:
255269
"""If the DESCRIBE TABLE EXTENDED output contains foreign key constraints, return a list of dictionaries,
256270
one dictionary per defined constraint
@@ -275,6 +289,15 @@ def get_pk_strings_from_dte_output(
275289
return output
276290

277291

292+
def get_comment_from_dte_output(dte_output: List[Dict[str, str]]) -> Optional[str]:
293+
"""Returns the value of the first "Comment" col_name data in dte_output"""
294+
output = match_dte_rows_by_key(dte_output, "Comment")
295+
if not output:
296+
return None
297+
else:
298+
return output[0]["data_type"]
299+
300+
278301
# The keys of this dictionary are the values we expect to see in a
279302
# TGetColumnsRequest's .TYPE_NAME attribute.
280303
# These are enumerated in ttypes.py as class TTypeId.
@@ -354,6 +377,7 @@ def parse_column_info_from_tgetcolumnsresponse(thrift_resp_row) -> ReflectedColu
354377
"type": final_col_type,
355378
"nullable": bool(thrift_resp_row.NULLABLE),
356379
"default": thrift_resp_row.COLUMN_DEF,
380+
"comment": thrift_resp_row.REMARKS or None,
357381
}
358382

359383
# TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects

src/databricks/sqlalchemy/base.py

+30-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import re
2-
from typing import Any, List, Optional, Dict, Union, Collection, Iterable, Tuple
1+
from typing import Any, List, Optional, Dict, Union
32

43
import databricks.sqlalchemy._ddl as dialect_ddl_impl
54
import databricks.sqlalchemy._types as dialect_type_impl
@@ -11,19 +10,20 @@
1110
build_pk_dict,
1211
get_fk_strings_from_dte_output,
1312
get_pk_strings_from_dte_output,
13+
get_comment_from_dte_output,
1414
parse_column_info_from_tgetcolumnsresponse,
1515
)
1616

1717
import sqlalchemy
1818
from sqlalchemy import DDL, event
1919
from sqlalchemy.engine import Connection, Engine, default, reflection
20-
from sqlalchemy.engine.reflection import ObjectKind
2120
from sqlalchemy.engine.interfaces import (
2221
ReflectedForeignKeyConstraint,
2322
ReflectedPrimaryKeyConstraint,
2423
ReflectedColumn,
25-
TableKey,
24+
ReflectedTableComment,
2625
)
26+
from sqlalchemy.engine.reflection import ReflectionDefaults
2727
from sqlalchemy.exc import DatabaseError, SQLAlchemyError
2828

2929
try:
@@ -285,7 +285,7 @@ def get_table_names(self, connection: Connection, schema=None, **kwargs):
285285
views_result = self.get_view_names(connection=connection, schema=schema)
286286

287287
# In Databricks, SHOW TABLES FROM <schema> returns both tables and views.
288-
# Potential optimisation: rewrite this to instead query informtation_schema
288+
# Potential optimisation: rewrite this to instead query information_schema
289289
tables_minus_views = [
290290
row.tableName for row in tables_result if row.tableName not in views_result
291291
]
@@ -328,7 +328,7 @@ def get_materialized_view_names(
328328
def get_temp_view_names(
329329
self, connection: Connection, schema: Optional[str] = None, **kw: Any
330330
) -> List[str]:
331-
"""A wrapper around get_view_names taht fetches only the names of temporary views"""
331+
"""A wrapper around get_view_names that fetches only the names of temporary views"""
332332
return self.get_view_names(connection, schema, only_temp=True)
333333

334334
def do_rollback(self, dbapi_connection):
@@ -375,6 +375,30 @@ def get_schema_names(self, connection, **kw):
375375
schema_list = [row[0] for row in result]
376376
return schema_list
377377

378+
@reflection.cache
379+
def get_table_comment(
380+
self,
381+
connection: Connection,
382+
table_name: str,
383+
schema: Optional[str] = None,
384+
**kw: Any,
385+
) -> ReflectedTableComment:
386+
result = self._describe_table_extended(
387+
connection=connection,
388+
table_name=table_name,
389+
schema_name=schema,
390+
)
391+
392+
if result is None:
393+
return ReflectionDefaults.table_comment()
394+
395+
comment = get_comment_from_dte_output(result)
396+
397+
if comment:
398+
return dict(text=comment)
399+
else:
400+
return ReflectionDefaults.table_comment()
401+
378402

379403
@event.listens_for(Engine, "do_connect")
380404
def receive_do_connect(dialect, conn_rec, cargs, cparams):

src/databricks/sqlalchemy/requirements.py

+12
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,18 @@ def table_reflection(self):
159159
"""target database has general support for table reflection"""
160160
return sqlalchemy.testing.exclusions.open()
161161

162+
@property
163+
def comment_reflection(self):
164+
"""Indicates if the database support table comment reflection"""
165+
return sqlalchemy.testing.exclusions.open()
166+
167+
@property
168+
def comment_reflection_full_unicode(self):
169+
"""Indicates if the database support table comment reflection in the
170+
full unicode range, including emoji etc.
171+
"""
172+
return sqlalchemy.testing.exclusions.open()
173+
162174
@property
163175
def temp_table_reflection(self):
164176
"""ComponentReflection test is intricate and simply cannot function without this exclusion being defined here.

src/databricks/sqlalchemy/test/_future.py

-48
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
ComponentReflectionTest,
1414
ComponentReflectionTestExtra,
1515
CTETest,
16-
FutureTableDDLTest,
1716
InsertBehaviorTest,
18-
TableDDLTest,
1917
)
2018
from sqlalchemy.testing.suite import (
2119
ArrayTest,
@@ -53,7 +51,6 @@ class FutureFeature(Enum):
5351
PROVISION = "event-driven engine configuration"
5452
REGEXP = "_visit_regexp"
5553
SANE_ROWCOUNT = "sane_rowcount support"
56-
TBL_COMMENTS = "table comment reflection"
5754
TBL_OPTS = "get_table_options method"
5855
TEST_DESIGN = "required test-fixture overrides"
5956
TUPLE_LITERAL = "tuple-like IN markers completely"
@@ -251,36 +248,7 @@ class FutureWeCanSetDefaultSchemaWEventsTest(FutureWeCanSetDefaultSchemaWEventsT
251248
pass
252249

253250

254-
class FutureTableDDLTest(FutureTableDDLTest):
255-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
256-
def test_add_table_comment(self):
257-
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
258-
pass
259-
260-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
261-
def test_drop_table_comment(self):
262-
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
263-
pass
264-
265-
266-
class TableDDLTest(TableDDLTest):
267-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
268-
def test_add_table_comment(self, connection):
269-
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
270-
pass
271-
272-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
273-
def test_drop_table_comment(self, connection):
274-
"""We could use requirements.comment_reflection here to disable this but prefer a more meaningful skip message"""
275-
pass
276-
277-
278251
class ComponentReflectionTest(ComponentReflectionTest):
279-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
280-
def test_get_multi_table_comment(self):
281-
"""There are 84 permutations of this test that are skipped."""
282-
pass
283-
284252
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_OPTS, True))
285253
def test_multi_get_table_options_tables(self):
286254
"""It's not clear what the expected ouput from this method would even _be_. Requires research."""
@@ -302,22 +270,6 @@ def test_get_multi_pk_constraint(self):
302270
def test_get_multi_check_constraints(self):
303271
pass
304272

305-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
306-
def test_get_comments(self):
307-
pass
308-
309-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
310-
def test_get_comments_with_schema(self):
311-
pass
312-
313-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
314-
def test_comments_unicode(self):
315-
pass
316-
317-
@pytest.mark.skip(reason=render_future_feature(FutureFeature.TBL_COMMENTS))
318-
def test_comments_unicode_full(self):
319-
pass
320-
321273

322274
class ComponentReflectionTestExtra(ComponentReflectionTestExtra):
323275
@pytest.mark.skip(render_future_feature(FutureFeature.CHECK))

0 commit comments

Comments
 (0)