Skip to content

Feat/compiler patcher #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
venv

# Unit test / coverage reports
htmlcov/
Expand All @@ -45,4 +46,4 @@ docs/_build/
.idea/

# Alembic
tests/migrations/versions/*
tests/migrations/versions/*
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

test:
#? Run test suite
@cd tests && pytest
96 changes: 93 additions & 3 deletions sqlalchemy_timescaledb/dialect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sqlalchemy import schema, event, DDL
from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
from sqlalchemy.dialects.postgresql.base import PGDDLCompiler
from sqlalchemy.dialects.postgresql.base import PGDDLCompiler, PGDialect
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.sql.elements import ClauseElement

try:
import alembic
Expand All @@ -13,8 +14,79 @@
class TimescaledbImpl(postgresql.PostgresqlImpl):
__dialect__ = 'timescaledb'

def all_subclasses(cls, include_cls: bool = True) -> set:
"""
A Recursive version of cls.__subclasses__() (i.e including subclasses of subclasses)
"""
if not hasattr(cls, "__subclasses__"):
if type(cls) is type:
cls_name = cls.__name__
else:
cls_name = cls.__class__.__name__

raise ValueError(f"Can't get subclasses of {cls_name}")

ret = cls.__subclasses__()
for subcls in ret:
ret += all_subclasses(subcls, include_cls = False)

if include_cls:
ret = [cls] + ret

return set(ret)



class TimescaledbDDLCompiler(PGDDLCompiler):
def __init__(self,*args,**kwargs):
super().__init__(*args,**kwargs)

# patch sqlalchemy to use postgres compilers for timescaledb dialect:
# we do this when the compiler is instantiated rather than in python's parse/init phase to remove the possibility
# of load order causing problems (i.e perhaps timescaledb might somehow be loaded before postgresql
# [or some postgresql extension])
self.patch_postgres_compilers()

@staticmethod
def patch_postgres_compilers():
"""
Here we iterate over ClauseElement subclasses to find postgres-specific compilers, and duplicate them so that they also
work with timescaledb

This allows the timescaledb dialect to use postgres compilers which specify the postgresql dialect via e.g
@compile(class,dialect="postgresql")
something.execute_if(dialect="postgresql")

(if we didn't do this, the "if dialect == postgresql" test will fail for these compilers when using the timescaledb dialect,
[ because weirdly it seems that "postgresql" != "timescaledb" ])

The compiler_dispatcher works by having a 'specs' dict, with the key being the db dialect and the value being the
compiler for that type of SQL clause for that dialect.

When attempting to compile, it chooses the dialect-specific compiler and compiles the sql with something like:

if dialect not in dispatcher.specs: compiler = default_compiler
else compiler = dispatcher.specs[dialect]

compiled_sql = compiler(some_clauseelement) # call compiler with the ClauseElement

Due to this, if a compiler specifies 'postgresql' as the dialect and the system is running on timescaledb, then
the dispatcher will fall back to the default compiler rather than using the postgres one, because the current
db dialect isn't 'postgresql'

We handle this here by iterating through all ClauseElement subclasses, looking for postgres-specific compilers,
and we copy them into a new 'timescaledb' entry in dispatcher.specs so that timescaledb is handled the same as
postgresql.

This approach saves us from needing to re-implement timescaledb compilers for everything - if we didn't do the
above, we would need to manually copy a bunch of compilers, like you see commented out at the end of this file
"""

for cls in all_subclasses(ClauseElement):
if (hasattr(cls, "_compiler_dispatcher") and hasattr(cls._compiler_dispatcher, "specs") and 'postgresql' in cls._compiler_dispatcher.specs):
# print(f"Patching compiler to use {cls._compiler_dispatcher.specs['postgresql']} for {cls} and timescaledb dialect")
cls._compiler_dispatcher.specs['timescaledb'] = cls._compiler_dispatcher.specs['postgresql']

def post_create_table(self, table):
hypertable = table.kwargs.get('timescaledb_hypertable', {})

Expand Down Expand Up @@ -54,7 +126,7 @@ def ddl_hypertable(table_name, hypertable):
)


class TimescaledbDialect:
class TimescaledbDialect(PGDialect):
name = 'timescaledb'
ddl_compiler = TimescaledbDDLCompiler
construct_arguments = [
Expand All @@ -66,11 +138,29 @@ class TimescaledbDialect:
]


class TimescaledbPsycopg2Dialect(TimescaledbDialect, PGDialect_psycopg2):
class TimescaledbPsycopg2Dialect(TimescaledbDialect,PGDialect_psycopg2):
driver = 'psycopg2'
supports_statement_cache = True


class TimescaledbAsyncpgDialect(TimescaledbDialect, PGDialect_asyncpg):
driver = 'asyncpg'
supports_statement_cache = True


"""
This function blatantly stolen from venv/lib/python3.11/site-packages/alembic/ddl/postgresql.py
You shouldn't need to add any of these, see TimescaledbDDLCompiler.patch_postgres_compilers.

@compiles(PostgresqlColumnType, "timescaledb")
def visit_column_type(
element: PostgresqlColumnType, compiler: PGDDLCompiler, **kw
) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"TYPE %s" % format_type(compiler, element.type_),
"USING %s" % element.using if element.using else "",
)

"""
4 changes: 3 additions & 1 deletion tests/alembic.ini
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ prepend_sys_path = .
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.

# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os

# the output encoding used when revision files
# are written from script.py.mako
Expand Down
3 changes: 2 additions & 1 deletion tests/async/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
@pytest_asyncio.fixture
def async_engine():
yield create_async_engine(
DATABASE_URL.set(drivername='timescaledb+asyncpg')
DATABASE_URL.set(drivername='timescaledb+asyncpg'),
echo=True
)


Expand Down
1 change: 1 addition & 0 deletions tests/async/test_hypertable.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ async def test_is_not_hypertable(self, async_session):
"""
)
)).scalar_one()
#assert False
8 changes: 6 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.fixture
def engine():
yield create_engine(DATABASE_URL)
yield create_engine(DATABASE_URL, echo=True)


@pytest.fixture
Expand All @@ -20,9 +20,13 @@ def session(engine):
yield session


_factory_session = None

@pytest.fixture(autouse=True)
def setup(engine):
FactorySession.configure(bind=engine)
global _factory_session
if _factory_session is None:
_factory_session = FactorySession.configure(bind=engine)
Base.metadata.create_all(bind=engine)
yield
Base.metadata.drop_all(bind=engine)
Expand Down
6 changes: 3 additions & 3 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from sqlalchemy.orm import declarative_base

DATABASE_URL = URL.create(
host=os.environ.get('POSTGRES_HOST', '0.0.0.0'),
host=os.environ.get('POSTGRES_HOST', 'localhost'),
port=os.environ.get('POSTGRES_PORT', 5432),
username=os.environ.get('POSTGRES_USER', 'user'),
username=os.environ.get('POSTGRES_USER', 'postgres'),
password=os.environ.get('POSTGRES_PASSWORD', 'password'),
database=os.environ.get('POSTGRES_DB', 'database'),
database=os.environ.get('POSTGRES_DB', 'test_timescaledb'),
drivername=os.environ.get('DRIVERNAME', 'timescaledb')
)

Expand Down
1 change: 1 addition & 0 deletions tests/test_alembic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def setup_class(self):
self.migration_versions_path = os.path.join(
os.path.dirname(__file__), 'migrations', 'versions'
)
self.config.set_main_option("version_locations",self.migration_versions_path)

def test_create_revision(self, engine):
Base.metadata.drop_all(bind=engine)
Expand Down