Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type Config struct {
OutPath string // query code path
OutFile string // query code file name, default: gen.go
ModelPkgPath string // generated model code's package name
QueryPkgPath string // generated model code's package name
WithUnitTest bool // generate unit test for query code

// generate model global configuration
Expand Down Expand Up @@ -120,6 +121,9 @@ func (cfg *Config) Revise() (err error) {
if strings.TrimSpace(cfg.ModelPkgPath) == "" {
cfg.ModelPkgPath = model.DefaultModelPkg
}
if strings.TrimSpace(cfg.QueryPkgPath) == "" {
cfg.QueryPkgPath = model.DefaultQueryPkg
}

cfg.OutPath, err = filepath.Abs(cfg.OutPath)
if err != nil {
Expand All @@ -129,11 +133,11 @@ func (cfg *Config) Revise() (err error) {
cfg.OutPath = fmt.Sprintf(".%squery%s", string(os.PathSeparator), string(os.PathSeparator))
}
if cfg.OutFile == "" {
cfg.OutFile = filepath.Join(cfg.OutPath, "gen.go")
cfg.OutFile = filepath.Join(cfg.OutPath, cfg.QueryPkgPath, "gen.go")
} else if !strings.Contains(cfg.OutFile, string(os.PathSeparator)) {
cfg.OutFile = filepath.Join(cfg.OutPath, cfg.OutFile)
cfg.OutFile = filepath.Join(cfg.OutPath, cfg.QueryPkgPath, cfg.OutFile)
}
cfg.queryPkgName = filepath.Base(cfg.OutPath)
cfg.queryPkgName = filepath.Base(cfg.QueryPkgPath)

if cfg.db == nil {
cfg.db, _ = gorm.Open(tests.DummyDialector{})
Expand Down
43 changes: 20 additions & 23 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,9 @@ func (g *Generator) generateQueryFile() (err error) {
return nil
}

if err = os.MkdirAll(g.OutPath, os.ModePerm); err != nil {
return fmt.Errorf("make dir outpath(%s) fail: %s", g.OutPath, err)
queryOutPath := g.getQueryOutputPath()
if err = os.MkdirAll(queryOutPath, os.ModePerm); err != nil {
return fmt.Errorf("create query pkg path(%s) fail: %s", queryOutPath, err)
}

errChan := make(chan error)
Expand Down Expand Up @@ -379,6 +380,10 @@ func (g *Generator) generateQueryFile() (err error) {
return nil
}

func (g *Generator) getQueryOutputPath() (outPath string) {
return filepath.Join(g.OutPath, g.QueryPkgPath) + string(os.PathSeparator)
}

// generateSingleQueryFile generate query code and save to file
func (g *Generator) generateSingleQueryFile(data *genInfo) (err error) {
var buf bytes.Buffer
Expand Down Expand Up @@ -425,8 +430,10 @@ func (g *Generator) generateSingleQueryFile(data *genInfo) (err error) {
return err
}

defer g.info(fmt.Sprintf("generate query file: %s%s%s.gen.go", g.OutPath, string(os.PathSeparator), data.FileName))
return g.output(fmt.Sprintf("%s%s%s.gen.go", g.OutPath, string(os.PathSeparator), data.FileName), buf.Bytes())
outputPath := filepath.Join(g.OutPath, g.QueryPkgPath)

defer g.info(fmt.Sprintf("generate query file: %s%s%s.gen.go", outputPath, string(os.PathSeparator), data.FileName))
return g.output(fmt.Sprintf("%s%s%s.gen.go", outputPath, string(os.PathSeparator), data.FileName), buf.Bytes())
}

// generateQueryUnitTestFile generate unit test file for query
Expand Down Expand Up @@ -457,8 +464,10 @@ func (g *Generator) generateQueryUnitTestFile(data *genInfo) (err error) {
}
}

defer g.info(fmt.Sprintf("generate unit test file: %s%s%s.gen_test.go", g.OutPath, string(os.PathSeparator), data.FileName))
return g.output(fmt.Sprintf("%s%s%s.gen_test.go", g.OutPath, string(os.PathSeparator), data.FileName), buf.Bytes())
outputPath := filepath.Join(g.OutPath, g.QueryPkgPath)

defer g.info(fmt.Sprintf("generate unit test file: %s%s%s.gen_test.go", outputPath, string(os.PathSeparator), data.FileName))
return g.output(fmt.Sprintf("%s%s%s.gen_test.go", outputPath, string(os.PathSeparator), data.FileName), buf.Bytes())
}

// generateModelFile generate model structures and save to file
Expand All @@ -467,12 +476,8 @@ func (g *Generator) generateModelFile() error {
return nil
}

modelOutPath, err := g.getModelOutputPath()
if err != nil {
return err
}

if err = os.MkdirAll(modelOutPath, os.ModePerm); err != nil {
modelOutPath := g.getModelOutputPath()
if err := os.MkdirAll(modelOutPath, os.ModePerm); err != nil {
return fmt.Errorf("create model pkg path(%s) fail: %s", modelOutPath, err)
}

Expand Down Expand Up @@ -512,24 +517,16 @@ func (g *Generator) generateModelFile() error {
}(data)
}
select {
case err = <-errChan:
case err := <-errChan:
return err
case <-pool.AsyncWaitAll():
g.fillModelPkgPath(modelOutPath)
}
return nil
}

func (g *Generator) getModelOutputPath() (outPath string, err error) {
if strings.Contains(g.ModelPkgPath, string(os.PathSeparator)) {
outPath, err = filepath.Abs(g.ModelPkgPath)
if err != nil {
return "", fmt.Errorf("cannot parse model pkg path: %w", err)
}
} else {
outPath = filepath.Join(filepath.Dir(g.OutPath), g.ModelPkgPath)
}
return outPath + string(os.PathSeparator), nil
func (g *Generator) getModelOutputPath() (outPath string) {
return filepath.Join(g.OutPath, g.ModelPkgPath) + string(os.PathSeparator)
}

func (g *Generator) fillModelPkgPath(filePath string) {
Expand Down
3 changes: 1 addition & 2 deletions generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ func TestConfig(t *testing.T) {
OutFile: "",

ModelPkgPath: "models",

queryPkgName: "query",
QueryPkgPath: "query",
}
}

Expand Down
2 changes: 2 additions & 0 deletions internal/model/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
const (
// DefaultModelPkg ...
DefaultModelPkg = "model"
// DefaultQueryPkg ...
DefaultQueryPkg = "query"
)

// Status sql status
Expand Down
12 changes: 6 additions & 6 deletions tests/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ var _ = os.Setenv("GORM_DIALECT", "mysql")
var generateCase = map[string]func(dir string) *gen.Generator{
generateDirPrefix + "dal_1": func(dir string) *gen.Generator {
g := gen.NewGenerator(gen.Config{
OutPath: dir + "/query",
OutPath: dir,
Mode: gen.WithDefaultQuery,
})
g.UseDB(DB)
Expand All @@ -34,7 +34,7 @@ var generateCase = map[string]func(dir string) *gen.Generator{
},
generateDirPrefix + "dal_2": func(dir string) *gen.Generator {
g := gen.NewGenerator(gen.Config{
OutPath: dir + "/query",
OutPath: dir,
Mode: gen.WithDefaultQuery,

WithUnitTest: true,
Expand All @@ -50,7 +50,7 @@ var generateCase = map[string]func(dir string) *gen.Generator{
},
generateDirPrefix + "dal_3": func(dir string) *gen.Generator {
g := gen.NewGenerator(gen.Config{
OutPath: dir + "/query",
OutPath: dir,
Mode: gen.WithDefaultQuery | gen.WithQueryInterface,

WithUnitTest: true,
Expand All @@ -70,7 +70,7 @@ var generateCase = map[string]func(dir string) *gen.Generator{
},
generateDirPrefix + "dal_4": func(dir string) *gen.Generator {
g := gen.NewGenerator(gen.Config{
OutPath: dir + "/query",
OutPath: dir,
Mode: gen.WithDefaultQuery | gen.WithQueryInterface,

WithUnitTest: true,
Expand All @@ -88,7 +88,7 @@ var generateCase = map[string]func(dir string) *gen.Generator{
},
generateDirPrefix + "dal_5": func(dir string) *gen.Generator {
g := gen.NewGenerator(gen.Config{
OutPath: dir + "/query",
OutPath: dir,
Mode: gen.WithDefaultQuery | gen.WithQueryInterface,

WithUnitTest: true,
Expand All @@ -104,7 +104,7 @@ var generateCase = map[string]func(dir string) *gen.Generator{
},
generateDirPrefix + "dal_6": func(dir string) *gen.Generator {
g := gen.NewGenerator(gen.Config{
OutPath: dir + "/query",
OutPath: dir,
Mode: gen.WithDefaultQuery | gen.WithQueryInterface,

WithUnitTest: true,
Expand Down
22 changes: 16 additions & 6 deletions tools/gentool/README.ZH_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
## 使用方式

```shell
gentool -h

gentool -h

Usage of gentool:
-db string
input mysql or postgres or sqlite or sqlserver. consult[https://gorm.io/docs/connecting_to_the_database.html] (default "mysql")
Expand All @@ -30,11 +30,13 @@
-fieldWithTypeTag
generate field with gorm column type tag
-modelPkgName string
generated model code's package name
generated model code's package name (default "model")
-queryPkgName string
generated query code's package name (default "query")
-outFile string
query code file name, default: gen.go
-outPath string
specify a directory for output (default "./dao/query")
specify a directory for output (default "./dao")
-tables string
enter the required data table or leave it blank
-onlyModel
Expand Down Expand Up @@ -84,9 +86,17 @@ default ""

#### modelPkgName

默认值是数据表名称。
默认为:model

生成的model代码的包名称。
设置“outPath”后的路径。

#### queryPkgName

默认为:query

生成的model代码的包名称。
设置“outPath”后的路径。

#### outFile

Expand Down
18 changes: 14 additions & 4 deletions tools/gentool/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ Install GEN as a binary tool
-fieldWithTypeTag
generate field with gorm column type tag
-modelPkgName string
generated model code's package name
generated model code's package name (default "model")
-queryPkgName string
generated query code's package name (default "query")
-outFile string
query code file name, default: gen.go
-outPath string
specify a directory for output (default "./dao/query")
specify a directory for output (default "./dao")
-tables string
enter the required data table or leave it blank
-onlyModel
Expand Down Expand Up @@ -82,17 +84,25 @@ generate field with gorm column type tag

#### modelPkgName

defalut table name.
default "model"

generated model code's package name.
set the path after "outPath".

#### queryPkgName

default "query"

generated query code's package name.
set the path after "outPath".

#### outFile

query code file name, default: gen.go

#### outPath

specify a directory for output (default "./dao/query")
specify a directory for output (default "./dao")

#### tables

Expand Down
6 changes: 4 additions & 2 deletions tools/gentool/gen.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ database:
# only generate models (without query file)
onlyModel : false
# specify a directory for output
outPath : "./dao/query"
outPath : "./dao"
# query code file name, default: gen.go
outFile : ""
# generate unit test for query code
withUnitTest : false
# generated model code's package name
modelPkgName : ""
modelPkgName : "model"
# generated query code's package name
queryPkgName : "query"
# generate with pointer when field is nullable
fieldNullable : false
# generate with pointer when field has default value
Expand Down
8 changes: 7 additions & 1 deletion tools/gentool/gentool.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ const (
dbClickHouse DBType = "clickhouse"
)
const (
defaultQueryPath = "./dao/query"
defaultQueryPath = "./dao"
)

// CmdParams is command line parameters
Expand All @@ -42,6 +42,7 @@ type CmdParams struct {
OutFile string `yaml:"outFile"` // query code file name, default: gen.go
WithUnitTest bool `yaml:"withUnitTest"` // generate unit test for query code
ModelPkgName string `yaml:"modelPkgName"` // generated model code's package name
QueryPkgName string `yaml:"queryPkgName"` // generated query code's package name
FieldNullable bool `yaml:"fieldNullable"` // generate with pointer when field is nullable
FieldCoverable bool `yaml:"fieldCoverable"` // generate with pointer when field has default value
FieldWithIndexTag bool `yaml:"fieldWithIndexTag"` // generate field with gorm index tag
Expand Down Expand Up @@ -149,6 +150,7 @@ func argParse() *CmdParams {
outFile := flag.String("outFile", "", "query code file name, default: gen.go")
withUnitTest := flag.Bool("withUnitTest", false, "generate unit test for query code")
modelPkgName := flag.String("modelPkgName", "", "generated model code's package name")
queryPkgName := flag.String("queryPkgName", "", "generated query code's package name")
fieldNullable := flag.Bool("fieldNullable", false, "generate with pointer when field is nullable")
fieldCoverable := flag.Bool("fieldCoverable", false, "generate with pointer when field has default value")
fieldWithIndexTag := flag.Bool("fieldWithIndexTag", false, "generate field with gorm index tag")
Expand Down Expand Up @@ -186,6 +188,9 @@ func argParse() *CmdParams {
if *modelPkgName != "" {
cmdParse.ModelPkgName = *modelPkgName
}
if *queryPkgName != "" {
cmdParse.QueryPkgName = *queryPkgName
}
if *fieldNullable {
cmdParse.FieldNullable = *fieldNullable
}
Expand Down Expand Up @@ -220,6 +225,7 @@ func main() {
OutPath: config.OutPath,
OutFile: config.OutFile,
ModelPkgPath: config.ModelPkgName,
QueryPkgPath: config.QueryPkgName,
WithUnitTest: config.WithUnitTest,
FieldNullable: config.FieldNullable,
FieldCoverable: config.FieldCoverable,
Expand Down