Skip to content

Commit

Permalink
Merge pull request #46 from go-gorm/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
tr1v3r authored Aug 22, 2021
2 parents 00f9bdd + 736bfe0 commit b6cc0af
Show file tree
Hide file tree
Showing 16 changed files with 1,976 additions and 702 deletions.
1,021 changes: 1,018 additions & 3 deletions README.md

Large diffs are not rendered by default.

109 changes: 74 additions & 35 deletions do.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gen

import (
"database/sql"
"reflect"
"strings"

"gorm.io/gorm"
Expand Down Expand Up @@ -126,6 +127,9 @@ func (d *DO) build(opts ...stmtOpt) *gorm.Statement {
return stmt
}

// underlyingDO return self
func (d *DO) underlyingDO() *DO { return d }

// Debug return a DO with db in debug mode
func (d *DO) Debug() Dao {
return NewDO(d.db.Debug())
Expand Down Expand Up @@ -161,7 +165,23 @@ func (d *DO) Where(conds ...Condition) Dao {
}

func (d *DO) Order(columns ...field.Expr) Dao {
return NewDO(d.db.Clauses(clause.OrderBy{Expression: clause.CommaExpression{Exprs: toExpression(columns)}}))
// lazy build Columns
// if c, ok := d.db.Statement.Clauses[clause.OrderBy{}.Name()]; ok {
// if order, ok := c.Expression.(clause.OrderBy); ok {
// if expr, ok := order.Expression.(clause.CommaExpression); ok {
// expr.Exprs = append(expr.Exprs, toExpression(columns)...)
// return NewDO(d.db.Clauses(clause.OrderBy{Expression: expr}))
// }
// }
// }
// return NewDO(d.db.Clauses(clause.OrderBy{Expression: clause.CommaExpression{Exprs: toExpression(columns)}}))

// eager build Columns
orderArray := make([]string, len(columns))
for i, c := range columns {
orderArray[i] = c.BuildExpr(d.db.Statement)
}
return NewDO(d.db.Order(strings.Join(orderArray, ",")))
}

func (d *DO) Distinct(columns ...field.Expr) Dao {
Expand Down Expand Up @@ -250,34 +270,48 @@ func (d *DO) Save(value interface{}) error {
return d.db.Save(value).Error
}

func (d *DO) First(dest interface{}, conds ...field.Expr) error {
return d.db.Clauses(toExpression(conds)...).First(dest).Error
func (d *DO) First() (result interface{}, err error) {
return d.singleQuery(d.db.First)
}

func (d *DO) Take(dest interface{}, conds ...field.Expr) error {
return d.db.Clauses(toExpression(conds)...).Take(dest).Error
func (d *DO) Take() (result interface{}, err error) {
return d.singleQuery(d.db.Take)
}

func (d *DO) Last(dest interface{}, conds ...field.Expr) error {
return d.db.Clauses(toExpression(conds)...).Last(dest).Error
func (d *DO) Last() (result interface{}, err error) {
return d.singleQuery(d.db.Last)
}

func (d *DO) Find(dest interface{}, conds ...field.Expr) error {
return d.db.Clauses(toExpression(conds)...).Find(dest).Error
func (d *DO) singleQuery(query func(dest interface{}, conds ...interface{}) *gorm.DB) (result interface{}, err error) {
result = d.newResult()
if err := query(result).Error; err != nil {
return nil, err
}
return result, nil
}

func (d *DO) FindInBatches(dest interface{}, batchSize int, fc func(tx Dao, batch int) error) error {
return d.db.FindInBatches(dest, batchSize, func(tx *gorm.DB, batch int) error { return fc(NewDO(tx), batch) }).Error
func (d *DO) Find() (results interface{}, err error) {
return d.multiQuery(d.db.Find)
}

func (d *DO) FirstOrInit(dest interface{}, conds ...field.Expr) error {
return d.db.Clauses(toExpression(conds)...).FirstOrInit(dest).Error
func (d *DO) multiQuery(query func(dest interface{}, conds ...interface{}) *gorm.DB) (results interface{}, err error) {
resultsPtr := d.newResultSlicePointer()
err = query(resultsPtr).Error
return reflect.Indirect(reflect.ValueOf(resultsPtr)).Interface(), err
}

func (d *DO) FirstOrCreate(dest interface{}, conds ...field.Expr) error {
return d.db.Clauses(toExpression(conds)...).FirstOrCreate(dest).Error
func (d *DO) FindInBatches(dest interface{}, batchSize int, fc func(tx Dao, batch int) error) error {
return d.db.FindInBatches(dest, batchSize, func(tx *gorm.DB, batch int) error { return fc(NewDO(tx), batch) }).Error
}

// func (d *DO) FirstOrInit(dest interface{}, conds ...field.Expr) error {
// return d.db.Clauses(toExpression(conds)...).FirstOrInit(dest).Error
// }

// func (d *DO) FirstOrCreate(dest interface{}, conds ...field.Expr) error {
// return d.db.Clauses(toExpression(conds)...).FirstOrCreate(dest).Error
// }

func (d *DO) Update(column field.Expr, value interface{}) error {
switch expr := column.RawExpr().(type) {
case clause.Expression:
Expand All @@ -287,8 +321,8 @@ func (d *DO) Update(column field.Expr, value interface{}) error {
switch value := value.(type) {
case field.Expr:
return d.db.Update(column.Column().Name, value.RawExpr()).Error
case *DO:
return d.db.Update(column.Column().Name, value.db).Error
case subQuery:
return d.db.Update(column.Column().Name, value.UnderlyingDB()).Error
default:
return d.db.Update(column.Column().Name, value).Error
}
Expand All @@ -307,8 +341,8 @@ func (d *DO) UpdateColumn(column field.Expr, value interface{}) error {
switch value := value.(type) {
case field.Expr:
return d.db.UpdateColumn(column.Column().Name, value.RawExpr()).Error
case *DO:
return d.db.UpdateColumn(column.Column().Name, value.db).Error
case subQuery:
return d.db.UpdateColumn(column.Column().Name, value.UnderlyingDB()).Error
default:
return d.db.UpdateColumn(column.Column().Name, value).Error
}
Expand All @@ -318,12 +352,12 @@ func (d *DO) UpdateColumns(values interface{}) error {
return d.db.UpdateColumns(values).Error
}

func (d *DO) Delete(value interface{}, conds ...field.Expr) error {
return d.db.Clauses(toExpression(conds)...).Delete(value).Error
func (d *DO) Delete() error {
return d.db.Delete(d.db.Statement.Model).Error
}

func (d *DO) Count(count *int64) error {
return d.db.Count(count).Error
func (d *DO) Count() (count int64, err error) {
return count, d.db.Count(&count).Error
}

func (d *DO) Row() *sql.Row {
Expand Down Expand Up @@ -370,12 +404,16 @@ func (d *DO) RollBackTo(name string) Dao {
return NewDO(d.db.RollbackTo(name))
}

func toExpression(conds []field.Expr) []clause.Expression {
exprs := make([]clause.Expression, len(conds))
for i, cond := range conds {
exprs[i] = cond
}
return exprs
func (d *DO) newResult() interface{} {
return reflect.New(d.getModel()).Interface()
}

func (d *DO) newResultSlicePointer() interface{} {
return reflect.New(reflect.SliceOf(reflect.PtrTo(d.getModel()))).Interface()
}

func (d *DO) getModel() reflect.Type {
return reflect.Indirect(reflect.ValueOf(d.db.Statement.Model)).Type()
}

func hintToExpression(hs []Hint) []clause.Expression {
Expand All @@ -390,8 +428,8 @@ func condToExpression(conds []Condition) []clause.Expression {
exprs := make([]clause.Expression, 0, len(conds))
for _, cond := range conds {
switch cond := cond.(type) {
case *DO:
exprs = append(exprs, cond.buildCondition()...)
case subQuery:
exprs = append(exprs, cond.underlyingDO().buildCondition()...)
default:
exprs = append(exprs, cond)
}
Expand Down Expand Up @@ -452,7 +490,7 @@ func toInterfaceSlice(value interface{}) []interface{} {
// Table(u.Select(u.ID, u.Name).Where(u.Age.Gt(18))).Select()
// the above usage is equivalent to SQL statement:
// SELECT * FROM (SELECT `id`, `name` FROM `users_info` WHERE `age` > ?)"
func Table(subQueries ...Dao) Dao {
func Table(subQueries ...subQuery) Dao {
if len(subQueries) == 0 {
return NewDO(nil)
}
Expand All @@ -462,15 +500,16 @@ func Table(subQueries ...Dao) Dao {
for i, query := range subQueries {
tablePlaceholder[i] = "(?)"

do := query.(*DO)
do := query.underlyingDO()
tableExprs[i] = do.db
if do.alias != "" {
tablePlaceholder[i] += " AS " + do.Quote(do.alias)
}
}

db := subQueries[0].(*DO).db
return NewDO(db.Session(&gorm.Session{NewDB: true}).Table(strings.Join(tablePlaceholder, ", "), tableExprs...))
return NewDO(subQueries[0].underlyingDO().db.
Session(&gorm.Session{NewDB: true}).
Table(strings.Join(tablePlaceholder, ", "), tableExprs...))
}

// ======================== sub query method ========================
Expand Down
131 changes: 8 additions & 123 deletions do_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,131 +4,12 @@ import (
"reflect"
"strings"
"testing"
"time"

"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/utils/tests"
"gorm.io/hints"

"gorm.io/gen/field"
)

var db, _ = gorm.Open(tests.DummyDialector{}, nil)

func init() {
db = db.Debug()

callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "ORDER BY", "LIMIT"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "ORDER BY", "LIMIT"},
})
}

// UserRaw user data struct
type UserRaw struct {
ID uint `gorm:"primary_key"`
Name string
Age int
Score float64
Address string
Famous bool
RegisterAt time.Time
}

func (UserRaw) TableName() string {
return "users_info"
}

// StudentRaw student data struct
type StudentRaw struct {
ID int64 `gorm:"primary_key"`
Name string
Age int
Instructor int64 //导师
}

func (StudentRaw) TableName() string {
return "student"
}

// Teacher teacher data struct
type TeacherRaw struct {
ID int64 `gorm:"primary_key"`
Name string
}

func (TeacherRaw) TableName() string {
return "teacher"
}

type User struct {
DO

ID field.Uint
Name field.String
Age field.Int
Score field.Float64
Address field.String
Famous field.Bool
RegisterAt field.Time
}

var u = func() *User {
u := User{
ID: field.NewUint("", "id"),
Name: field.NewString("", "name"),
Age: field.NewInt("", "age"),
Score: field.NewFloat64("", "score"),
Address: field.NewString("", "address"),
Famous: field.NewBool("", "famous"),
RegisterAt: field.NewTime("", "register_at"),
}
u.UseDB(db.Session(&gorm.Session{DryRun: true}))
u.UseModel(UserRaw{})
return &u
}()

type Student struct {
DO

ID field.Int64
Name field.String
Age field.Int
Instructor field.Int64
}

var student = func() *Student {
s := Student{
ID: field.NewInt64("student", "id"),
Name: field.NewString("student", "name"),
Age: field.NewInt("student", "age"),
Instructor: field.NewInt64("student", "instructor"),
}
s.UseDB(db.Session(&gorm.Session{DryRun: true}))
s.UseModel(StudentRaw{})
return &s
}()

type Teacher struct {
DO

ID field.Int64
Name field.String
}

var teacher = func() *Teacher {
t := Teacher{
ID: field.NewInt64("teacher", "id"),
Name: field.NewString("teacher", "name"),
}
t.UseDB(db.Session(&gorm.Session{DryRun: true}))
t.UseModel(TeacherRaw{})
return &t
}()

func checkBuildExpr(t *testing.T, e Dao, opts []stmtOpt, result string, vars []interface{}) {
stmt := e.(*DO).build(opts...)
func checkBuildExpr(t *testing.T, e subQuery, opts []stmtOpt, result string, vars []interface{}) {
stmt := e.underlyingDO().build(opts...)

sql := strings.TrimSpace(stmt.SQL.String())
if sql != result {
Expand All @@ -142,7 +23,7 @@ func checkBuildExpr(t *testing.T, e Dao, opts []stmtOpt, result string, vars []i

func TestDO_methods(t *testing.T) {
testcases := []struct {
Expr Dao
Expr subQuery
Opts []stmtOpt
ExpectedVars []interface{}
Result string
Expand Down Expand Up @@ -199,7 +80,11 @@ func TestDO_methods(t *testing.T) {
},
{
Expr: u.Order(u.ID.Desc(), u.Age),
Result: "ORDER BY `id` DESC, `age`",
Result: "ORDER BY `id` DESC,`age`",
},
{
Expr: u.Order(u.ID.Desc()).Order(u.Age),
Result: "ORDER BY `id` DESC,`age`",
},
{
Expr: u.Hints(hints.New("hint")).Select(),
Expand Down
Loading

0 comments on commit b6cc0af

Please sign in to comment.