Skip to content

Commit b33bd31

Browse files
committed
Adding redshift create table special syntax.
Adding diststyle support. Adding distkey support. Adding sortkey support. Adding encode support.
1 parent 2bb1dab commit b33bd31

File tree

2 files changed

+280
-4
lines changed

2 files changed

+280
-4
lines changed

redshift_sqlalchemy/dialect.py

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,91 @@
1+
from sqlalchemy import schema, util, exc
2+
from sqlalchemy.dialects.postgresql.base import PGDDLCompiler
13
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
24
from sqlalchemy.engine import reflection
3-
from sqlalchemy import util, exc
4-
from sqlalchemy.types import VARCHAR, NullType
55
from sqlalchemy.ext.compiler import compiles
6-
from sqlalchemy.sql.expression import Executable, ClauseElement
7-
from sqlalchemy.sql.expression import BindParameter
6+
from sqlalchemy.sql.expression import BindParameter, Executable, ClauseElement
7+
from sqlalchemy.types import VARCHAR, NullType
8+
9+
10+
class RedShiftDDLCompiler(PGDDLCompiler):
11+
12+
def post_create_table(self, table):
13+
text = ""
14+
info = table.dialect_options['redshift']
15+
diststyle = info.get('diststyle', None)
16+
if diststyle:
17+
diststyle = diststyle.upper()
18+
if diststyle not in ('EVEN', 'KEY', 'ALL'):
19+
raise exc.CompileError(
20+
u"diststyle {0} is invalid".format(diststyle))
21+
text += " DISTSTYLE " + diststyle
22+
23+
distkey = info.get('distkey', None)
24+
if distkey:
25+
text += " DISTKEY ({0})".format(distkey)
26+
27+
sortkey = info.get('sortkey', None)
28+
if sortkey:
29+
if isinstance(sortkey, basestring):
30+
keys = (sortkey,)
31+
else:
32+
keys = sortkey
33+
text += " SORTKEY ({0})".format(", ".join(keys))
34+
return text
35+
36+
def get_column_specification(self, column, **kwargs):
37+
''' Redshift doesn't support serial types, so they have been removed
38+
here.
39+
'''
40+
colspec = self.preparer.format_column(column)
41+
colspec += " " + self.dialect.type_compiler.process(column.type)
42+
43+
colspec += self._fetch_redshift_column_attributes(column)
44+
45+
default = self.get_column_default_string(column)
46+
if default is not None:
47+
colspec += " DEFAULT " + default
48+
49+
if not column.nullable:
50+
colspec += " NOT NULL"
51+
return colspec
52+
53+
def _fetch_redshift_column_attributes(self, column):
54+
text = ""
55+
if not hasattr(column, 'info'):
56+
return text
57+
info = column.info
58+
encode = info.get('encode', None)
59+
if encode:
60+
text += " ENCODE " + encode
61+
62+
distkey = info.get('distkey', None)
63+
if distkey:
64+
text += " DISTKEY"
865

66+
sortkey = info.get('sortkey', None)
67+
if sortkey:
68+
text += " SORTKEY"
69+
return text
970

1071
class RedshiftDialect(PGDialect_psycopg2):
1172
name = 'redshift'
73+
ddl_compiler = RedShiftDDLCompiler
74+
75+
construct_arguments = [
76+
(schema.Index, {
77+
"using": False,
78+
"where": None,
79+
"ops": {}
80+
}),
81+
(schema.Table, {
82+
"ignore_search_path": False,
83+
'diststyle': None,
84+
'distkey': None,
85+
'sortkey': None
86+
}),
87+
]
88+
1289
@reflection.cache
1390
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
1491
"""

tests/test_ddl_compiler.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import difflib
2+
3+
from pytest import fixture
4+
from sqlalchemy import Table, Column, Integer, String, MetaData
5+
from sqlalchemy.exc import CompileError
6+
from sqlalchemy.schema import CreateTable
7+
8+
from redshift_sqlalchemy.dialect import RedShiftDDLCompiler, RedshiftDialect
9+
10+
11+
class TestDDLCompiler(object):
12+
13+
@fixture
14+
def compiler(self):
15+
compiler = RedShiftDDLCompiler(RedshiftDialect(), None)
16+
return compiler
17+
18+
def _compare_strings(self, expected, actual):
19+
assert expected is not None, "Expected was None"
20+
assert actual is not None, "Actual was None"
21+
22+
a = [ (c, c.encode('hex')) if c is not None else None for c in expected]
23+
b = [ (c, c.encode('hex')) if c is not None else None for c in actual]
24+
return u"-expected, +actual\n" + u"\n".join(difflib.ndiff(a, b))
25+
26+
27+
def test_create_table_simple(self, compiler):
28+
29+
table = Table('t1',
30+
MetaData(),
31+
Column('id', Integer, primary_key=True),
32+
Column('name', String))
33+
34+
35+
create_table = CreateTable(table)
36+
actual = compiler.process(create_table)
37+
expected = u"\nCREATE TABLE t1 ("\
38+
u"\n\tid INTEGER NOT NULL, "\
39+
u"\n\tname VARCHAR, "\
40+
u"\n\tPRIMARY KEY (id)\n)\n\n"
41+
assert expected == actual, self._compare_strings(expected, actual)
42+
43+
def test_create_table_with_diststyle(self, compiler):
44+
45+
table = Table('t1',
46+
MetaData(),
47+
Column('id', Integer, primary_key=True),
48+
Column('name', String),
49+
redshift_diststyle="EVEN")
50+
51+
52+
create_table = CreateTable(table)
53+
actual = compiler.process(create_table)
54+
expected = u"\nCREATE TABLE t1 ("\
55+
u"\n\tid INTEGER NOT NULL, "\
56+
u"\n\tname VARCHAR, "\
57+
u"\n\tPRIMARY KEY (id)\n) "\
58+
u"DISTSTYLE EVEN\n\n"
59+
assert expected == actual, self._compare_strings(expected, actual)
60+
61+
def test_invalid_diststyle(self, compiler):
62+
63+
table = Table('t1',
64+
MetaData(),
65+
Column('id', Integer, primary_key=True),
66+
Column('name', String),
67+
redshift_diststyle="NOTEVEN")
68+
69+
70+
create_table = CreateTable(table)
71+
try:
72+
compiler.process(create_table)
73+
except CompileError:
74+
pass
75+
else:
76+
assert False, "Error expected"
77+
78+
def test_create_table_with_distkey(self, compiler):
79+
80+
table = Table('t1',
81+
MetaData(),
82+
Column('id', Integer, primary_key=True),
83+
Column('name', String),
84+
redshift_distkey="id")
85+
86+
87+
create_table = CreateTable(table)
88+
actual = compiler.process(create_table)
89+
expected = u"\nCREATE TABLE t1 ("\
90+
u"\n\tid INTEGER NOT NULL, "\
91+
u"\n\tname VARCHAR, "\
92+
u"\n\tPRIMARY KEY (id)\n) "\
93+
u"DISTKEY (id)\n\n"
94+
assert expected == actual, self._compare_strings(expected, actual)
95+
96+
def test_create_table_with_sortkey(self, compiler):
97+
98+
table = Table('t1',
99+
MetaData(),
100+
Column('id', Integer, primary_key=True),
101+
Column('name', String),
102+
redshift_sortkey="id")
103+
104+
105+
create_table = CreateTable(table)
106+
actual = compiler.process(create_table)
107+
expected = u"\nCREATE TABLE t1 ("\
108+
u"\n\tid INTEGER NOT NULL, "\
109+
u"\n\tname VARCHAR, "\
110+
u"\n\tPRIMARY KEY (id)\n) "\
111+
u"SORTKEY (id)\n\n"
112+
assert expected == actual, self._compare_strings(expected, actual)
113+
114+
def test_create_table_with_multiple_sortkeys(self, compiler):
115+
116+
table = Table('t1',
117+
MetaData(),
118+
Column('id', Integer, primary_key=True),
119+
Column('name', String),
120+
redshift_sortkey=["id", "name"])
121+
122+
123+
create_table = CreateTable(table)
124+
actual = compiler.process(create_table)
125+
expected = u"\nCREATE TABLE t1 ("\
126+
u"\n\tid INTEGER NOT NULL, "\
127+
u"\n\tname VARCHAR, "\
128+
u"\n\tPRIMARY KEY (id)\n) "\
129+
u"SORTKEY (id, name)\n\n"
130+
assert expected == actual, self._compare_strings(expected, actual)
131+
132+
def test_create_table_all_together(self, compiler):
133+
table = Table('t1',
134+
MetaData(),
135+
Column('id', Integer, primary_key=True),
136+
Column('name', String),
137+
redshift_diststyle="KEY",
138+
redshift_distkey="id",
139+
redshift_sortkey=["id", "name"])
140+
141+
create_table = CreateTable(table)
142+
actual = compiler.process(create_table)
143+
expected = u"\nCREATE TABLE t1 ("\
144+
u"\n\tid INTEGER NOT NULL, "\
145+
u"\n\tname VARCHAR, "\
146+
u"\n\tPRIMARY KEY (id)\n) "\
147+
u"DISTSTYLE KEY DISTKEY (id) SORTKEY (id, name)\n\n"
148+
assert expected == actual, self._compare_strings(expected, actual)
149+
150+
def test_create_column_with_sortkey(self, compiler):
151+
table = Table('t1',
152+
MetaData(),
153+
Column('id', Integer, primary_key=True,
154+
info=dict(sortkey=True)),
155+
Column('name', String)
156+
)
157+
158+
159+
create_table = CreateTable(table)
160+
actual = compiler.process(create_table)
161+
expected = u"\nCREATE TABLE t1 ("\
162+
u"\n\tid INTEGER SORTKEY NOT NULL, "\
163+
u"\n\tname VARCHAR, "\
164+
u"\n\tPRIMARY KEY (id)\n)\n\n"
165+
assert expected == actual, self._compare_strings(expected, actual)
166+
167+
def test_create_column_with_distkey(self, compiler):
168+
table = Table('t1',
169+
MetaData(),
170+
Column('id', Integer, primary_key=True,
171+
info=dict(distkey=True)),
172+
Column('name', String)
173+
)
174+
175+
176+
create_table = CreateTable(table)
177+
actual = compiler.process(create_table)
178+
expected = u"\nCREATE TABLE t1 ("\
179+
u"\n\tid INTEGER DISTKEY NOT NULL, "\
180+
u"\n\tname VARCHAR, "\
181+
u"\n\tPRIMARY KEY (id)\n)\n\n"
182+
assert expected == actual, self._compare_strings(expected, actual)
183+
184+
def test_create_column_with_encoding(self, compiler):
185+
table = Table('t1',
186+
MetaData(),
187+
Column('id', Integer, primary_key=True,
188+
info=dict(encode="LZO")),
189+
Column('name', String)
190+
)
191+
192+
193+
create_table = CreateTable(table)
194+
actual = compiler.process(create_table)
195+
expected = u"\nCREATE TABLE t1 ("\
196+
u"\n\tid INTEGER ENCODE LZO NOT NULL, "\
197+
u"\n\tname VARCHAR, "\
198+
u"\n\tPRIMARY KEY (id)\n)\n\n"
199+
assert expected == actual, self._compare_strings(expected, actual)

0 commit comments

Comments
 (0)