Skip to content

Compression annotation support #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 99 additions & 22 deletions sqlalchemy_timescaledb/dialect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from sqlalchemy import schema, event, DDL
import textwrap
from typing import Optional, Mapping, Any

from sqlalchemy import schema, event, DDL, Table, Dialect, ExecutableDDLElement
from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
from sqlalchemy.dialects.postgresql.base import PGDDLCompiler
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.engine.interfaces import SchemaTranslateMapType
from sqlalchemy.ext import compiler
from sqlalchemy_utils.view import CreateView, compile_create_materialized_view

try:
import alembic
Expand All @@ -14,9 +20,46 @@ class TimescaledbImpl(postgresql.PostgresqlImpl):
__dialect__ = 'timescaledb'


def _get_interval(value):
if isinstance(value, str):
return f"INTERVAL '{value}'"
elif isinstance(value, int):
return str(value)
else:
return "NULL"


def _create_map(mapping: dict):
return ", ".join([f'{key} => {value}' for key, value in mapping.items()])


@compiler.compiles(CreateView, 'timescaledb')
def compile_create_view(create, compiler, **kw):
return compiler.visit_create_view(create, **kw)

class TimescaledbDDLCompiler(PGDDLCompiler):
def post_create_table(self, table):

def visit_create_view(self, create, **kw):
ret = compile_create_materialized_view(create, self, **kw)
view = create.element
continuous = view.kwargs.get('timescaledb_continuous', {})
if continuous:
event.listen(
view,
'after_create',
self.ddl_add_continuous(
view.name, continuous
).execute_if(
dialect='timescaledb'
)
)
return ret

def visit_create_table(self, create, **kw):
ret = super().visit_create_table(create, **kw)
table = create.element
hypertable = table.kwargs.get('timescaledb_hypertable', {})
compress = table.kwargs.get('timescaledb_compress', {})

if hypertable:
event.listen(
Expand All @@ -29,29 +72,61 @@ def post_create_table(self, table):
)
)

return super().post_create_table(table)
if compress:
event.listen(
table,
'after_create',
self.ddl_compress(
table.name, compress
).execute_if(
dialect='timescaledb'
)
)
event.listen(
table,
'after_create',
self.ddl_compression_policy(
table.name, compress
).execute_if(
dialect='timescaledb'
)
)

return ret

@staticmethod
def ddl_hypertable(table_name, hypertable):
time_column_name = hypertable['time_column_name']
chunk_time_interval = hypertable.get('chunk_time_interval', '7 days')

if isinstance(chunk_time_interval, str):
if chunk_time_interval.isdigit():
chunk_time_interval = int(chunk_time_interval)
else:
chunk_time_interval = f"INTERVAL '{chunk_time_interval}'"

return DDL(
f"""
SELECT create_hypertable(
'{table_name}',
'{time_column_name}',
chunk_time_interval => {chunk_time_interval},
if_not_exists => TRUE
);
"""
)
chunk_time_interval = _get_interval(hypertable.get('chunk_time_interval', '7 days'))

parameters = _create_map(dict(chunk_time_interval=chunk_time_interval, if_not_exists="TRUE"))
return DDL(textwrap.dedent(f"""SELECT create_hypertable('{table_name}','{time_column_name}',{parameters})"""))

@staticmethod
def ddl_compress(table_name, compress):
segmentby = compress['compress_segmentby']

return DDL(textwrap.dedent(f"""
ALTER TABLE {table_name} SET (timescaledb.compress, timescaledb.compress_segmentby = '{segmentby}')
"""))

@staticmethod
def ddl_compression_policy(table_name, compress):
compress_after = _get_interval(compress.get('compression_policy_compress_after', '7 days'))
schedule_interval = _get_interval(compress.get('compression_policy_schedule_interval', None))

parameters = _create_map(dict(compress_after=compress_after, schedule_interval=schedule_interval))
return DDL(textwrap.dedent(f"""SELECT add_compression_policy('{table_name}', {parameters})"""))

@staticmethod
def ddl_add_continuous(table_name, continuous):
start_offset = _get_interval(continuous.get('continuous_aggregate_policy_start_offset', None))
end_offset = _get_interval(continuous.get('continuous_aggregate_policy_end_offset', None))
schedule_interval = _get_interval(continuous.get('continuous_aggregate_policy_schedule_interval', None))

parameters = _create_map(
dict(start_offset=start_offset, end_offset=end_offset, schedule_interval=schedule_interval))
return DDL(textwrap.dedent(f"""SELECT add_continuous_aggregate_policy('{table_name}', {parameters})"""))


class TimescaledbDialect:
Expand All @@ -60,7 +135,9 @@ class TimescaledbDialect:
construct_arguments = [
(
schema.Table, {
"hypertable": {}
"hypertable": {},
"compress": {},
"continuous": {},
}
)
]
Expand Down