Skip to content

Commit 15fcc29

Browse files
committed
Initalize
0 parents  commit 15fcc29

File tree

5 files changed

+283
-0
lines changed

5 files changed

+283
-0
lines changed

License

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
The MIT License (MIT)
2+
3+
Copyright (c) 2013-NOW Jinzhu <[email protected]>
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in
13+
all copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21+
THE SOFTWARE.

README.md

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# GORM PostgreSQL Driver
2+
3+
## USAGE
4+
5+
```go
6+
import (
7+
"gorm.io/driver/postgres"
8+
"gorm.io/gorm"
9+
)
10+
11+
// https://github.com/lib/pq
12+
dsn := "user=gorm password=gorm DB.name=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai"
13+
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
14+
```
15+
16+
Checkout [https://gorm.io](https://gorm.io) for details.

go.mod

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
module gorm.io/driver/postgres
2+
3+
go 1.14
4+
5+
require github.com/lib/pq v1.6.0

migrator.go

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
package postgres
2+
3+
import (
4+
"fmt"
5+
6+
"gorm.io/gorm"
7+
"gorm.io/gorm/clause"
8+
"gorm.io/gorm/migrator"
9+
"gorm.io/gorm/schema"
10+
)
11+
12+
type Migrator struct {
13+
migrator.Migrator
14+
}
15+
16+
func (m Migrator) CurrentDatabase() (name string) {
17+
m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name)
18+
return
19+
}
20+
21+
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
22+
for _, opt := range opts {
23+
str := stmt.Quote(opt.DBName)
24+
if opt.Expression != "" {
25+
str = opt.Expression
26+
}
27+
28+
if opt.Collate != "" {
29+
str += " COLLATE " + opt.Collate
30+
}
31+
32+
if opt.Sort != "" {
33+
str += " " + opt.Sort
34+
}
35+
results = append(results, clause.Expr{SQL: str})
36+
}
37+
return
38+
}
39+
40+
func (m Migrator) HasIndex(value interface{}, name string) bool {
41+
var count int64
42+
m.RunWithValue(value, func(stmt *gorm.Statement) error {
43+
if idx := stmt.Schema.LookIndex(name); idx != nil {
44+
name = idx.Name
45+
}
46+
47+
return m.DB.Raw(
48+
"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = CURRENT_SCHEMA()", stmt.Table, name,
49+
).Row().Scan(&count)
50+
})
51+
52+
return count > 0
53+
}
54+
55+
func (m Migrator) CreateIndex(value interface{}, name string) error {
56+
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
57+
if idx := stmt.Schema.LookIndex(name); idx != nil {
58+
opts := m.BuildIndexOptions(idx.Fields, stmt)
59+
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
60+
61+
createIndexSQL := "CREATE "
62+
if idx.Class != "" {
63+
createIndexSQL += idx.Class + " "
64+
}
65+
createIndexSQL += "INDEX ?"
66+
67+
if idx.Type != "" {
68+
createIndexSQL += " USING " + idx.Type
69+
}
70+
createIndexSQL += " ON ??"
71+
72+
if idx.Where != "" {
73+
createIndexSQL += " WHERE " + idx.Where
74+
}
75+
76+
return m.DB.Exec(createIndexSQL, values...).Error
77+
}
78+
79+
return fmt.Errorf("failed to create index with name %v", name)
80+
})
81+
}
82+
83+
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
84+
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
85+
return m.DB.Exec(
86+
"ALTER INDEX ? RENAME TO ?",
87+
clause.Column{Name: oldName}, clause.Column{Name: newName},
88+
).Error
89+
})
90+
}
91+
92+
func (m Migrator) DropIndex(value interface{}, name string) error {
93+
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
94+
if idx := stmt.Schema.LookIndex(name); idx != nil {
95+
name = idx.Name
96+
}
97+
98+
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
99+
})
100+
}
101+
102+
func (m Migrator) HasTable(value interface{}) bool {
103+
var count int64
104+
m.RunWithValue(value, func(stmt *gorm.Statement) error {
105+
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND table_type = ?", stmt.Table, "BASE TABLE").Row().Scan(&count)
106+
})
107+
108+
return count > 0
109+
}
110+
111+
func (m Migrator) DropTable(values ...interface{}) error {
112+
values = m.ReorderModels(values, false)
113+
tx := m.DB.Session(&gorm.Session{})
114+
for i := len(values) - 1; i >= 0; i-- {
115+
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
116+
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", clause.Table{Name: stmt.Table}).Error
117+
}); err != nil {
118+
return err
119+
}
120+
}
121+
return nil
122+
}
123+
124+
func (m Migrator) HasColumn(value interface{}, field string) bool {
125+
var count int64
126+
m.RunWithValue(value, func(stmt *gorm.Statement) error {
127+
name := field
128+
if field := stmt.Schema.LookUpField(field); field != nil {
129+
name = field.DBName
130+
}
131+
132+
return m.DB.Raw(
133+
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?",
134+
stmt.Table, name,
135+
).Row().Scan(&count)
136+
})
137+
138+
return count > 0
139+
}

postgres.go

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package postgres
2+
3+
import (
4+
"database/sql"
5+
"fmt"
6+
"regexp"
7+
"strconv"
8+
9+
"gorm.io/gorm"
10+
"gorm.io/gorm/callbacks"
11+
"gorm.io/gorm/clause"
12+
"gorm.io/gorm/logger"
13+
"gorm.io/gorm/migrator"
14+
"gorm.io/gorm/schema"
15+
_ "github.com/lib/pq"
16+
)
17+
18+
type Dialector struct {
19+
DSN string
20+
}
21+
22+
func Open(dsn string) gorm.Dialector {
23+
return &Dialector{DSN: dsn}
24+
}
25+
26+
func (dialector Dialector) Name() string {
27+
return "postgres"
28+
}
29+
30+
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
31+
// register callbacks
32+
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
33+
WithReturning: true,
34+
})
35+
db.ConnPool, err = sql.Open("postgres", dialector.DSN)
36+
return
37+
}
38+
39+
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
40+
return Migrator{migrator.Migrator{Config: migrator.Config{
41+
DB: db,
42+
Dialector: dialector,
43+
CreateIndexAfterCreateTable: true,
44+
}}}
45+
}
46+
47+
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
48+
writer.WriteByte('$')
49+
writer.WriteString(strconv.Itoa(len(stmt.Vars)))
50+
}
51+
52+
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
53+
writer.WriteByte('"')
54+
writer.WriteString(str)
55+
writer.WriteByte('"')
56+
}
57+
58+
var numericPlaceholder = regexp.MustCompile("\\$(\\d+)")
59+
60+
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
61+
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
62+
}
63+
64+
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
65+
switch field.DataType {
66+
case schema.Bool:
67+
return "boolean"
68+
case schema.Int, schema.Uint:
69+
if field.AutoIncrement || field == field.Schema.PrioritizedPrimaryField {
70+
switch {
71+
case field.Size < 16:
72+
return "smallserial"
73+
case field.Size < 31:
74+
return "serial"
75+
default:
76+
return "bigserial"
77+
}
78+
} else {
79+
switch {
80+
case field.Size < 16:
81+
return "smallint"
82+
case field.Size < 31:
83+
return "integer"
84+
default:
85+
return "bigint"
86+
}
87+
}
88+
case schema.Float:
89+
return "decimal"
90+
case schema.String:
91+
if field.Size > 0 {
92+
return fmt.Sprintf("varchar(%d)", field.Size)
93+
}
94+
return "text"
95+
case schema.Time:
96+
return "timestamptz"
97+
case schema.Bytes:
98+
return "bytea"
99+
}
100+
101+
return ""
102+
}

0 commit comments

Comments
 (0)