|
1 |
| -import re |
2 |
| -from typing import Any, List, Optional, Dict, Union, Collection, Iterable, Tuple |
| 1 | +from typing import Any, List, Optional, Dict, Union |
3 | 2 |
|
4 | 3 | import databricks.sqlalchemy._ddl as dialect_ddl_impl
|
5 | 4 | import databricks.sqlalchemy._types as dialect_type_impl
|
|
11 | 10 | build_pk_dict,
|
12 | 11 | get_fk_strings_from_dte_output,
|
13 | 12 | get_pk_strings_from_dte_output,
|
| 13 | + get_comment_from_dte_output, |
14 | 14 | parse_column_info_from_tgetcolumnsresponse,
|
15 | 15 | )
|
16 | 16 |
|
17 | 17 | import sqlalchemy
|
18 | 18 | from sqlalchemy import DDL, event
|
19 | 19 | from sqlalchemy.engine import Connection, Engine, default, reflection
|
20 |
| -from sqlalchemy.engine.reflection import ObjectKind |
21 | 20 | from sqlalchemy.engine.interfaces import (
|
22 | 21 | ReflectedForeignKeyConstraint,
|
23 | 22 | ReflectedPrimaryKeyConstraint,
|
24 | 23 | ReflectedColumn,
|
25 |
| - TableKey, |
| 24 | + ReflectedTableComment, |
26 | 25 | )
|
| 26 | +from sqlalchemy.engine.reflection import ReflectionDefaults |
27 | 27 | from sqlalchemy.exc import DatabaseError, SQLAlchemyError
|
28 | 28 |
|
29 | 29 | try:
|
@@ -285,7 +285,7 @@ def get_table_names(self, connection: Connection, schema=None, **kwargs):
|
285 | 285 | views_result = self.get_view_names(connection=connection, schema=schema)
|
286 | 286 |
|
287 | 287 | # In Databricks, SHOW TABLES FROM <schema> returns both tables and views.
|
288 |
| - # Potential optimisation: rewrite this to instead query informtation_schema |
| 288 | + # Potential optimisation: rewrite this to instead query information_schema |
289 | 289 | tables_minus_views = [
|
290 | 290 | row.tableName for row in tables_result if row.tableName not in views_result
|
291 | 291 | ]
|
@@ -328,7 +328,7 @@ def get_materialized_view_names(
|
328 | 328 | def get_temp_view_names(
|
329 | 329 | self, connection: Connection, schema: Optional[str] = None, **kw: Any
|
330 | 330 | ) -> List[str]:
|
331 |
| - """A wrapper around get_view_names taht fetches only the names of temporary views""" |
| 331 | + """A wrapper around get_view_names that fetches only the names of temporary views""" |
332 | 332 | return self.get_view_names(connection, schema, only_temp=True)
|
333 | 333 |
|
334 | 334 | def do_rollback(self, dbapi_connection):
|
@@ -375,6 +375,30 @@ def get_schema_names(self, connection, **kw):
|
375 | 375 | schema_list = [row[0] for row in result]
|
376 | 376 | return schema_list
|
377 | 377 |
|
| 378 | + @reflection.cache |
| 379 | + def get_table_comment( |
| 380 | + self, |
| 381 | + connection: Connection, |
| 382 | + table_name: str, |
| 383 | + schema: Optional[str] = None, |
| 384 | + **kw: Any, |
| 385 | + ) -> ReflectedTableComment: |
| 386 | + result = self._describe_table_extended( |
| 387 | + connection=connection, |
| 388 | + table_name=table_name, |
| 389 | + schema_name=schema, |
| 390 | + ) |
| 391 | + |
| 392 | + if result is None: |
| 393 | + return ReflectionDefaults.table_comment() |
| 394 | + |
| 395 | + comment = get_comment_from_dte_output(result) |
| 396 | + |
| 397 | + if comment: |
| 398 | + return dict(text=comment) |
| 399 | + else: |
| 400 | + return ReflectionDefaults.table_comment() |
| 401 | + |
378 | 402 |
|
379 | 403 | @event.listens_for(Engine, "do_connect")
|
380 | 404 | def receive_do_connect(dialect, conn_rec, cargs, cparams):
|
|
0 commit comments