Skip to content

Commit

Permalink
feat: field set value
Browse files Browse the repository at this point in the history
  • Loading branch information
tr1v3r committed Sep 28, 2021
1 parent 5c0d457 commit ec9d643
Show file tree
Hide file tree
Showing 13 changed files with 194 additions and 39 deletions.
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ err // returns error

##### Create record with selected fields

Create a record and assgin a value to the fields specified.
Create a record and assign a value to the fields specified.

```go
u := query.Use(db).User
Expand Down Expand Up @@ -1191,17 +1191,17 @@ initialize struct with more attributes if record not found, those `Attrs` won’
u := query.Use(db).User

// User not found, initialize it with give conditions and Attrs
user, err := u.WithContext(ctx).Where(u.Name.Eq("non_existing")).Attrs(u.Age.Eq(20)).FirstOrInit()
user, err := u.WithContext(ctx).Where(u.Name.Eq("non_existing")).Attrs(u.Age.Value(20)).FirstOrInit()
// SELECT * FROM USERS WHERE name = 'non_existing' ORDER BY id LIMIT 1;
// user -> User{Name: "non_existing", Age: 20}

// User not found, initialize it with give conditions and Attrs
user, err := u.WithContext(ctx).Where(u.Name.Eq("non_existing")).Attrs(u.Age.Eq(20)).FirstOrInit()
user, err := u.WithContext(ctx).Where(u.Name.Eq("non_existing")).Attrs(u.Age.Value(20)).FirstOrInit()
// SELECT * FROM USERS WHERE name = 'non_existing' ORDER BY id LIMIT 1;
// user -> User{Name: "non_existing", Age: 20}

// Found user with `name` = `modi`, attributes will be ignored
user, err := u.WithContext(ctx).Where(u.Name.Eq("modi")).Attrs(u.Age.Eq(20)).FirstOrInit()
user, err := u.WithContext(ctx).Where(u.Name.Eq("modi")).Attrs(u.Age.Value(20)).FirstOrInit()
// SELECT * FROM USERS WHERE name = modi' ORDER BY id LIMIT 1;
// user -> User{ID: 1, Name: "modi", Age: 17}
```
Expand All @@ -1210,11 +1210,11 @@ user, err := u.WithContext(ctx).Where(u.Name.Eq("modi")).Attrs(u.Age.Eq(20)).Fir

```go
// User not found, initialize it with give conditions and Assign attributes
user, err := u.WithContext(ctx).Where(u.Name.Eq("non_existing")).Assign(u.Age.Eq(20)).FirstOrInit()
user, err := u.WithContext(ctx).Where(u.Name.Eq("non_existing")).Assign(u.Age.Value(20)).FirstOrInit()
// user -> User{Name: "non_existing", Age: 20}

// Found user with `name` = `modi`, update it with Assign attributes
user, err := u.WithContext(ctx).Where(u.Name.Eq("modi")).Assign(u.Age.Eq(20)).FirstOrInit()
user, err := u.WithContext(ctx).Where(u.Name.Eq("modi")).Assign(u.Age.Value(20)).FirstOrInit()
// SELECT * FROM USERS WHERE name = modi' ORDER BY id LIMIT 1;
// user -> User{ID: 111, Name: "modi", Age: 20}
```
Expand Down Expand Up @@ -1242,13 +1242,13 @@ Create struct with more attributes if record not found, those `Attrs` won’t be
u := query.Use(db).User

// User not found, create it with give conditions and Attrs
user, err := u.WithContext(ctx).Where(u.Name.Eq("non_existing")).Attrs(u.Age.Eq(20)).FirstOrCreate()
user, err := u.WithContext(ctx).Where(u.Name.Eq("non_existing")).Attrs(u.Age.Value(20)).FirstOrCreate()
// SELECT * FROM users WHERE name = 'non_existing' ORDER BY id LIMIT 1;
// INSERT INTO "users" (name, age) VALUES ("non_existing", 20);
// user -> User{ID: 112, Name: "non_existing", Age: 20}

// Found user with `name` = `modi`, attributes will be ignored
user, err := u.WithContext(ctx).Where(u.Name.Eq("modi")).Attrs(u.Age.Eq(20)).FirstOrCreate()
user, err := u.WithContext(ctx).Where(u.Name.Eq("modi")).Attrs(u.Age.Value(20)).FirstOrCreate()
// SELECT * FROM users WHERE name = 'modi' ORDER BY id LIMIT 1;
// user -> User{ID: 111, Name: "modi", Age: 18}
```
Expand All @@ -1259,13 +1259,13 @@ user, err := u.WithContext(ctx).Where(u.Name.Eq("modi")).Attrs(u.Age.Eq(20)).Fir
u := query.Use(db).User

// User not found, initialize it with give conditions and Assign attributes
user, err := u.WithContext(ctx).Where(u.Name.Eq("non_existing")).Assign(u.Age.Eq(20)).FirstOrCreate()
user, err := u.WithContext(ctx).Where(u.Name.Eq("non_existing")).Assign(u.Age.Value(20)).FirstOrCreate()
// SELECT * FROM users WHERE name = 'non_existing' ORDER BY id LIMIT 1;
// INSERT INTO "users" (name, age) VALUES ("non_existing", 20);
// user -> User{ID: 112, Name: "non_existing", Age: 20}

// Found user with `name` = `modi`, update it with Assign attributes
user, err := u.WithContext(ctx).Where(u.Name.Eq("modi")).Assign(u.Age.Eq(20)).FirstOrCreate(&user)
user, err := u.WithContext(ctx).Where(u.Name.Eq("modi")).Assign(u.Age.Value(20)).FirstOrCreate(&user)
// SELECT * FROM users WHERE name = 'modi' ORDER BY id LIMIT 1;
// UPDATE users SET age=20 WHERE id = 111;
// user -> User{ID: 111, Name: "modi", Age: 20}
Expand Down
44 changes: 28 additions & 16 deletions do.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,22 @@ func (d *DO) join(table schema.Tabler, joinType clause.JoinType, conds []field.E
return d.getInstance(d.db.Clauses(from))
}

func (d *DO) Attrs(attrs ...field.Expr) Dao {
return d.getInstance(d.db.Attrs(toExpression(attrs...)))
func (d *DO) Attrs(attrs ...field.AssignExpr) Dao {
return d.getInstance(d.db.Attrs(d.attrsValue(attrs)...))
}

func (d *DO) Assign(attrs ...field.Expr) Dao {
return d.getInstance(d.db.Assign(toExpressionInterface(attrs...)...))
func (d *DO) Assign(attrs ...field.AssignExpr) Dao {
return d.getInstance(d.db.Assign(d.attrsValue(attrs)...))
}

func (d *DO) attrsValue(attrs []field.AssignExpr) []interface{} {
values := make([]interface{}, 0, len(attrs))
for _, attr := range attrs {
if expr, ok := attr.AssignExpr().(clause.Eq); ok {
values = append(values, expr)
}
}
return values
}

func (d *DO) Joins(field field.RelationField) Dao {
Expand Down Expand Up @@ -465,8 +475,8 @@ func (d *DO) Update(column field.Expr, value interface{}) (info resultInfo, err

var result *gorm.DB
switch value := value.(type) {
case field.Expr:
result = tx.Update(columnStr, value.RawExpr())
case field.AssignExpr:
result = tx.Update(columnStr, value.AssignExpr())
case subQuery:
result = tx.Update(columnStr, value.underlyingDB())
default:
Expand All @@ -475,8 +485,8 @@ func (d *DO) Update(column field.Expr, value interface{}) (info resultInfo, err
return resultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
}

func (d *DO) UpdateSimple(columns ...field.Expr) (info resultInfo, err error) {
dest, err := parseExprs(d.db.Statement, columns)
func (d *DO) UpdateSimple(columns ...field.AssignExpr) (info resultInfo, err error) {
dest, err := assignMap(d.db.Statement, columns)
if err != nil {
return resultInfo{Error: err}, err
}
Expand Down Expand Up @@ -506,8 +516,8 @@ func (d *DO) UpdateColumn(column field.Expr, value interface{}) (info resultInfo
return resultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
}

func (d *DO) UpdateColumnSimple(columns ...field.Expr) (info resultInfo, err error) {
dest, err := parseExprs(d.db.Statement, columns)
func (d *DO) UpdateColumnSimple(columns ...field.AssignExpr) (info resultInfo, err error) {
dest, err := assignMap(d.db.Statement, columns)
if err != nil {
return resultInfo{Error: err}, err
}
Expand Down Expand Up @@ -639,14 +649,16 @@ func toInterfaceSlice(value interface{}) []interface{} {
}
}

func parseExprs(stmt *gorm.Statement, exprs []field.Expr) (map[string]interface{}, error) {
func assignMap(stmt *gorm.Statement, exprs []field.AssignExpr) (map[string]interface{}, error) {
dest := make(map[string]interface{}, len(exprs))
for _, e := range exprs {
expr, ok := e.RawExpr().(clause.Expression)
if !ok {
return nil, ErrInvalidExpression
for _, expr := range exprs {
target := expr.BuildColumn(stmt, field.WithoutQuote).String()
switch e := expr.AssignExpr().(type) {
case clause.Expr:
dest[target] = e
case clause.Eq:
dest[target] = e.Value
}
dest[e.BuildColumn(stmt, field.WithoutQuote).String()] = expr
}
return dest, nil
}
Expand Down
3 changes: 0 additions & 3 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ package gen
import "errors"

var (
// ErrInvalidExpression invalid Expression
ErrInvalidExpression = errors.New("invalid expression")

// ErrEmptyCondition empty condition
ErrEmptyCondition = errors.New("empty condition")
)
2 changes: 0 additions & 2 deletions field/association.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (
"gorm.io/gorm/schema"
)

var s = clause.Associations

type RelationshipType schema.RelationshipType

const (
Expand Down
8 changes: 8 additions & 0 deletions field/bool.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,11 @@ func (field Bool) BitAnd(value bool) Expr {
func (field Bool) BitOr(value bool) Expr {
return Bool{field.bitOr(value)}
}

func (field Bool) Value(value bool) AssignExpr {
return field.value(value)
}

func (field Bool) Zero() AssignExpr {
return field.value(false)
}
20 changes: 18 additions & 2 deletions field/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ import (

var _ Expr = new(Field)

type AssignExpr interface {
Expr

AssignExpr() expression
}

// Expr a query expression about field
type Expr interface {
As(alias string) Expr
Expand Down Expand Up @@ -40,6 +46,10 @@ type expr struct {
func (e expr) BeCond() interface{} { return e.expression() }
func (expr) CondError() error { return nil }

func (e expr) AssignExpr() expression {
return e.expression()
}

func (e expr) expression() clause.Expression {
if e.e == nil {
return clause.NamedExpr{SQL: "?", Vars: []interface{}{e.col}}
Expand Down Expand Up @@ -144,13 +154,15 @@ func (e expr) Sum() Float64 {
return Float64{e.setE(clause.Expr{SQL: "SUM(?)", Vars: []interface{}{e.RawExpr()}})}
}

func (e expr) Null() AssignExpr {
return e.setE(clause.Eq{Column: e.col.Name, Value: nil})
}

func (e expr) WithTable(table string) Expr {
e.col.Table = table
return e
}

// TODO add value assign: Set(value)/SetNull()/SetZero()

// ======================== comparison between columns ========================
func (e expr) EqCol(col Expr) Expr {
return e.setE(clause.Expr{SQL: "? = ?", Vars: []interface{}{e.RawExpr(), col.RawExpr()}})
Expand Down Expand Up @@ -190,6 +202,10 @@ func (e expr) Desc() Expr {
}

// ======================== general experssion ========================
func (e expr) value(value interface{}) AssignExpr {
return e.setE(clause.Eq{Column: e.col.Name, Value: value})
}

func (e expr) between(values []interface{}) expr {
return e.setE(clause.Expr{SQL: "? BETWEEN ? AND ?", Vars: append([]interface{}{e.RawExpr()}, values...)})
}
Expand Down
4 changes: 4 additions & 0 deletions field/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ func (field Field) Like(value ScanValuer) Expr {
return expr{e: clause.Like{Column: field.RawExpr(), Value: value}}
}

func (field Field) Value(value ScanValuer) AssignExpr {
return field.value(value)
}

func (field Field) toSlice(values ...ScanValuer) []interface{} {
slice := make([]interface{}, len(values))
for i, v := range values {
Expand Down
16 changes: 16 additions & 0 deletions field/float.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ func (field Float64) Floor() Int {
return Int{field.floor()}
}

func (field Float64) Value(value float64) AssignExpr {
return field.value(value)
}

func (field Float64) Zero() AssignExpr {
return field.value(0)
}

func (field Float64) toSlice(values ...float64) []interface{} {
slice := make([]interface{}, len(values))
for i, v := range values {
Expand Down Expand Up @@ -158,6 +166,14 @@ func (field Float32) Floor() Int {
return Int{field.floor()}
}

func (field Float32) Value(value float32) AssignExpr {
return field.value(value)
}

func (field Float32) Zero() AssignExpr {
return field.value(0)
}

func (field Float32) toSlice(values ...float32) []interface{} {
slice := make([]interface{}, len(values))
for i, v := range values {
Expand Down
Loading

0 comments on commit ec9d643

Please sign in to comment.