Skip to content

Commit 6ebcfa4

Browse files
committed
fix parseDriver function to return correct Driver when using MySQL
1 parent faa1c9d commit 6ebcfa4

File tree

5 files changed

+21
-13
lines changed

5 files changed

+21
-13
lines changed

internal/codegen/golang/driver.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package golang
22

3+
import "github.com/sqlc-dev/sqlc/internal/config"
4+
35
type SQLDriver string
46

57
const (
@@ -15,14 +17,22 @@ const (
1517
SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql"
1618
)
1719

18-
func parseDriver(sqlPackage string) SQLDriver {
20+
func parseDriver(sqlPackage string, engine string) SQLDriver {
1921
switch sqlPackage {
2022
case SQLPackagePGXV4:
2123
return SQLDriverPGXV4
2224
case SQLPackagePGXV5:
2325
return SQLDriverPGXV5
2426
default:
25-
return SQLDriverLibPQ
27+
switch engine {
28+
case string(config.EnginePostgreSQL):
29+
return SQLDriverLibPQ
30+
case string(config.EngineMySQL):
31+
return SQLDriverGoSQLDriverMySQL
32+
default:
33+
// TODO: return Driver for SQLite
34+
return SQLDriverLibPQ
35+
}
2636
}
2737
}
2838

internal/codegen/golang/gen.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie
123123
Enums: enums,
124124
Structs: structs,
125125
}
126-
127126
golang := req.Settings.Go
128127
tctx := tmplCtx{
129128
EmitInterface: golang.EmitInterface,
@@ -137,23 +136,22 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie
137136
EmitAllEnumValues: golang.EmitAllEnumValues,
138137
UsesCopyFrom: usesCopyFrom(queries),
139138
UsesBatch: usesBatch(queries),
140-
SQLDriver: parseDriver(golang.SqlPackage),
139+
SQLDriver: parseDriver(golang.SqlPackage, req.Settings.Engine),
141140
Q: "`",
142141
Package: golang.Package,
143142
Enums: enums,
144143
Structs: structs,
145144
SqlcVersion: req.SqlcVersion,
146145
}
147146

148-
if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && golang.SqlDriver != SQLDriverGoSQLDriverMySQL {
147+
if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && tctx.SQLDriver != SQLDriverGoSQLDriverMySQL {
149148
return nil, errors.New(":copyfrom is only supported by pgx and github.com/go-sql-driver/mysql")
150149
}
151150

152-
if tctx.UsesCopyFrom && golang.SqlDriver == SQLDriverGoSQLDriverMySQL {
151+
if tctx.UsesCopyFrom && tctx.SQLDriver == SQLDriverGoSQLDriverMySQL {
153152
if err := checkNoTimesForMySQLCopyFrom(queries); err != nil {
154153
return nil, err
155154
}
156-
tctx.SQLDriver = SQLDriverGoSQLDriverMySQL
157155
}
158156

159157
if tctx.UsesBatch && !tctx.SQLDriver.IsPGX() {

internal/codegen/golang/imports.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ func (i *importer) dbImports() fileImports {
116116
{Path: "context"},
117117
}
118118

119-
sqlpkg := parseDriver(i.Settings.Go.SqlPackage)
119+
sqlpkg := parseDriver(i.Settings.Go.SqlPackage, i.Settings.Engine)
120120
switch sqlpkg {
121121
case SQLDriverPGXV4:
122122
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgconn"})
@@ -160,7 +160,7 @@ func buildImports(settings *plugin.Settings, queries []Query, uses func(string)
160160
std["database/sql"] = struct{}{}
161161
}
162162

163-
sqlpkg := parseDriver(settings.Go.SqlPackage)
163+
sqlpkg := parseDriver(settings.Go.SqlPackage, settings.Engine)
164164
for _, q := range queries {
165165
if q.Cmd == metadata.CmdExecResult {
166166
switch sqlpkg {
@@ -374,7 +374,7 @@ func (i *importer) queryImports(filename string) fileImports {
374374
std["context"] = struct{}{}
375375
}
376376

377-
sqlpkg := parseDriver(i.Settings.Go.SqlPackage)
377+
sqlpkg := parseDriver(i.Settings.Go.SqlPackage, i.Settings.Engine)
378378
if sqlcSliceScan() {
379379
std["strings"] = struct{}{}
380380
}
@@ -459,7 +459,7 @@ func (i *importer) batchImports() fileImports {
459459

460460
std["context"] = struct{}{}
461461
std["errors"] = struct{}{}
462-
sqlpkg := parseDriver(i.Settings.Go.SqlPackage)
462+
sqlpkg := parseDriver(i.Settings.Go.SqlPackage, i.Settings.Engine)
463463
switch sqlpkg {
464464
case SQLDriverPGXV4:
465465
pkg[ImportSpec{Path: "github.com/jackc/pgx/v4"}] = struct{}{}

internal/codegen/golang/postgresql_type.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func parseIdentifierString(name string) (*plugin.Identifier, error) {
3636
func postgresType(req *plugin.CodeGenRequest, col *plugin.Column) string {
3737
columnType := sdk.DataType(col.Type)
3838
notNull := col.NotNull || col.IsArray
39-
driver := parseDriver(req.Settings.Go.SqlPackage)
39+
driver := parseDriver(req.Settings.Go.SqlPackage, req.Settings.Engine)
4040
emitPointersForNull := driver.IsPGX() && req.Settings.Go.EmitPointersForNullTypes
4141

4242
switch columnType {

internal/codegen/golang/result.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
208208
Comments: query.Comments,
209209
Table: query.InsertIntoTable,
210210
}
211-
sqlpkg := parseDriver(req.Settings.Go.SqlPackage)
211+
sqlpkg := parseDriver(req.Settings.Go.SqlPackage, req.Settings.Engine)
212212

213213
qpl := int(*req.Settings.Go.QueryParameterLimit)
214214

0 commit comments

Comments
 (0)