Skip to content

Commit e69492a

Browse files
committed
Initial commit
1 parent 5df6883 commit e69492a

File tree

1 file changed

+159
-0
lines changed

1 file changed

+159
-0
lines changed

dialect.go

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
package gorm
2+
3+
import (
4+
"fmt"
5+
"reflect"
6+
"strings"
7+
"time"
8+
9+
"github.com/ngorm/common"
10+
"github.com/ngorm/ngorm/model"
11+
)
12+
13+
type postgres struct {
14+
common.Dialect
15+
}
16+
17+
func (postgres) GetName() string {
18+
return "postgres"
19+
}
20+
21+
func (postgres) BindVar(i int) string {
22+
return fmt.Sprintf("$%v", i)
23+
}
24+
25+
func (postgres) DataTypeOf(field *model.StructField) (string, error) {
26+
dataValue, sqlType, size, additionalType :=
27+
model.ParseFieldStructForDialect(field)
28+
if sqlType == "" {
29+
switch dataValue.Kind() {
30+
case reflect.Bool:
31+
sqlType = "boolean"
32+
case reflect.Int, reflect.Int8,
33+
reflect.Int16, reflect.Int32,
34+
reflect.Uint, reflect.Uint8,
35+
reflect.Uint16, reflect.Uint32,
36+
reflect.Uintptr:
37+
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
38+
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
39+
sqlType = "serial"
40+
} else {
41+
sqlType = "integer"
42+
}
43+
case reflect.Int64, reflect.Uint64:
44+
if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey {
45+
field.TagSettings["AUTO_INCREMENT"] = "AUTO_INCREMENT"
46+
sqlType = "bigserial"
47+
} else {
48+
sqlType = "bigint"
49+
}
50+
case reflect.Float32, reflect.Float64:
51+
sqlType = "numeric"
52+
case reflect.String:
53+
if _, ok := field.TagSettings["SIZE"]; !ok {
54+
size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different
55+
}
56+
57+
if size > 0 && size < 65532 {
58+
sqlType = fmt.Sprintf("varchar(%d)", size)
59+
} else {
60+
sqlType = "text"
61+
}
62+
case reflect.Struct:
63+
if _, ok := dataValue.Interface().(time.Time); ok {
64+
sqlType = "timestamp with time zone"
65+
}
66+
case reflect.Map:
67+
if dataValue.Type().Name() == "Hstore" {
68+
sqlType = "hstore"
69+
}
70+
default:
71+
if isByteArrayOrSlice(dataValue) {
72+
sqlType = "bytea"
73+
} else if isUUID(dataValue) {
74+
sqlType = "uuid"
75+
}
76+
}
77+
}
78+
79+
if sqlType == "" {
80+
return "", fmt.Errorf("invalid sql type %s (%s) for postgres",
81+
dataValue.Type().Name(), dataValue.Kind().String())
82+
}
83+
84+
if strings.TrimSpace(additionalType) == "" {
85+
return sqlType, nil
86+
}
87+
return fmt.Sprintf("%v %v", sqlType, additionalType), nil
88+
}
89+
90+
func (s postgres) HasIndex(tableName string, indexName string) bool {
91+
var count int
92+
s.DB.QueryRow(
93+
"SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2",
94+
tableName, indexName).Scan(&count)
95+
return count > 0
96+
}
97+
98+
func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool {
99+
var count int
100+
query := `
101+
SELECT Count(con.conname)
102+
FROM pg_constraint con
103+
WHERE $1 :: regclass :: oid = con.conrelid
104+
AND con.conname = $2
105+
AND con.contype = 'f'
106+
`
107+
s.DB.QueryRow(query, tableName, foreignKeyName).Scan(&count)
108+
return count > 0
109+
}
110+
111+
func (s postgres) HasTable(tableName string) bool {
112+
var count int
113+
query := `
114+
SELECT Count(*)
115+
FROM information_schema.tables
116+
WHERE table_name = $1
117+
AND table_type = 'BASE TABLE'
118+
`
119+
s.DB.QueryRow(query, tableName).Scan(&count)
120+
return count > 0
121+
}
122+
123+
func (s postgres) HasColumn(tableName string, columnName string) bool {
124+
var count int
125+
query := `
126+
SELECT Count(*)
127+
FROM information_schema.columns
128+
WHERE table_name = $1
129+
AND column_name = $2
130+
`
131+
s.DB.QueryRow(query, tableName, columnName).Scan(&count)
132+
return count > 0
133+
}
134+
135+
func (s postgres) CurrentDatabase() (name string) {
136+
s.DB.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name)
137+
return
138+
}
139+
140+
func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string {
141+
return fmt.Sprintf("RETURNING %v.%v", tableName, key)
142+
}
143+
144+
func (postgres) SupportLastInsertID() bool {
145+
return false
146+
}
147+
148+
func isByteArrayOrSlice(value reflect.Value) bool {
149+
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
150+
}
151+
152+
func isUUID(value reflect.Value) bool {
153+
if value.Kind() != reflect.Array || value.Type().Len() != 16 {
154+
return false
155+
}
156+
typename := value.Type().Name()
157+
lower := strings.ToLower(typename)
158+
return "uuid" == lower || "guid" == lower
159+
}

0 commit comments

Comments
 (0)