From e0c4067bcc036d2352f0799e69e51096581f3a82 Mon Sep 17 00:00:00 2001 From: r1v3r Date: Mon, 29 Aug 2022 17:48:30 +0800 Subject: [PATCH] fix:data race(#588) --- generator.go | 12 ++++++------ import.go | 22 +++++++++++++++------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/generator.go b/generator.go index 6efe5bd3..b6865cb1 100644 --- a/generator.go +++ b/generator.go @@ -287,7 +287,7 @@ func (g *Generator) generateQueryFile() (err error) { pool.Wait() go func(info *genInfo) { defer pool.Done() - err = g.generateSingleQueryFile(info) + err := g.generateSingleQueryFile(info) if err != nil { errChan <- err } @@ -310,7 +310,7 @@ func (g *Generator) generateQueryFile() (err error) { var buf bytes.Buffer err = render(tmpl.Header, &buf, map[string]interface{}{ "Package": g.queryPkgName, - "ImportPkgPaths": importList.Add(g.importPkgPaths...).Output(), + "ImportPkgPaths": importList.Add(g.importPkgPaths...).Paths(), }) if err != nil { return err @@ -339,7 +339,7 @@ func (g *Generator) generateQueryFile() (err error) { err = render(tmpl.Header, &buf, map[string]interface{}{ "Package": g.queryPkgName, - "ImportPkgPaths": unitTestImportList.Add(g.importPkgPaths...).Output(), + "ImportPkgPaths": unitTestImportList.Add(g.importPkgPaths...).Paths(), }) if err != nil { g.db.Logger.Error(context.Background(), "generate query unit test fail: %s", err) @@ -376,7 +376,7 @@ func (g *Generator) generateSingleQueryFile(data *genInfo) (err error) { } err = render(tmpl.Header, &buf, map[string]interface{}{ "Package": g.queryPkgName, - "ImportPkgPaths": importList.Add(structPkgPath).Add(getImportPkgPaths(data)...).Output(), + "ImportPkgPaths": importList.Add(structPkgPath).Add(getImportPkgPaths(data)...).Paths(), }) if err != nil { return err @@ -426,7 +426,7 @@ func (g *Generator) generateQueryUnitTestFile(data *genInfo) (err error) { } err = render(tmpl.Header, &buf, map[string]interface{}{ "Package": g.queryPkgName, - "ImportPkgPaths": unitTestImportList.Add(structPkgPath).Add(data.ImportPkgPaths...).Output(), + "ImportPkgPaths": unitTestImportList.Add(structPkgPath).Add(data.ImportPkgPaths...).Paths(), }) if err != nil { return err @@ -474,7 +474,7 @@ func (g *Generator) generateModelFile() error { defer pool.Done() var buf bytes.Buffer - err = render(tmpl.Model, &buf, data) + err := render(tmpl.Model, &buf, data) if err != nil { errChan <- err return diff --git a/import.go b/import.go index f2cabd39..7ee4a5c0 100644 --- a/import.go +++ b/import.go @@ -3,7 +3,7 @@ package gen import "strings" var ( - importList = importPkgS{}.Add( + importList = new(importPkgS).Add( "context", "database/sql", "strings", @@ -18,7 +18,7 @@ var ( "", "gorm.io/plugin/dbresolver", ) - unitTestImportList = importPkgS{}.Add( + unitTestImportList = new(importPkgS).Add( "context", "fmt", "strconv", @@ -29,18 +29,23 @@ var ( ) ) -type importPkgS struct{ paths []string } +type importPkgS struct { + paths []string +} func (ip importPkgS) Add(paths ...string) *importPkgS { + purePaths := make([]string, 0, len(paths)+1) for _, p := range paths { p = strings.TrimSpace(p) if p == "" { - ip.paths = append(ip.paths, p) + purePaths = append(purePaths, p) continue } + if p[len(p)-1] != '"' { p = `"` + p + `"` } + var exists bool for _, existsP := range ip.paths { if p == existsP { @@ -49,11 +54,14 @@ func (ip importPkgS) Add(paths ...string) *importPkgS { } } if !exists { - ip.paths = append(ip.paths, p) + purePaths = append(purePaths, p) } } - ip.paths = append(ip.paths, "") + purePaths = append(purePaths, "") + + ip.paths = append(ip.paths, purePaths...) + return &ip } -func (ip *importPkgS) Output() []string { return ip.paths } +func (ip importPkgS) Paths() []string { return ip.paths }