Skip to content

Commit 8b8ee55

Browse files
committed
Merge pull request #12 from aronkyle-lightspeedretail/Adding-DDL-Support
Adding ddl support for redshift special syntax
2 parents 0a153f0 + a6eba04 commit 8b8ee55

File tree

4 files changed

+334
-4
lines changed

4 files changed

+334
-4
lines changed

redshift_sqlalchemy/dialect.py

+121-4
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,131 @@
1+
from sqlalchemy import schema, util, exc
2+
from sqlalchemy.dialects.postgresql.base import PGDDLCompiler
13
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
24
from sqlalchemy.engine import reflection
3-
from sqlalchemy import util, exc
4-
from sqlalchemy.types import VARCHAR, NullType
55
from sqlalchemy.ext.compiler import compiles
6-
from sqlalchemy.sql.expression import Executable, ClauseElement
7-
from sqlalchemy.sql.expression import BindParameter
6+
from sqlalchemy.sql.expression import BindParameter, Executable, ClauseElement
7+
from sqlalchemy.types import VARCHAR, NullType
8+
9+
10+
class RedShiftDDLCompiler(PGDDLCompiler):
11+
''' Handles Redshift specific create table syntax.
12+
13+
Users can specify the DISTSTYLE, DISTKEY, SORTKEY and ENCODE properties per
14+
table and per column.
15+
16+
Table level properties can be set using the dialect specific syntax. For
17+
example, to specify a distkey and style you apply the following ::
18+
19+
table = Table(metadata,
20+
Column('id', Integer, primary_key=True),
21+
Column('name', String),
22+
redshift_diststyle="KEY",
23+
redshift_distkey="id"
24+
redshift_sortkey=["id", "name"]
25+
)
26+
27+
A single sortkey can be applied without a wrapping list ::
28+
29+
table = Table(metadata,
30+
Column('id', Integer, primary_key=True),
31+
Column('name', String),
32+
redshift_sortkey="id"
33+
)
34+
35+
Column level special syntax can also be applied using the column info
36+
dictionary. For example, we can specify the encode for a column ::
37+
38+
table = Table(metadata,
39+
Column('id', Integer, primary_key=True),
40+
Column('name', String, info={"encode":"lzo"})
41+
)
42+
43+
We can also specify the distkey and sortkey options ::
44+
45+
table = Table(metadata,
46+
Column('id', Integer, primary_key=True),
47+
Column('name', String,
48+
info={"distkey":True, "sortkey":True})
49+
)
50+
51+
'''
52+
53+
def post_create_table(self, table):
54+
text = ""
55+
info = table.dialect_options['redshift']
56+
diststyle = info.get('diststyle', None)
57+
if diststyle:
58+
diststyle = diststyle.upper()
59+
if diststyle not in ('EVEN', 'KEY', 'ALL'):
60+
raise exc.CompileError(
61+
u"diststyle {0} is invalid".format(diststyle))
62+
text += " DISTSTYLE " + diststyle
63+
64+
distkey = info.get('distkey', None)
65+
if distkey:
66+
text += " DISTKEY ({0})".format(distkey)
67+
68+
sortkey = info.get('sortkey', None)
69+
if sortkey:
70+
if isinstance(sortkey, str):
71+
keys = (sortkey,)
72+
else:
73+
keys = sortkey
74+
text += " SORTKEY ({0})".format(", ".join(keys))
75+
return text
76+
77+
def get_column_specification(self, column, **kwargs):
78+
# aron - Apr 21, 2014: Redshift doesn't support serial types. So I
79+
# removed support for them here.
80+
colspec = self.preparer.format_column(column)
81+
colspec += " " + self.dialect.type_compiler.process(column.type)
82+
83+
colspec += self._fetch_redshift_column_attributes(column)
84+
85+
default = self.get_column_default_string(column)
86+
if default is not None:
87+
colspec += " DEFAULT " + default
88+
89+
if not column.nullable:
90+
colspec += " NOT NULL"
91+
return colspec
92+
93+
def _fetch_redshift_column_attributes(self, column):
94+
text = ""
95+
if not hasattr(column, 'info'):
96+
return text
97+
info = column.info
98+
encode = info.get('encode', None)
99+
if encode:
100+
text += " ENCODE " + encode
101+
102+
distkey = info.get('distkey', None)
103+
if distkey:
104+
text += " DISTKEY"
8105

106+
sortkey = info.get('sortkey', None)
107+
if sortkey:
108+
text += " SORTKEY"
109+
return text
9110

10111
class RedshiftDialect(PGDialect_psycopg2):
11112
name = 'redshift'
113+
ddl_compiler = RedShiftDDLCompiler
114+
115+
construct_arguments = [
116+
(schema.Index, {
117+
"using": False,
118+
"where": None,
119+
"ops": {}
120+
}),
121+
(schema.Table, {
122+
"ignore_search_path": False,
123+
'diststyle': None,
124+
'distkey': None,
125+
'sortkey': None
126+
}),
127+
]
128+
12129
@reflection.cache
13130
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
14131
"""

setup.cfg

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[pytest]
2+
norecursedirs = .svn .git _build tmp* venv redshift_sqlalchemy
3+

setup.py

+11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
11
from setuptools import setup
2+
from setuptools.command.test import test as TestCommand
3+
import sys
4+
5+
class PyTest(TestCommand):
6+
def run_tests(self):
7+
import pytest
8+
errno = pytest.main(self.test_args)
9+
sys.exit(errno)
210

311
setup(
412
name='redshift-sqlalchemy',
@@ -11,6 +19,9 @@
1119
url='https://github.com/binarydud/redshift_sqlalchemy',
1220
packages=['redshift_sqlalchemy'],
1321
install_requires=['psycopg2>=2.5', 'SQLAlchemy>=0.8.0'],
22+
tests_require=['pytest>=2.5.2'],
23+
test_suite="tests",
24+
cmdclass = {'test': PyTest},
1425
include_package_data=True,
1526
zip_safe=False,
1627
classifiers=[

tests/test_ddl_compiler.py

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import difflib
2+
3+
from pytest import fixture
4+
from sqlalchemy import Table, Column, Integer, String, MetaData
5+
from sqlalchemy.exc import CompileError
6+
from sqlalchemy.schema import CreateTable
7+
8+
from redshift_sqlalchemy.dialect import RedShiftDDLCompiler, RedshiftDialect
9+
10+
11+
class TestDDLCompiler(object):
12+
13+
@fixture
14+
def compiler(self):
15+
compiler = RedShiftDDLCompiler(RedshiftDialect(), None)
16+
return compiler
17+
18+
def _compare_strings(self, expected, actual):
19+
assert expected is not None, "Expected was None"
20+
assert actual is not None, "Actual was None"
21+
22+
a = [ (c, c.encode('hex')) if c is not None else None for c in expected]
23+
b = [ (c, c.encode('hex')) if c is not None else None for c in actual]
24+
return u"-expected, +actual\n" + u"\n".join(difflib.ndiff(a, b))
25+
26+
27+
def test_create_table_simple(self, compiler):
28+
29+
table = Table('t1',
30+
MetaData(),
31+
Column('id', Integer, primary_key=True),
32+
Column('name', String))
33+
34+
35+
create_table = CreateTable(table)
36+
actual = compiler.process(create_table)
37+
expected = u"\nCREATE TABLE t1 ("\
38+
u"\n\tid INTEGER NOT NULL, "\
39+
u"\n\tname VARCHAR, "\
40+
u"\n\tPRIMARY KEY (id)\n)\n\n"
41+
assert expected == actual, self._compare_strings(expected, actual)
42+
43+
def test_create_table_with_diststyle(self, compiler):
44+
45+
table = Table('t1',
46+
MetaData(),
47+
Column('id', Integer, primary_key=True),
48+
Column('name', String),
49+
redshift_diststyle="EVEN")
50+
51+
52+
create_table = CreateTable(table)
53+
actual = compiler.process(create_table)
54+
expected = u"\nCREATE TABLE t1 ("\
55+
u"\n\tid INTEGER NOT NULL, "\
56+
u"\n\tname VARCHAR, "\
57+
u"\n\tPRIMARY KEY (id)\n) "\
58+
u"DISTSTYLE EVEN\n\n"
59+
assert expected == actual, self._compare_strings(expected, actual)
60+
61+
def test_invalid_diststyle(self, compiler):
62+
63+
table = Table('t1',
64+
MetaData(),
65+
Column('id', Integer, primary_key=True),
66+
Column('name', String),
67+
redshift_diststyle="NOTEVEN")
68+
69+
70+
create_table = CreateTable(table)
71+
try:
72+
compiler.process(create_table)
73+
except CompileError:
74+
pass
75+
else:
76+
assert False, "Error expected"
77+
78+
def test_create_table_with_distkey(self, compiler):
79+
80+
table = Table('t1',
81+
MetaData(),
82+
Column('id', Integer, primary_key=True),
83+
Column('name', String),
84+
redshift_distkey="id")
85+
86+
87+
create_table = CreateTable(table)
88+
actual = compiler.process(create_table)
89+
expected = u"\nCREATE TABLE t1 ("\
90+
u"\n\tid INTEGER NOT NULL, "\
91+
u"\n\tname VARCHAR, "\
92+
u"\n\tPRIMARY KEY (id)\n) "\
93+
u"DISTKEY (id)\n\n"
94+
assert expected == actual, self._compare_strings(expected, actual)
95+
96+
def test_create_table_with_sortkey(self, compiler):
97+
98+
table = Table('t1',
99+
MetaData(),
100+
Column('id', Integer, primary_key=True),
101+
Column('name', String),
102+
redshift_sortkey="id")
103+
104+
105+
create_table = CreateTable(table)
106+
actual = compiler.process(create_table)
107+
expected = u"\nCREATE TABLE t1 ("\
108+
u"\n\tid INTEGER NOT NULL, "\
109+
u"\n\tname VARCHAR, "\
110+
u"\n\tPRIMARY KEY (id)\n) "\
111+
u"SORTKEY (id)\n\n"
112+
assert expected == actual, self._compare_strings(expected, actual)
113+
114+
def test_create_table_with_multiple_sortkeys(self, compiler):
115+
116+
table = Table('t1',
117+
MetaData(),
118+
Column('id', Integer, primary_key=True),
119+
Column('name', String),
120+
redshift_sortkey=["id", "name"])
121+
122+
123+
create_table = CreateTable(table)
124+
actual = compiler.process(create_table)
125+
expected = u"\nCREATE TABLE t1 ("\
126+
u"\n\tid INTEGER NOT NULL, "\
127+
u"\n\tname VARCHAR, "\
128+
u"\n\tPRIMARY KEY (id)\n) "\
129+
u"SORTKEY (id, name)\n\n"
130+
assert expected == actual, self._compare_strings(expected, actual)
131+
132+
def test_create_table_all_together(self, compiler):
133+
table = Table('t1',
134+
MetaData(),
135+
Column('id', Integer, primary_key=True),
136+
Column('name', String),
137+
redshift_diststyle="KEY",
138+
redshift_distkey="id",
139+
redshift_sortkey=["id", "name"])
140+
141+
create_table = CreateTable(table)
142+
actual = compiler.process(create_table)
143+
expected = u"\nCREATE TABLE t1 ("\
144+
u"\n\tid INTEGER NOT NULL, "\
145+
u"\n\tname VARCHAR, "\
146+
u"\n\tPRIMARY KEY (id)\n) "\
147+
u"DISTSTYLE KEY DISTKEY (id) SORTKEY (id, name)\n\n"
148+
assert expected == actual, self._compare_strings(expected, actual)
149+
150+
def test_create_column_with_sortkey(self, compiler):
151+
table = Table('t1',
152+
MetaData(),
153+
Column('id', Integer, primary_key=True,
154+
info=dict(sortkey=True)),
155+
Column('name', String)
156+
)
157+
158+
159+
create_table = CreateTable(table)
160+
actual = compiler.process(create_table)
161+
expected = u"\nCREATE TABLE t1 ("\
162+
u"\n\tid INTEGER SORTKEY NOT NULL, "\
163+
u"\n\tname VARCHAR, "\
164+
u"\n\tPRIMARY KEY (id)\n)\n\n"
165+
assert expected == actual, self._compare_strings(expected, actual)
166+
167+
def test_create_column_with_distkey(self, compiler):
168+
table = Table('t1',
169+
MetaData(),
170+
Column('id', Integer, primary_key=True,
171+
info=dict(distkey=True)),
172+
Column('name', String)
173+
)
174+
175+
176+
create_table = CreateTable(table)
177+
actual = compiler.process(create_table)
178+
expected = u"\nCREATE TABLE t1 ("\
179+
u"\n\tid INTEGER DISTKEY NOT NULL, "\
180+
u"\n\tname VARCHAR, "\
181+
u"\n\tPRIMARY KEY (id)\n)\n\n"
182+
assert expected == actual, self._compare_strings(expected, actual)
183+
184+
def test_create_column_with_encoding(self, compiler):
185+
table = Table('t1',
186+
MetaData(),
187+
Column('id', Integer, primary_key=True,
188+
info=dict(encode="LZO")),
189+
Column('name', String)
190+
)
191+
192+
193+
create_table = CreateTable(table)
194+
actual = compiler.process(create_table)
195+
expected = u"\nCREATE TABLE t1 ("\
196+
u"\n\tid INTEGER ENCODE LZO NOT NULL, "\
197+
u"\n\tname VARCHAR, "\
198+
u"\n\tPRIMARY KEY (id)\n)\n\n"
199+
assert expected == actual, self._compare_strings(expected, actual)

0 commit comments

Comments
 (0)