@@ -15,15 +15,18 @@ class SqlServerContainer(DbContainer):
15
15
>>> import sqlalchemy
16
16
>>> from testcontainers.mssql import SqlServerContainer
17
17
18
- >>> with SqlServerContainer() as mssql:
19
- ... engine = sqlalchemy.create_engine(mssql.get_connection_url())
20
- ... with engine.begin() as connection:
21
- ... result = connection.execute(sqlalchemy.text("select @@VERSION"))
18
+ >>> with SqlServerContainer() as mssql:
19
+ ... engine = sqlalchemy.create_engine(mssql.get_connection_url())
20
+ ... result = engine.execute(sqlalchemy.text("select @@VERSION"))
21
+ Notes
22
+ -----
23
+ Requires `ODBC Driver 17 for SQL Server <https://docs.microsoft.com/en-us/sql/connect/odbc/
24
+ linux-mac/installing-the-microsoft-odbc-driver-for-sql-server>`_.
22
25
"""
23
26
24
27
def __init__ (self , image : str = "mcr.microsoft.com/mssql/server:2019-latest" ,
25
28
username : str = "SA" , password : Optional [str ] = None , port : int = 1433 ,
26
- dbname : str = "tempdb" , dialect : str = 'mssql+pymssql' , ** kwargs ) -> None :
29
+ dbname : str = "tempdb" , dialect : str = 'mssql+pymssql' , driver : str = "ODBC Driver 17 for SQL Server" , ** kwargs ) -> None :
27
30
raise_for_deprecated_parameter (kwargs , "user" , "username" )
28
31
super (SqlServerContainer , self ).__init__ (image , ** kwargs )
29
32
@@ -34,6 +37,7 @@ def __init__(self, image: str = "mcr.microsoft.com/mssql/server:2019-latest",
34
37
self .username = username
35
38
self .dbname = dbname
36
39
self .dialect = dialect
40
+ self .driver = driver
37
41
38
42
def _configure (self ) -> None :
39
43
self .with_env ("SA_PASSWORD" , self .password )
@@ -42,7 +46,10 @@ def _configure(self) -> None:
42
46
self .with_env ("ACCEPT_EULA" , 'Y' )
43
47
44
48
def get_connection_url (self ) -> str :
45
- return super ()._create_connection_url (
49
+ base_url = super (SqlServerContainer , self )._create_connection_url (
46
50
dialect = self .dialect , username = self .username , password = self .password ,
47
- dbname = self .dbname , port = self .port
51
+ db_name = self .dbname , port = self .port
48
52
)
53
+ url = base_url + f"?driver={ '+' .join (self .driver .split (' ' ))} "
54
+ return url
55
+
0 commit comments