diff --git a/internal/generate/query.go b/internal/generate/query.go index f282c066..460aa77d 100644 --- a/internal/generate/query.go +++ b/internal/generate/query.go @@ -179,9 +179,9 @@ func (b *QueryStructMeta) ReviseDIYMethod() error { if tableName == nil { methods = append(methods, parser.DefaultMethodTableName(b.ModelStructName)) } else { - //e.g. return "@@table" => return TableNameUser + // e.g. return "@@table" => return TableNameUser tableName.Body = strings.ReplaceAll(tableName.Body, "\"@@table\"", "TableName"+b.ModelStructName) - //e.g. return "t_@@table" => return "t_user" + // e.g. return "t_@@table" => return "t_user" tableName.Body = strings.ReplaceAll(tableName.Body, "@@table", b.TableName) } b.ModelMethods = methods @@ -194,7 +194,7 @@ func (b *QueryStructMeta) ReviseDIYMethod() error { func (b *QueryStructMeta) addMethodFromAddMethodOpt(methods ...interface{}) *QueryStructMeta { for _, method := range methods { - modelMethods, err := parser.GetModelMethod(method, 5) + modelMethods, err := parser.GetModelMethod(method) if err != nil { panic("add diy method err:" + err.Error()) } diff --git a/internal/parser/export.go b/internal/parser/export.go index 4753484a..901b78f8 100644 --- a/internal/parser/export.go +++ b/internal/parser/export.go @@ -74,7 +74,7 @@ func fileExists(path string) bool { } // GetModelMethod get diy methods -func GetModelMethod(v interface{}, skip int) (method *DIYMethods, err error) { +func GetModelMethod(v interface{}) (method *DIYMethods, err error) { method = new(DIYMethods) // get diy method info by input value, must input a function or a struct @@ -98,7 +98,15 @@ func GetModelMethod(v interface{}, skip int) (method *DIYMethods, err error) { // if struct in main file ctx := build.Default if method.pkgPath == "main" { - _, file, _, _ := runtime.Caller(skip) + var skip int + var file string + for { + _, file, _, _ = runtime.Caller(skip) + if !(strings.Contains(file, "gorm/gen/generator.go") || strings.Contains(file, "gorm/gen/internal")) || file == "" { + break + } + skip++ + } p, err = ctx.ImportDir(filepath.Dir(file), build.ImportComment) } else { p, err = ctx.Import(method.pkgPath, "", build.ImportComment) diff --git a/tests/.expect/dal_5/model/users.gen.go b/tests/.expect/dal_5/model/users.gen.go new file mode 100644 index 00000000..d65d4f54 --- /dev/null +++ b/tests/.expect/dal_5/model/users.gen.go @@ -0,0 +1,39 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameUser = "users" + +// User mapped from table +type User struct { + ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` + CreatedAt *time.Time `gorm:"column:created_at" json:"-"` + Name *string `gorm:"column:name;index:idx_name,priority:1;comment:oneline" json:"-"` + Address *string `gorm:"column:address" json:"-"` + RegisterTime *time.Time `gorm:"column:register_time" json:"-"` + Alive *bool `gorm:"column:alive;comment:multiline\nline1\nline2" json:"-"` + CompanyID *int64 `gorm:"column:company_id;default:666" json:"-"` + PrivateURL *string `gorm:"column:private_url;default:https://a.b.c" json:"-"` +} + +func (m *User) IsEmpty() bool { + if m == nil { + return true + } + return m.ID == 0 +} + +func (m *User) GetID() int { + return int(m.ID) +} + +// TableName User's table name +func (*User) TableName() string { + return TableNameUser +} diff --git a/tests/.expect/dal_5/query/gen.go b/tests/.expect/dal_5/query/gen.go new file mode 100644 index 00000000..a19d931b --- /dev/null +++ b/tests/.expect/dal_5/query/gen.go @@ -0,0 +1,103 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package query + +import ( + "context" + "database/sql" + + "gorm.io/gorm" + + "gorm.io/gen" + + "gorm.io/plugin/dbresolver" +) + +var ( + Q = new(Query) + User *user +) + +func SetDefault(db *gorm.DB, opts ...gen.DOOption) { + *Q = *Use(db, opts...) + User = &Q.User +} + +func Use(db *gorm.DB, opts ...gen.DOOption) *Query { + return &Query{ + db: db, + User: newUser(db, opts...), + } +} + +type Query struct { + db *gorm.DB + + User user +} + +func (q *Query) Available() bool { return q.db != nil } + +func (q *Query) clone(db *gorm.DB) *Query { + return &Query{ + db: db, + User: q.User.clone(db), + } +} + +func (q *Query) ReadDB() *Query { + return q.ReplaceDB(q.db.Clauses(dbresolver.Read)) +} + +func (q *Query) WriteDB() *Query { + return q.ReplaceDB(q.db.Clauses(dbresolver.Write)) +} + +func (q *Query) ReplaceDB(db *gorm.DB) *Query { + return &Query{ + db: db, + User: q.User.replaceDB(db), + } +} + +type queryCtx struct { + User IUserDo +} + +func (q *Query) WithContext(ctx context.Context) *queryCtx { + return &queryCtx{ + User: q.User.WithContext(ctx), + } +} + +func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error { + return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...) +} + +func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx { + tx := q.db.Begin(opts...) + return &QueryTx{Query: q.clone(tx), Error: tx.Error} +} + +type QueryTx struct { + *Query + Error error +} + +func (q *QueryTx) Commit() error { + return q.db.Commit().Error +} + +func (q *QueryTx) Rollback() error { + return q.db.Rollback().Error +} + +func (q *QueryTx) SavePoint(name string) error { + return q.db.SavePoint(name).Error +} + +func (q *QueryTx) RollbackTo(name string) error { + return q.db.RollbackTo(name).Error +} diff --git a/tests/.expect/dal_5/query/gen_test.go b/tests/.expect/dal_5/query/gen_test.go new file mode 100644 index 00000000..40cdc69b --- /dev/null +++ b/tests/.expect/dal_5/query/gen_test.go @@ -0,0 +1,118 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package query + +import ( + "context" + "fmt" + "reflect" + "sync" + "testing" + + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +type Input struct { + Args []interface{} +} + +type Expectation struct { + Ret []interface{} +} + +type TestCase struct { + Input + Expectation +} + +const dbName = "gen_test.db" + +var db *gorm.DB +var once sync.Once + +func init() { + InitializeDB() + db.AutoMigrate(&_another{}) +} + +func InitializeDB() { + once.Do(func() { + var err error + db, err = gorm.Open(sqlite.Open(dbName), &gorm.Config{}) + if err != nil { + panic(fmt.Errorf("open sqlite %q fail: %w", dbName, err)) + } + }) +} + +func assert(t *testing.T, methodName string, res, exp interface{}) { + if !reflect.DeepEqual(res, exp) { + t.Errorf("%v() gotResult = %v, want %v", methodName, res, exp) + } +} + +type _another struct { + ID uint64 `gorm:"primaryKey"` +} + +func (*_another) TableName() string { return "another_for_unit_test" } + +func Test_Available(t *testing.T) { + if !Use(db).Available() { + t.Errorf("query.Available() == false") + } +} + +func Test_WithContext(t *testing.T) { + query := Use(db) + if !query.Available() { + t.Errorf("query Use(db) fail: query.Available() == false") + } + + type Content string + var key, value Content = "gen_tag", "unit_test" + qCtx := query.WithContext(context.WithValue(context.Background(), key, value)) + + for _, ctx := range []context.Context{ + qCtx.User.UnderlyingDB().Statement.Context, + } { + if v := ctx.Value(key); v != value { + t.Errorf("get value from context fail, expect %q, got %q", value, v) + } + } +} + +func Test_Transaction(t *testing.T) { + query := Use(db) + if !query.Available() { + t.Errorf("query Use(db) fail: query.Available() == false") + } + + err := query.Transaction(func(tx *Query) error { return nil }) + if err != nil { + t.Errorf("query.Transaction execute fail: %s", err) + } + + tx := query.Begin() + + err = tx.SavePoint("point") + if err != nil { + t.Errorf("query tx SavePoint fail: %s", err) + } + err = tx.RollbackTo("point") + if err != nil { + t.Errorf("query tx RollbackTo fail: %s", err) + } + err = tx.Commit() + if err != nil { + t.Errorf("query tx Commit fail: %s", err) + } + + err = query.Begin().Rollback() + if err != nil { + t.Errorf("query tx Rollback fail: %s", err) + } +} diff --git a/tests/.expect/dal_5/query/users.gen.go b/tests/.expect/dal_5/query/users.gen.go new file mode 100644 index 00000000..3d2ffb34 --- /dev/null +++ b/tests/.expect/dal_5/query/users.gen.go @@ -0,0 +1,423 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package query + +import ( + "context" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + + "gorm.io/gen" + "gorm.io/gen/field" + + "gorm.io/plugin/dbresolver" + + "gorm.io/gen/tests/.gen/dal_5/model" +) + +func newUser(db *gorm.DB, opts ...gen.DOOption) user { + _user := user{} + + _user.userDo.UseDB(db, opts...) + _user.userDo.UseModel(&model.User{}) + + tableName := _user.userDo.TableName() + _user.ALL = field.NewAsterisk(tableName) + _user.ID = field.NewInt64(tableName, "id") + _user.CreatedAt = field.NewTime(tableName, "created_at") + _user.Name = field.NewString(tableName, "name") + _user.Address = field.NewString(tableName, "address") + _user.RegisterTime = field.NewTime(tableName, "register_time") + _user.Alive = field.NewBool(tableName, "alive") + _user.CompanyID = field.NewInt64(tableName, "company_id") + _user.PrivateURL = field.NewString(tableName, "private_url") + + _user.fillFieldMap() + + return _user +} + +type user struct { + userDo userDo + + ALL field.Asterisk + ID field.Int64 + CreatedAt field.Time + Name field.String // oneline + Address field.String + RegisterTime field.Time + /* + multiline + line1 + line2 + */ + Alive field.Bool + CompanyID field.Int64 + PrivateURL field.String + + fieldMap map[string]field.Expr +} + +func (u user) Table(newTableName string) *user { + u.userDo.UseTable(newTableName) + return u.updateTableName(newTableName) +} + +func (u user) As(alias string) *user { + u.userDo.DO = *(u.userDo.As(alias).(*gen.DO)) + return u.updateTableName(alias) +} + +func (u *user) updateTableName(table string) *user { + u.ALL = field.NewAsterisk(table) + u.ID = field.NewInt64(table, "id") + u.CreatedAt = field.NewTime(table, "created_at") + u.Name = field.NewString(table, "name") + u.Address = field.NewString(table, "address") + u.RegisterTime = field.NewTime(table, "register_time") + u.Alive = field.NewBool(table, "alive") + u.CompanyID = field.NewInt64(table, "company_id") + u.PrivateURL = field.NewString(table, "private_url") + + u.fillFieldMap() + + return u +} + +func (u *user) WithContext(ctx context.Context) IUserDo { return u.userDo.WithContext(ctx) } + +func (u user) TableName() string { return u.userDo.TableName() } + +func (u user) Alias() string { return u.userDo.Alias() } + +func (u *user) GetFieldByName(fieldName string) (field.OrderExpr, bool) { + _f, ok := u.fieldMap[fieldName] + if !ok || _f == nil { + return nil, false + } + _oe, ok := _f.(field.OrderExpr) + return _oe, ok +} + +func (u *user) fillFieldMap() { + u.fieldMap = make(map[string]field.Expr, 8) + u.fieldMap["id"] = u.ID + u.fieldMap["created_at"] = u.CreatedAt + u.fieldMap["name"] = u.Name + u.fieldMap["address"] = u.Address + u.fieldMap["register_time"] = u.RegisterTime + u.fieldMap["alive"] = u.Alive + u.fieldMap["company_id"] = u.CompanyID + u.fieldMap["private_url"] = u.PrivateURL +} + +func (u user) clone(db *gorm.DB) user { + u.userDo.ReplaceConnPool(db.Statement.ConnPool) + return u +} + +func (u user) replaceDB(db *gorm.DB) user { + u.userDo.ReplaceDB(db) + return u +} + +type userDo struct{ gen.DO } + +type IUserDo interface { + gen.SubQuery + Debug() IUserDo + WithContext(ctx context.Context) IUserDo + WithResult(fc func(tx gen.Dao)) gen.ResultInfo + ReplaceDB(db *gorm.DB) + ReadDB() IUserDo + WriteDB() IUserDo + As(alias string) gen.Dao + Session(config *gorm.Session) IUserDo + Columns(cols ...field.Expr) gen.Columns + Clauses(conds ...clause.Expression) IUserDo + Not(conds ...gen.Condition) IUserDo + Or(conds ...gen.Condition) IUserDo + Select(conds ...field.Expr) IUserDo + Where(conds ...gen.Condition) IUserDo + Order(conds ...field.Expr) IUserDo + Distinct(cols ...field.Expr) IUserDo + Omit(cols ...field.Expr) IUserDo + Join(table schema.Tabler, on ...field.Expr) IUserDo + LeftJoin(table schema.Tabler, on ...field.Expr) IUserDo + RightJoin(table schema.Tabler, on ...field.Expr) IUserDo + Group(cols ...field.Expr) IUserDo + Having(conds ...gen.Condition) IUserDo + Limit(limit int) IUserDo + Offset(offset int) IUserDo + Count() (count int64, err error) + Scopes(funcs ...func(gen.Dao) gen.Dao) IUserDo + Unscoped() IUserDo + Create(values ...*model.User) error + CreateInBatches(values []*model.User, batchSize int) error + Save(values ...*model.User) error + First() (*model.User, error) + Take() (*model.User, error) + Last() (*model.User, error) + Find() ([]*model.User, error) + FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.User, err error) + FindInBatches(result *[]*model.User, batchSize int, fc func(tx gen.Dao, batch int) error) error + Pluck(column field.Expr, dest interface{}) error + Delete(...*model.User) (info gen.ResultInfo, err error) + Update(column field.Expr, value interface{}) (info gen.ResultInfo, err error) + UpdateSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error) + Updates(value interface{}) (info gen.ResultInfo, err error) + UpdateColumn(column field.Expr, value interface{}) (info gen.ResultInfo, err error) + UpdateColumnSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error) + UpdateColumns(value interface{}) (info gen.ResultInfo, err error) + UpdateFrom(q gen.SubQuery) gen.Dao + Attrs(attrs ...field.AssignExpr) IUserDo + Assign(attrs ...field.AssignExpr) IUserDo + Joins(fields ...field.RelationField) IUserDo + Preload(fields ...field.RelationField) IUserDo + FirstOrInit() (*model.User, error) + FirstOrCreate() (*model.User, error) + FindByPage(offset int, limit int) (result []*model.User, count int64, err error) + ScanByPage(result interface{}, offset int, limit int) (count int64, err error) + Scan(result interface{}) (err error) + Returning(value interface{}, columns ...string) IUserDo + UnderlyingDB() *gorm.DB + schema.Tabler +} + +func (u userDo) Debug() IUserDo { + return u.withDO(u.DO.Debug()) +} + +func (u userDo) WithContext(ctx context.Context) IUserDo { + return u.withDO(u.DO.WithContext(ctx)) +} + +func (u userDo) ReadDB() IUserDo { + return u.Clauses(dbresolver.Read) +} + +func (u userDo) WriteDB() IUserDo { + return u.Clauses(dbresolver.Write) +} + +func (u userDo) Session(config *gorm.Session) IUserDo { + return u.withDO(u.DO.Session(config)) +} + +func (u userDo) Clauses(conds ...clause.Expression) IUserDo { + return u.withDO(u.DO.Clauses(conds...)) +} + +func (u userDo) Returning(value interface{}, columns ...string) IUserDo { + return u.withDO(u.DO.Returning(value, columns...)) +} + +func (u userDo) Not(conds ...gen.Condition) IUserDo { + return u.withDO(u.DO.Not(conds...)) +} + +func (u userDo) Or(conds ...gen.Condition) IUserDo { + return u.withDO(u.DO.Or(conds...)) +} + +func (u userDo) Select(conds ...field.Expr) IUserDo { + return u.withDO(u.DO.Select(conds...)) +} + +func (u userDo) Where(conds ...gen.Condition) IUserDo { + return u.withDO(u.DO.Where(conds...)) +} + +func (u userDo) Exists(subquery interface{ UnderlyingDB() *gorm.DB }) IUserDo { + return u.Where(field.CompareSubQuery(field.ExistsOp, nil, subquery.UnderlyingDB())) +} + +func (u userDo) Order(conds ...field.Expr) IUserDo { + return u.withDO(u.DO.Order(conds...)) +} + +func (u userDo) Distinct(cols ...field.Expr) IUserDo { + return u.withDO(u.DO.Distinct(cols...)) +} + +func (u userDo) Omit(cols ...field.Expr) IUserDo { + return u.withDO(u.DO.Omit(cols...)) +} + +func (u userDo) Join(table schema.Tabler, on ...field.Expr) IUserDo { + return u.withDO(u.DO.Join(table, on...)) +} + +func (u userDo) LeftJoin(table schema.Tabler, on ...field.Expr) IUserDo { + return u.withDO(u.DO.LeftJoin(table, on...)) +} + +func (u userDo) RightJoin(table schema.Tabler, on ...field.Expr) IUserDo { + return u.withDO(u.DO.RightJoin(table, on...)) +} + +func (u userDo) Group(cols ...field.Expr) IUserDo { + return u.withDO(u.DO.Group(cols...)) +} + +func (u userDo) Having(conds ...gen.Condition) IUserDo { + return u.withDO(u.DO.Having(conds...)) +} + +func (u userDo) Limit(limit int) IUserDo { + return u.withDO(u.DO.Limit(limit)) +} + +func (u userDo) Offset(offset int) IUserDo { + return u.withDO(u.DO.Offset(offset)) +} + +func (u userDo) Scopes(funcs ...func(gen.Dao) gen.Dao) IUserDo { + return u.withDO(u.DO.Scopes(funcs...)) +} + +func (u userDo) Unscoped() IUserDo { + return u.withDO(u.DO.Unscoped()) +} + +func (u userDo) Create(values ...*model.User) error { + if len(values) == 0 { + return nil + } + return u.DO.Create(values) +} + +func (u userDo) CreateInBatches(values []*model.User, batchSize int) error { + return u.DO.CreateInBatches(values, batchSize) +} + +// Save : !!! underlying implementation is different with GORM +// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values) +func (u userDo) Save(values ...*model.User) error { + if len(values) == 0 { + return nil + } + return u.DO.Save(values) +} + +func (u userDo) First() (*model.User, error) { + if result, err := u.DO.First(); err != nil { + return nil, err + } else { + return result.(*model.User), nil + } +} + +func (u userDo) Take() (*model.User, error) { + if result, err := u.DO.Take(); err != nil { + return nil, err + } else { + return result.(*model.User), nil + } +} + +func (u userDo) Last() (*model.User, error) { + if result, err := u.DO.Last(); err != nil { + return nil, err + } else { + return result.(*model.User), nil + } +} + +func (u userDo) Find() ([]*model.User, error) { + result, err := u.DO.Find() + return result.([]*model.User), err +} + +func (u userDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.User, err error) { + buf := make([]*model.User, 0, batchSize) + err = u.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error { + defer func() { results = append(results, buf...) }() + return fc(tx, batch) + }) + return results, err +} + +func (u userDo) FindInBatches(result *[]*model.User, batchSize int, fc func(tx gen.Dao, batch int) error) error { + return u.DO.FindInBatches(result, batchSize, fc) +} + +func (u userDo) Attrs(attrs ...field.AssignExpr) IUserDo { + return u.withDO(u.DO.Attrs(attrs...)) +} + +func (u userDo) Assign(attrs ...field.AssignExpr) IUserDo { + return u.withDO(u.DO.Assign(attrs...)) +} + +func (u userDo) Joins(fields ...field.RelationField) IUserDo { + for _, _f := range fields { + u = *u.withDO(u.DO.Joins(_f)) + } + return &u +} + +func (u userDo) Preload(fields ...field.RelationField) IUserDo { + for _, _f := range fields { + u = *u.withDO(u.DO.Preload(_f)) + } + return &u +} + +func (u userDo) FirstOrInit() (*model.User, error) { + if result, err := u.DO.FirstOrInit(); err != nil { + return nil, err + } else { + return result.(*model.User), nil + } +} + +func (u userDo) FirstOrCreate() (*model.User, error) { + if result, err := u.DO.FirstOrCreate(); err != nil { + return nil, err + } else { + return result.(*model.User), nil + } +} + +func (u userDo) FindByPage(offset int, limit int) (result []*model.User, count int64, err error) { + result, err = u.Offset(offset).Limit(limit).Find() + if err != nil { + return + } + + if size := len(result); 0 < limit && 0 < size && size < limit { + count = int64(size + offset) + return + } + + count, err = u.Offset(-1).Limit(-1).Count() + return +} + +func (u userDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) { + count, err = u.Count() + if err != nil { + return + } + + err = u.Offset(offset).Limit(limit).Scan(result) + return +} + +func (u userDo) Scan(result interface{}) (err error) { + return u.DO.Scan(result) +} + +func (u userDo) Delete(models ...*model.User) (result gen.ResultInfo, err error) { + return u.DO.Delete(models) +} + +func (u *userDo) withDO(do gen.Dao) *userDo { + u.DO = *do.(*gen.DO) + return u +} diff --git a/tests/.expect/dal_5/query/users.gen_test.go b/tests/.expect/dal_5/query/users.gen_test.go new file mode 100644 index 00000000..6906dad8 --- /dev/null +++ b/tests/.expect/dal_5/query/users.gen_test.go @@ -0,0 +1,145 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package query + +import ( + "context" + "fmt" + "testing" + + "gorm.io/gen" + "gorm.io/gen/field" + "gorm.io/gen/tests/.gen/dal_5/model" + "gorm.io/gorm/clause" +) + +func init() { + InitializeDB() + err := db.AutoMigrate(&model.User{}) + if err != nil { + fmt.Printf("Error: AutoMigrate(&model.User{}) fail: %s", err) + } +} + +func Test_userQuery(t *testing.T) { + user := newUser(db) + user = *user.As(user.TableName()) + _do := user.WithContext(context.Background()).Debug() + + primaryKey := field.NewString(user.TableName(), clause.PrimaryKey) + _, err := _do.Unscoped().Where(primaryKey.IsNotNull()).Delete() + if err != nil { + t.Error("clean table fail:", err) + return + } + + _, ok := user.GetFieldByName("") + if ok { + t.Error("GetFieldByName(\"\") from user success") + } + + err = _do.Create(&model.User{}) + if err != nil { + t.Error("create item in table fail:", err) + } + + err = _do.Save(&model.User{}) + if err != nil { + t.Error("create item in table fail:", err) + } + + err = _do.CreateInBatches([]*model.User{{}, {}}, 10) + if err != nil { + t.Error("create item in table fail:", err) + } + + _, err = _do.Select(user.ALL).Take() + if err != nil { + t.Error("Take() on table fail:", err) + } + + _, err = _do.First() + if err != nil { + t.Error("First() on table fail:", err) + } + + _, err = _do.Last() + if err != nil { + t.Error("First() on table fail:", err) + } + + _, err = _do.Where(primaryKey.IsNotNull()).FindInBatch(10, func(tx gen.Dao, batch int) error { return nil }) + if err != nil { + t.Error("FindInBatch() on table fail:", err) + } + + err = _do.Where(primaryKey.IsNotNull()).FindInBatches(&[]*model.User{}, 10, func(tx gen.Dao, batch int) error { return nil }) + if err != nil { + t.Error("FindInBatches() on table fail:", err) + } + + _, err = _do.Select(user.ALL).Where(primaryKey.IsNotNull()).Order(primaryKey.Desc()).Find() + if err != nil { + t.Error("Find() on table fail:", err) + } + + _, err = _do.Distinct(primaryKey).Take() + if err != nil { + t.Error("select Distinct() on table fail:", err) + } + + _, err = _do.Select(user.ALL).Omit(primaryKey).Take() + if err != nil { + t.Error("Omit() on table fail:", err) + } + + _, err = _do.Group(primaryKey).Find() + if err != nil { + t.Error("Group() on table fail:", err) + } + + _, err = _do.Scopes(func(dao gen.Dao) gen.Dao { return dao.Where(primaryKey.IsNotNull()) }).Find() + if err != nil { + t.Error("Scopes() on table fail:", err) + } + + _, _, err = _do.FindByPage(0, 1) + if err != nil { + t.Error("FindByPage() on table fail:", err) + } + + _, err = _do.ScanByPage(&model.User{}, 0, 1) + if err != nil { + t.Error("ScanByPage() on table fail:", err) + } + + _, err = _do.Attrs(primaryKey).Assign(primaryKey).FirstOrInit() + if err != nil { + t.Error("FirstOrInit() on table fail:", err) + } + + _, err = _do.Attrs(primaryKey).Assign(primaryKey).FirstOrCreate() + if err != nil { + t.Error("FirstOrCreate() on table fail:", err) + } + + var _a _another + var _aPK = field.NewString(_a.TableName(), clause.PrimaryKey) + + err = _do.Join(&_a, primaryKey.EqCol(_aPK)).Scan(map[string]interface{}{}) + if err != nil { + t.Error("Join() on table fail:", err) + } + + err = _do.LeftJoin(&_a, primaryKey.EqCol(_aPK)).Scan(map[string]interface{}{}) + if err != nil { + t.Error("LeftJoin() on table fail:", err) + } + + _, err = _do.Not().Or().Clauses().Take() + if err != nil { + t.Error("Not/Or/Clauses on table fail:", err) + } +} diff --git a/tests/.expect/dal_6/model/users.gen.go b/tests/.expect/dal_6/model/users.gen.go new file mode 100644 index 00000000..d65d4f54 --- /dev/null +++ b/tests/.expect/dal_6/model/users.gen.go @@ -0,0 +1,39 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package model + +import ( + "time" +) + +const TableNameUser = "users" + +// User mapped from table +type User struct { + ID int64 `gorm:"column:id;primaryKey;autoIncrement:true" json:"-"` + CreatedAt *time.Time `gorm:"column:created_at" json:"-"` + Name *string `gorm:"column:name;index:idx_name,priority:1;comment:oneline" json:"-"` + Address *string `gorm:"column:address" json:"-"` + RegisterTime *time.Time `gorm:"column:register_time" json:"-"` + Alive *bool `gorm:"column:alive;comment:multiline\nline1\nline2" json:"-"` + CompanyID *int64 `gorm:"column:company_id;default:666" json:"-"` + PrivateURL *string `gorm:"column:private_url;default:https://a.b.c" json:"-"` +} + +func (m *User) IsEmpty() bool { + if m == nil { + return true + } + return m.ID == 0 +} + +func (m *User) GetID() int { + return int(m.ID) +} + +// TableName User's table name +func (*User) TableName() string { + return TableNameUser +} diff --git a/tests/.expect/dal_6/query/gen.go b/tests/.expect/dal_6/query/gen.go new file mode 100644 index 00000000..a19d931b --- /dev/null +++ b/tests/.expect/dal_6/query/gen.go @@ -0,0 +1,103 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package query + +import ( + "context" + "database/sql" + + "gorm.io/gorm" + + "gorm.io/gen" + + "gorm.io/plugin/dbresolver" +) + +var ( + Q = new(Query) + User *user +) + +func SetDefault(db *gorm.DB, opts ...gen.DOOption) { + *Q = *Use(db, opts...) + User = &Q.User +} + +func Use(db *gorm.DB, opts ...gen.DOOption) *Query { + return &Query{ + db: db, + User: newUser(db, opts...), + } +} + +type Query struct { + db *gorm.DB + + User user +} + +func (q *Query) Available() bool { return q.db != nil } + +func (q *Query) clone(db *gorm.DB) *Query { + return &Query{ + db: db, + User: q.User.clone(db), + } +} + +func (q *Query) ReadDB() *Query { + return q.ReplaceDB(q.db.Clauses(dbresolver.Read)) +} + +func (q *Query) WriteDB() *Query { + return q.ReplaceDB(q.db.Clauses(dbresolver.Write)) +} + +func (q *Query) ReplaceDB(db *gorm.DB) *Query { + return &Query{ + db: db, + User: q.User.replaceDB(db), + } +} + +type queryCtx struct { + User IUserDo +} + +func (q *Query) WithContext(ctx context.Context) *queryCtx { + return &queryCtx{ + User: q.User.WithContext(ctx), + } +} + +func (q *Query) Transaction(fc func(tx *Query) error, opts ...*sql.TxOptions) error { + return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...) +} + +func (q *Query) Begin(opts ...*sql.TxOptions) *QueryTx { + tx := q.db.Begin(opts...) + return &QueryTx{Query: q.clone(tx), Error: tx.Error} +} + +type QueryTx struct { + *Query + Error error +} + +func (q *QueryTx) Commit() error { + return q.db.Commit().Error +} + +func (q *QueryTx) Rollback() error { + return q.db.Rollback().Error +} + +func (q *QueryTx) SavePoint(name string) error { + return q.db.SavePoint(name).Error +} + +func (q *QueryTx) RollbackTo(name string) error { + return q.db.RollbackTo(name).Error +} diff --git a/tests/.expect/dal_6/query/gen_test.go b/tests/.expect/dal_6/query/gen_test.go new file mode 100644 index 00000000..40cdc69b --- /dev/null +++ b/tests/.expect/dal_6/query/gen_test.go @@ -0,0 +1,118 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package query + +import ( + "context" + "fmt" + "reflect" + "sync" + "testing" + + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +type Input struct { + Args []interface{} +} + +type Expectation struct { + Ret []interface{} +} + +type TestCase struct { + Input + Expectation +} + +const dbName = "gen_test.db" + +var db *gorm.DB +var once sync.Once + +func init() { + InitializeDB() + db.AutoMigrate(&_another{}) +} + +func InitializeDB() { + once.Do(func() { + var err error + db, err = gorm.Open(sqlite.Open(dbName), &gorm.Config{}) + if err != nil { + panic(fmt.Errorf("open sqlite %q fail: %w", dbName, err)) + } + }) +} + +func assert(t *testing.T, methodName string, res, exp interface{}) { + if !reflect.DeepEqual(res, exp) { + t.Errorf("%v() gotResult = %v, want %v", methodName, res, exp) + } +} + +type _another struct { + ID uint64 `gorm:"primaryKey"` +} + +func (*_another) TableName() string { return "another_for_unit_test" } + +func Test_Available(t *testing.T) { + if !Use(db).Available() { + t.Errorf("query.Available() == false") + } +} + +func Test_WithContext(t *testing.T) { + query := Use(db) + if !query.Available() { + t.Errorf("query Use(db) fail: query.Available() == false") + } + + type Content string + var key, value Content = "gen_tag", "unit_test" + qCtx := query.WithContext(context.WithValue(context.Background(), key, value)) + + for _, ctx := range []context.Context{ + qCtx.User.UnderlyingDB().Statement.Context, + } { + if v := ctx.Value(key); v != value { + t.Errorf("get value from context fail, expect %q, got %q", value, v) + } + } +} + +func Test_Transaction(t *testing.T) { + query := Use(db) + if !query.Available() { + t.Errorf("query Use(db) fail: query.Available() == false") + } + + err := query.Transaction(func(tx *Query) error { return nil }) + if err != nil { + t.Errorf("query.Transaction execute fail: %s", err) + } + + tx := query.Begin() + + err = tx.SavePoint("point") + if err != nil { + t.Errorf("query tx SavePoint fail: %s", err) + } + err = tx.RollbackTo("point") + if err != nil { + t.Errorf("query tx RollbackTo fail: %s", err) + } + err = tx.Commit() + if err != nil { + t.Errorf("query tx Commit fail: %s", err) + } + + err = query.Begin().Rollback() + if err != nil { + t.Errorf("query tx Rollback fail: %s", err) + } +} diff --git a/tests/.expect/dal_6/query/users.gen.go b/tests/.expect/dal_6/query/users.gen.go new file mode 100644 index 00000000..f79fff04 --- /dev/null +++ b/tests/.expect/dal_6/query/users.gen.go @@ -0,0 +1,423 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package query + +import ( + "context" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + + "gorm.io/gen" + "gorm.io/gen/field" + + "gorm.io/plugin/dbresolver" + + "gorm.io/gen/tests/.gen/dal_6/model" +) + +func newUser(db *gorm.DB, opts ...gen.DOOption) user { + _user := user{} + + _user.userDo.UseDB(db, opts...) + _user.userDo.UseModel(&model.User{}) + + tableName := _user.userDo.TableName() + _user.ALL = field.NewAsterisk(tableName) + _user.ID = field.NewInt64(tableName, "id") + _user.CreatedAt = field.NewTime(tableName, "created_at") + _user.Name = field.NewString(tableName, "name") + _user.Address = field.NewString(tableName, "address") + _user.RegisterTime = field.NewTime(tableName, "register_time") + _user.Alive = field.NewBool(tableName, "alive") + _user.CompanyID = field.NewInt64(tableName, "company_id") + _user.PrivateURL = field.NewString(tableName, "private_url") + + _user.fillFieldMap() + + return _user +} + +type user struct { + userDo userDo + + ALL field.Asterisk + ID field.Int64 + CreatedAt field.Time + Name field.String // oneline + Address field.String + RegisterTime field.Time + /* + multiline + line1 + line2 + */ + Alive field.Bool + CompanyID field.Int64 + PrivateURL field.String + + fieldMap map[string]field.Expr +} + +func (u user) Table(newTableName string) *user { + u.userDo.UseTable(newTableName) + return u.updateTableName(newTableName) +} + +func (u user) As(alias string) *user { + u.userDo.DO = *(u.userDo.As(alias).(*gen.DO)) + return u.updateTableName(alias) +} + +func (u *user) updateTableName(table string) *user { + u.ALL = field.NewAsterisk(table) + u.ID = field.NewInt64(table, "id") + u.CreatedAt = field.NewTime(table, "created_at") + u.Name = field.NewString(table, "name") + u.Address = field.NewString(table, "address") + u.RegisterTime = field.NewTime(table, "register_time") + u.Alive = field.NewBool(table, "alive") + u.CompanyID = field.NewInt64(table, "company_id") + u.PrivateURL = field.NewString(table, "private_url") + + u.fillFieldMap() + + return u +} + +func (u *user) WithContext(ctx context.Context) IUserDo { return u.userDo.WithContext(ctx) } + +func (u user) TableName() string { return u.userDo.TableName() } + +func (u user) Alias() string { return u.userDo.Alias() } + +func (u *user) GetFieldByName(fieldName string) (field.OrderExpr, bool) { + _f, ok := u.fieldMap[fieldName] + if !ok || _f == nil { + return nil, false + } + _oe, ok := _f.(field.OrderExpr) + return _oe, ok +} + +func (u *user) fillFieldMap() { + u.fieldMap = make(map[string]field.Expr, 8) + u.fieldMap["id"] = u.ID + u.fieldMap["created_at"] = u.CreatedAt + u.fieldMap["name"] = u.Name + u.fieldMap["address"] = u.Address + u.fieldMap["register_time"] = u.RegisterTime + u.fieldMap["alive"] = u.Alive + u.fieldMap["company_id"] = u.CompanyID + u.fieldMap["private_url"] = u.PrivateURL +} + +func (u user) clone(db *gorm.DB) user { + u.userDo.ReplaceConnPool(db.Statement.ConnPool) + return u +} + +func (u user) replaceDB(db *gorm.DB) user { + u.userDo.ReplaceDB(db) + return u +} + +type userDo struct{ gen.DO } + +type IUserDo interface { + gen.SubQuery + Debug() IUserDo + WithContext(ctx context.Context) IUserDo + WithResult(fc func(tx gen.Dao)) gen.ResultInfo + ReplaceDB(db *gorm.DB) + ReadDB() IUserDo + WriteDB() IUserDo + As(alias string) gen.Dao + Session(config *gorm.Session) IUserDo + Columns(cols ...field.Expr) gen.Columns + Clauses(conds ...clause.Expression) IUserDo + Not(conds ...gen.Condition) IUserDo + Or(conds ...gen.Condition) IUserDo + Select(conds ...field.Expr) IUserDo + Where(conds ...gen.Condition) IUserDo + Order(conds ...field.Expr) IUserDo + Distinct(cols ...field.Expr) IUserDo + Omit(cols ...field.Expr) IUserDo + Join(table schema.Tabler, on ...field.Expr) IUserDo + LeftJoin(table schema.Tabler, on ...field.Expr) IUserDo + RightJoin(table schema.Tabler, on ...field.Expr) IUserDo + Group(cols ...field.Expr) IUserDo + Having(conds ...gen.Condition) IUserDo + Limit(limit int) IUserDo + Offset(offset int) IUserDo + Count() (count int64, err error) + Scopes(funcs ...func(gen.Dao) gen.Dao) IUserDo + Unscoped() IUserDo + Create(values ...*model.User) error + CreateInBatches(values []*model.User, batchSize int) error + Save(values ...*model.User) error + First() (*model.User, error) + Take() (*model.User, error) + Last() (*model.User, error) + Find() ([]*model.User, error) + FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.User, err error) + FindInBatches(result *[]*model.User, batchSize int, fc func(tx gen.Dao, batch int) error) error + Pluck(column field.Expr, dest interface{}) error + Delete(...*model.User) (info gen.ResultInfo, err error) + Update(column field.Expr, value interface{}) (info gen.ResultInfo, err error) + UpdateSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error) + Updates(value interface{}) (info gen.ResultInfo, err error) + UpdateColumn(column field.Expr, value interface{}) (info gen.ResultInfo, err error) + UpdateColumnSimple(columns ...field.AssignExpr) (info gen.ResultInfo, err error) + UpdateColumns(value interface{}) (info gen.ResultInfo, err error) + UpdateFrom(q gen.SubQuery) gen.Dao + Attrs(attrs ...field.AssignExpr) IUserDo + Assign(attrs ...field.AssignExpr) IUserDo + Joins(fields ...field.RelationField) IUserDo + Preload(fields ...field.RelationField) IUserDo + FirstOrInit() (*model.User, error) + FirstOrCreate() (*model.User, error) + FindByPage(offset int, limit int) (result []*model.User, count int64, err error) + ScanByPage(result interface{}, offset int, limit int) (count int64, err error) + Scan(result interface{}) (err error) + Returning(value interface{}, columns ...string) IUserDo + UnderlyingDB() *gorm.DB + schema.Tabler +} + +func (u userDo) Debug() IUserDo { + return u.withDO(u.DO.Debug()) +} + +func (u userDo) WithContext(ctx context.Context) IUserDo { + return u.withDO(u.DO.WithContext(ctx)) +} + +func (u userDo) ReadDB() IUserDo { + return u.Clauses(dbresolver.Read) +} + +func (u userDo) WriteDB() IUserDo { + return u.Clauses(dbresolver.Write) +} + +func (u userDo) Session(config *gorm.Session) IUserDo { + return u.withDO(u.DO.Session(config)) +} + +func (u userDo) Clauses(conds ...clause.Expression) IUserDo { + return u.withDO(u.DO.Clauses(conds...)) +} + +func (u userDo) Returning(value interface{}, columns ...string) IUserDo { + return u.withDO(u.DO.Returning(value, columns...)) +} + +func (u userDo) Not(conds ...gen.Condition) IUserDo { + return u.withDO(u.DO.Not(conds...)) +} + +func (u userDo) Or(conds ...gen.Condition) IUserDo { + return u.withDO(u.DO.Or(conds...)) +} + +func (u userDo) Select(conds ...field.Expr) IUserDo { + return u.withDO(u.DO.Select(conds...)) +} + +func (u userDo) Where(conds ...gen.Condition) IUserDo { + return u.withDO(u.DO.Where(conds...)) +} + +func (u userDo) Exists(subquery interface{ UnderlyingDB() *gorm.DB }) IUserDo { + return u.Where(field.CompareSubQuery(field.ExistsOp, nil, subquery.UnderlyingDB())) +} + +func (u userDo) Order(conds ...field.Expr) IUserDo { + return u.withDO(u.DO.Order(conds...)) +} + +func (u userDo) Distinct(cols ...field.Expr) IUserDo { + return u.withDO(u.DO.Distinct(cols...)) +} + +func (u userDo) Omit(cols ...field.Expr) IUserDo { + return u.withDO(u.DO.Omit(cols...)) +} + +func (u userDo) Join(table schema.Tabler, on ...field.Expr) IUserDo { + return u.withDO(u.DO.Join(table, on...)) +} + +func (u userDo) LeftJoin(table schema.Tabler, on ...field.Expr) IUserDo { + return u.withDO(u.DO.LeftJoin(table, on...)) +} + +func (u userDo) RightJoin(table schema.Tabler, on ...field.Expr) IUserDo { + return u.withDO(u.DO.RightJoin(table, on...)) +} + +func (u userDo) Group(cols ...field.Expr) IUserDo { + return u.withDO(u.DO.Group(cols...)) +} + +func (u userDo) Having(conds ...gen.Condition) IUserDo { + return u.withDO(u.DO.Having(conds...)) +} + +func (u userDo) Limit(limit int) IUserDo { + return u.withDO(u.DO.Limit(limit)) +} + +func (u userDo) Offset(offset int) IUserDo { + return u.withDO(u.DO.Offset(offset)) +} + +func (u userDo) Scopes(funcs ...func(gen.Dao) gen.Dao) IUserDo { + return u.withDO(u.DO.Scopes(funcs...)) +} + +func (u userDo) Unscoped() IUserDo { + return u.withDO(u.DO.Unscoped()) +} + +func (u userDo) Create(values ...*model.User) error { + if len(values) == 0 { + return nil + } + return u.DO.Create(values) +} + +func (u userDo) CreateInBatches(values []*model.User, batchSize int) error { + return u.DO.CreateInBatches(values, batchSize) +} + +// Save : !!! underlying implementation is different with GORM +// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values) +func (u userDo) Save(values ...*model.User) error { + if len(values) == 0 { + return nil + } + return u.DO.Save(values) +} + +func (u userDo) First() (*model.User, error) { + if result, err := u.DO.First(); err != nil { + return nil, err + } else { + return result.(*model.User), nil + } +} + +func (u userDo) Take() (*model.User, error) { + if result, err := u.DO.Take(); err != nil { + return nil, err + } else { + return result.(*model.User), nil + } +} + +func (u userDo) Last() (*model.User, error) { + if result, err := u.DO.Last(); err != nil { + return nil, err + } else { + return result.(*model.User), nil + } +} + +func (u userDo) Find() ([]*model.User, error) { + result, err := u.DO.Find() + return result.([]*model.User), err +} + +func (u userDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*model.User, err error) { + buf := make([]*model.User, 0, batchSize) + err = u.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error { + defer func() { results = append(results, buf...) }() + return fc(tx, batch) + }) + return results, err +} + +func (u userDo) FindInBatches(result *[]*model.User, batchSize int, fc func(tx gen.Dao, batch int) error) error { + return u.DO.FindInBatches(result, batchSize, fc) +} + +func (u userDo) Attrs(attrs ...field.AssignExpr) IUserDo { + return u.withDO(u.DO.Attrs(attrs...)) +} + +func (u userDo) Assign(attrs ...field.AssignExpr) IUserDo { + return u.withDO(u.DO.Assign(attrs...)) +} + +func (u userDo) Joins(fields ...field.RelationField) IUserDo { + for _, _f := range fields { + u = *u.withDO(u.DO.Joins(_f)) + } + return &u +} + +func (u userDo) Preload(fields ...field.RelationField) IUserDo { + for _, _f := range fields { + u = *u.withDO(u.DO.Preload(_f)) + } + return &u +} + +func (u userDo) FirstOrInit() (*model.User, error) { + if result, err := u.DO.FirstOrInit(); err != nil { + return nil, err + } else { + return result.(*model.User), nil + } +} + +func (u userDo) FirstOrCreate() (*model.User, error) { + if result, err := u.DO.FirstOrCreate(); err != nil { + return nil, err + } else { + return result.(*model.User), nil + } +} + +func (u userDo) FindByPage(offset int, limit int) (result []*model.User, count int64, err error) { + result, err = u.Offset(offset).Limit(limit).Find() + if err != nil { + return + } + + if size := len(result); 0 < limit && 0 < size && size < limit { + count = int64(size + offset) + return + } + + count, err = u.Offset(-1).Limit(-1).Count() + return +} + +func (u userDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) { + count, err = u.Count() + if err != nil { + return + } + + err = u.Offset(offset).Limit(limit).Scan(result) + return +} + +func (u userDo) Scan(result interface{}) (err error) { + return u.DO.Scan(result) +} + +func (u userDo) Delete(models ...*model.User) (result gen.ResultInfo, err error) { + return u.DO.Delete(models) +} + +func (u *userDo) withDO(do gen.Dao) *userDo { + u.DO = *do.(*gen.DO) + return u +} diff --git a/tests/.expect/dal_6/query/users.gen_test.go b/tests/.expect/dal_6/query/users.gen_test.go new file mode 100644 index 00000000..0a15043b --- /dev/null +++ b/tests/.expect/dal_6/query/users.gen_test.go @@ -0,0 +1,145 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package query + +import ( + "context" + "fmt" + "testing" + + "gorm.io/gen" + "gorm.io/gen/field" + "gorm.io/gen/tests/.gen/dal_6/model" + "gorm.io/gorm/clause" +) + +func init() { + InitializeDB() + err := db.AutoMigrate(&model.User{}) + if err != nil { + fmt.Printf("Error: AutoMigrate(&model.User{}) fail: %s", err) + } +} + +func Test_userQuery(t *testing.T) { + user := newUser(db) + user = *user.As(user.TableName()) + _do := user.WithContext(context.Background()).Debug() + + primaryKey := field.NewString(user.TableName(), clause.PrimaryKey) + _, err := _do.Unscoped().Where(primaryKey.IsNotNull()).Delete() + if err != nil { + t.Error("clean table fail:", err) + return + } + + _, ok := user.GetFieldByName("") + if ok { + t.Error("GetFieldByName(\"\") from user success") + } + + err = _do.Create(&model.User{}) + if err != nil { + t.Error("create item in table fail:", err) + } + + err = _do.Save(&model.User{}) + if err != nil { + t.Error("create item in table fail:", err) + } + + err = _do.CreateInBatches([]*model.User{{}, {}}, 10) + if err != nil { + t.Error("create item in table fail:", err) + } + + _, err = _do.Select(user.ALL).Take() + if err != nil { + t.Error("Take() on table fail:", err) + } + + _, err = _do.First() + if err != nil { + t.Error("First() on table fail:", err) + } + + _, err = _do.Last() + if err != nil { + t.Error("First() on table fail:", err) + } + + _, err = _do.Where(primaryKey.IsNotNull()).FindInBatch(10, func(tx gen.Dao, batch int) error { return nil }) + if err != nil { + t.Error("FindInBatch() on table fail:", err) + } + + err = _do.Where(primaryKey.IsNotNull()).FindInBatches(&[]*model.User{}, 10, func(tx gen.Dao, batch int) error { return nil }) + if err != nil { + t.Error("FindInBatches() on table fail:", err) + } + + _, err = _do.Select(user.ALL).Where(primaryKey.IsNotNull()).Order(primaryKey.Desc()).Find() + if err != nil { + t.Error("Find() on table fail:", err) + } + + _, err = _do.Distinct(primaryKey).Take() + if err != nil { + t.Error("select Distinct() on table fail:", err) + } + + _, err = _do.Select(user.ALL).Omit(primaryKey).Take() + if err != nil { + t.Error("Omit() on table fail:", err) + } + + _, err = _do.Group(primaryKey).Find() + if err != nil { + t.Error("Group() on table fail:", err) + } + + _, err = _do.Scopes(func(dao gen.Dao) gen.Dao { return dao.Where(primaryKey.IsNotNull()) }).Find() + if err != nil { + t.Error("Scopes() on table fail:", err) + } + + _, _, err = _do.FindByPage(0, 1) + if err != nil { + t.Error("FindByPage() on table fail:", err) + } + + _, err = _do.ScanByPage(&model.User{}, 0, 1) + if err != nil { + t.Error("ScanByPage() on table fail:", err) + } + + _, err = _do.Attrs(primaryKey).Assign(primaryKey).FirstOrInit() + if err != nil { + t.Error("FirstOrInit() on table fail:", err) + } + + _, err = _do.Attrs(primaryKey).Assign(primaryKey).FirstOrCreate() + if err != nil { + t.Error("FirstOrCreate() on table fail:", err) + } + + var _a _another + var _aPK = field.NewString(_a.TableName(), clause.PrimaryKey) + + err = _do.Join(&_a, primaryKey.EqCol(_aPK)).Scan(map[string]interface{}{}) + if err != nil { + t.Error("Join() on table fail:", err) + } + + err = _do.LeftJoin(&_a, primaryKey.EqCol(_aPK)).Scan(map[string]interface{}{}) + if err != nil { + t.Error("LeftJoin() on table fail:", err) + } + + _, err = _do.Not().Or().Clauses().Take() + if err != nil { + t.Error("Not/Or/Clauses on table fail:", err) + } +} diff --git a/tests/diy_method/method.go b/tests/diy_method/method.go index 54e41312..16c588ee 100644 --- a/tests/diy_method/method.go +++ b/tests/diy_method/method.go @@ -1,12 +1,12 @@ package diy_method import ( - "gorm.io/gen" "time" + + "gorm.io/gen" ) type InsertMethod interface { - // AddUser // // INSERT INTO users (name,age) VALUES (@name,@age) ON DUPLICATE KEY UPDATE age=VALUES(age) @@ -44,7 +44,6 @@ type InsertMethod interface { } type SelectMethod interface { - // FindByID // // select * from users where id=@id @@ -67,7 +66,6 @@ type SelectMethod interface { } type TrimTest interface { - // TestTrim // // select * from @@table where @@ -313,3 +311,18 @@ type TestFor interface { // {{end}} TestForLike(names []string) []gen.T } + +type TestForWithMethod struct { + ID int32 +} + +func (m *TestForWithMethod) IsEmpty() bool { + if m == nil { + return true + } + return m.ID == 0 +} + +func (m *TestForWithMethod) GetID() int { + return int(m.ID) +} diff --git a/tests/generate_test.go b/tests/generate_test.go index dac6232d..9b1bec22 100644 --- a/tests/generate_test.go +++ b/tests/generate_test.go @@ -86,6 +86,38 @@ var generateCase = map[string]func(dir string) *gen.Generator{ }, g.GenerateModel("users")) return g }, + generateDirPrefix + "dal_5": func(dir string) *gen.Generator { + g := gen.NewGenerator(gen.Config{ + OutPath: dir + "/query", + Mode: gen.WithDefaultQuery | gen.WithQueryInterface, + + WithUnitTest: true, + + FieldNullable: true, + FieldCoverable: true, + FieldWithIndexTag: true, + }) + g.UseDB(DB) + g.WithJSONTagNameStrategy(func(c string) string { return "-" }) + g.ApplyBasic(g.GenerateModel("users", gen.WithMethod(diy_method.TestForWithMethod{}))) + return g + }, + generateDirPrefix + "dal_6": func(dir string) *gen.Generator { + g := gen.NewGenerator(gen.Config{ + OutPath: dir + "/query", + Mode: gen.WithDefaultQuery | gen.WithQueryInterface, + + WithUnitTest: true, + + FieldNullable: true, + FieldCoverable: true, + FieldWithIndexTag: true, + }) + g.UseDB(DB) + g.WithJSONTagNameStrategy(func(c string) string { return "-" }) + g.ApplyBasic(g.GenerateModelAs("users", DB.Config.NamingStrategy.SchemaName("users"), gen.WithMethod(diy_method.TestForWithMethod{}))) + return g + }, } func TestGenerate(t *testing.T) {