From 653e2aa5213054ce05df160487a36bbf13871c5b Mon Sep 17 00:00:00 2001 From: iDer Date: Fri, 6 Aug 2021 15:16:58 +0800 Subject: [PATCH 1/3] style:file header and function param --- generator.go | 4 ++-- internal/template/tmpl.go | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/generator.go b/generator.go index db81ba85..c7b4c07b 100644 --- a/generator.go +++ b/generator.go @@ -79,8 +79,8 @@ func (g *Generator) UseDB(db *gorm.DB) { } // GenerateModel catch table info from db, return a BaseStruct -func (g *Generator) GenerateModel(name string) *check.BaseStruct { - structs, err := check.GenBaseStructs(g.db, g.Config.ModelPkgName, name) +func (g *Generator) GenerateModel(tableName string) *check.BaseStruct { + structs, err := check.GenBaseStructs(g.db, g.Config.ModelPkgName, tableName) if err != nil { log.Fatalf("check struct error: %s", err) } diff --git a/internal/template/tmpl.go b/internal/template/tmpl.go index 64b696a2..706711cd 100644 --- a/internal/template/tmpl.go +++ b/internal/template/tmpl.go @@ -1,9 +1,9 @@ package template const HeaderTmpl = ` -// Code generated by gorm/gen. DO NOT EDIT. -// Code generated by gorm/gen. DO NOT EDIT. -// Code generated by gorm/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. package {{.}} @@ -224,6 +224,9 @@ func Use(db *gorm.DB) *DB { // ModelTemplate used as a variable because it cannot load template file after packed, params still can pass file const ModelTemplate = ` +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. package {{.StructInfo.Package}} import "time" From 5db669cba42b06081bf1744e4bcfaaab07d74064 Mon Sep 17 00:00:00 2001 From: iDer Date: Fri, 6 Aug 2021 17:53:44 +0800 Subject: [PATCH 2/3] feat: specify module name --- generator.go | 6 +-- generator_test.go | 2 +- internal/check/export.go | 3 +- internal/check/gen_structs.go | 72 +++++++++++++++++++++-------------- internal/check/utils.go | 16 ++++++++ 5 files changed, 65 insertions(+), 34 deletions(-) diff --git a/generator.go b/generator.go index c7b4c07b..a494c8e6 100644 --- a/generator.go +++ b/generator.go @@ -79,12 +79,12 @@ func (g *Generator) UseDB(db *gorm.DB) { } // GenerateModel catch table info from db, return a BaseStruct -func (g *Generator) GenerateModel(tableName string) *check.BaseStruct { - structs, err := check.GenBaseStructs(g.db, g.Config.ModelPkgName, tableName) +func (g *Generator) GenerateModel(tableName string, modelName string) *check.BaseStruct { + s, err := check.GenBaseStructs(g.db, g.Config.ModelPkgName, tableName, modelName) if err != nil { log.Fatalf("check struct error: %s", err) } - return structs[0] + return s } // ApplyBasic specify models which will implement basic method diff --git a/generator_test.go b/generator_test.go index a6957144..fc5914ee 100644 --- a/generator_test.go +++ b/generator_test.go @@ -15,4 +15,4 @@ func TestConfig(t *testing.T) { queryPkgName: "query", } -} \ No newline at end of file +} diff --git a/internal/check/export.go b/internal/check/export.go index 647e4c0e..46bd54bb 100644 --- a/internal/check/export.go +++ b/internal/check/export.go @@ -3,7 +3,6 @@ package check import ( "fmt" "reflect" - "strings" "gorm.io/gorm" @@ -31,7 +30,7 @@ func CheckStructs(db *gorm.DB, structs ...interface{}) (bases []*BaseStruct, err base := &BaseStruct{ S: GetSimpleName(name), StructName: name, - NewStructName: strings.ToLower(name), + NewStructName: uncaptialize(name), StructInfo: parser.Param{Type: name, Package: getPackageName(structType.String())}, Source: Struct, db: db, diff --git a/internal/check/gen_structs.go b/internal/check/gen_structs.go index 0dc9eb34..6d37b1c2 100644 --- a/internal/check/gen_structs.go +++ b/internal/check/gen_structs.go @@ -57,43 +57,46 @@ var dataType = map[string]string{ } // GenBaseStructs generate db model by table name -func GenBaseStructs(db *gorm.DB, pkg string, tableName ...string) (bases []*BaseStruct, err error) { +func GenBaseStructs(db *gorm.DB, pkg string, tableName, modelName string) (bases *BaseStruct, err error) { if isDBUndefined(db) { return nil, fmt.Errorf("gen config db is undefined") } + if !isModelNameValid(modelName) { + return nil, fmt.Errorf("model name %q is invalid", modelName) + } if pkg == "" { pkg = ModelPkg } singular := singularModel(db.Config) dbName := getSchemaName(db) - for _, tb := range tableName { - columns, err := getTbColumns(db, dbName, tb) - if err != nil { - return nil, err - } - var base BaseStruct - base.Source = TableName - base.GenBaseStruct = true - base.TableName = tb - base.StructName = convertToModelName(singular, tb) - base.StructInfo = parser.Param{Type: base.StructName, Package: pkg} - for _, field := range columns { - mt := dataType[field.DataType] - base.Members = append(base.Members, &Member{ - Name: nameToCamelCase(field.ColumnName), - Type: mt, - ModelType: mt, - ColumnName: field.ColumnName, - ColumnComment: field.ColumnComment, - }) - } - - base.NewStructName = strings.ToLower(base.StructName) - base.S = string(base.NewStructName[0]) - _ = base.check() - bases = append(bases, &base) + columns, err := getTbColumns(db, dbName, tableName) + if err != nil { + return nil, err } - return + var base BaseStruct + base.Source = TableName + base.GenBaseStruct = true + base.TableName = tableName + base.StructName = convertToModelName(singular, tableName) + if modelName != "" { + base.StructName = captialize(modelName) + } + base.NewStructName = uncaptialize(base.StructName) + base.S = string(base.NewStructName[0]) + base.StructInfo = parser.Param{Type: base.StructName, Package: pkg} + for _, field := range columns { + mt := dataType[field.DataType] + base.Members = append(base.Members, &Member{ + Name: nameToCamelCase(field.ColumnName), + Type: mt, + ModelType: mt, + ColumnName: field.ColumnName, + ColumnComment: field.ColumnComment, + }) + } + + _ = base.check() + return &base, nil } //Mysql @@ -145,3 +148,16 @@ func singularModel(conf *gorm.Config) bool { } return false } + +// get mysql db' name +var modelNameReg = regexp.MustCompile(`^\w+$`) + +func isModelNameValid(name string) bool { + if name == "" { + return true + } + if !modelNameReg.MatchString(name) { + return false + } + return (name[0] > '9' || name[0] < '0') && name[0] != '_' +} diff --git a/internal/check/utils.go b/internal/check/utils.go index 87e10c22..72f4a14b 100644 --- a/internal/check/utils.go +++ b/internal/check/utils.go @@ -89,3 +89,19 @@ func getStructName(t string) string { list := strings.Split(t, ".") return list[len(list)-1] } + +func uncaptialize(s string) string { + if s == "" { + return "" + } + + return strings.ToLower(s[:1]) + s[1:] +} + +func captialize(s string) string { + if s == "" { + return "" + } + + return strings.ToUpper(s[:1]) + s[1:] +} From 22b89b0c0ed699c8ec39e2c68f3c0a96ff24ffaa Mon Sep 17 00:00:00 2001 From: riverchu Date: Fri, 6 Aug 2021 19:57:27 +0800 Subject: [PATCH 3/3] fix: distinct error --- do.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/do.go b/do.go index fb5f339e..0eca6c6f 100644 --- a/do.go +++ b/do.go @@ -176,7 +176,7 @@ func (d *DO) Order(columns ...field.Expr) Dao { func (d *DO) Distinct(columns ...field.Expr) Dao { Emit(methodDistinct) - return NewDO(d.db.Distinct(toInterfaceSlice(toColNames(d.db.Statement, columns...)))) + return NewDO(d.db.Distinct(toInterfaceSlice(toColNames(d.db.Statement, columns...))...)) } func (d *DO) Omit(columns ...field.Expr) Dao {