diff --git a/internal/generate/clause_test.go b/internal/generate/clause_test.go index c1258313..2eb3236b 100644 --- a/internal/generate/clause_test.go +++ b/internal/generate/clause_test.go @@ -66,8 +66,8 @@ func TestClause(t *testing.T) { GenerateResult: []string{ "generateSQL.WriteString(\"select * from users \")", "var whereSQL0 strings.Builder", - "params[\"id\"] = id", - "whereSQL0.WriteString(\"id>@id \")", + "params = append(params,id)", + "whereSQL0.WriteString(\"id>? \")", "helper.JoinWhereBuilder(&generateSQL,whereSQL0)", }, }, @@ -87,8 +87,8 @@ func TestClause(t *testing.T) { "generateSQL.WriteString(\"select * from users \")", "var whereSQL0 strings.Builder", "if id > 0 {", - "params[\"id\"] = id", - "whereSQL0.WriteString(\"id>@id \")", + "params = append(params,id)", + "whereSQL0.WriteString(\"id>? \")", "}", "helper.JoinWhereBuilder(&generateSQL,whereSQL0)", }, @@ -116,17 +116,17 @@ func TestClause(t *testing.T) { "generateSQL.WriteString(\"update users \")", "var setSQL0 strings.Builder", "if name != \"\" {", - "params[\"name\"] = name", - "setSQL0.WriteString(\"name=@name \")", + "params = append(params,name)", + "setSQL0.WriteString(\"name=? \")", "}", "setSQL0.WriteString(\", \")", "if id>0 {", - "params[\"id\"] = id", - "setSQL0.WriteString(\"id=@id \")", + "params = append(params,id)", + "setSQL0.WriteString(\"id=? \")", "}", "helper.JoinSetBuilder(&generateSQL,setSQL0)", - "params[\"id\"] = id", - "generateSQL.WriteString(\"where id=@id \")", + "params = append(params,id)", + "generateSQL.WriteString(\"where id=? \")", }, }, { @@ -135,7 +135,7 @@ func TestClause(t *testing.T) { "\"select * from \"", "\"users\"", "where", - "for _index, name := range names", + "for _, name := range names", "\"name=\"", "name", "end", @@ -144,9 +144,9 @@ func TestClause(t *testing.T) { GenerateResult: []string{ "generateSQL.WriteString(\"select * from users \")", "var whereSQL0 strings.Builder", - "for _index, name := range names{", - "params[\"nameForWhereSQL0_\"+strconv.Itoa(_index)]=name", - "whereSQL0.WriteString(\"name=@nameForWhereSQL0_\"+strconv.Itoa(_index)+\" \")", + "for _, name := range names{", + "params = append(params,name)", + "whereSQL0.WriteString(\"name=? \")", "}", "helper.JoinWhereBuilder(&generateSQL,whereSQL0)", }, diff --git a/internal/generate/interface.go b/internal/generate/interface.go index 09e09ea4..43eeef86 100644 --- a/internal/generate/interface.go +++ b/internal/generate/interface.go @@ -145,8 +145,11 @@ func (m *InterfaceMethod) checkParams(params []parser.Param) (err error) { switch { case param.Package == "UNDEFINED": param.Package = m.Package - case param.IsMap() || param.IsGenM() || param.IsError() || param.IsNull(): + case param.IsError() || param.IsNull(): return fmt.Errorf("type error on interface [%s] param: [%s]", m.InterfaceName, param.Name) + case param.IsGenM(): + param.Type = "map[string]interface{}" + param.Package = "" case param.IsGenT(): param.Type = m.OriginStruct.Type param.Package = m.OriginStruct.Package @@ -185,7 +188,6 @@ func (m *InterfaceMethod) checkResult(result []parser.Param) (err error) { param.SetName("result") param.Type = m.OriginStruct.Type param.Package = m.OriginStruct.Package - param.IsPointer = true m.ResultData = param case param.IsInterface(): return fmt.Errorf("query method can not return interface in [%s.%s]", m.InterfaceName, m.MethodName) diff --git a/internal/generate/section.go b/internal/generate/section.go index 56d6ab47..d01b4233 100644 --- a/internal/generate/section.go +++ b/internal/generate/section.go @@ -60,19 +60,6 @@ func (s *Section) appendTmpl(value string) { s.Tmpls = append(s.Tmpls, value) } -func (s *Section) isInForValue(value string) (ForRange, bool) { - valueList := strings.Split(value, ".") - for _, v := range s.forValue { - if v.value == valueList[0] { - if len(valueList) > 1 { - v.suffix = "." + strings.Join(valueList[1:], ".") - } - return v, true - } - } - return ForRange{}, false -} - func (s *Section) hasSameName(value string) bool { for _, p := range s.members { if p.Type == model.FOR && p.ForRange.value == value { @@ -438,15 +425,8 @@ func (s *Section) parseSQL(name string) (res SQLClause) { case model.VARIABLE: res.Value = append(res.Value, c.Value) case model.DATA: - forRange, isInForRange := s.isInForValue(c.Value) - if isInForRange { - s.appendTmpl(forRange.appendDataToParams(c.Value, name)) - c.Value = forRange.DataValue(c.Value, name) - } else { - s.appendTmpl(c.AddDataToParamMap()) - c.Value = strconv.Quote("@" + c.SQLParamName()) - } - res.Value = append(res.Value, c.Value) + s.appendTmpl(fmt.Sprintf("params = append(params,%s)", c.Value)) + res.Value = append(res.Value, "\"?\"") default: s.SubIndex() return @@ -460,28 +440,24 @@ func (s *Section) parseSQL(name string) (res SQLClause) { // checkSQLVar check sql variable by for loops value and external params func (s *Section) checkSQLVar(param string, status model.Status, method *InterfaceMethod) (result section, err error) { - paramName := strings.Split(param, ".")[0] - for index, part := range s.members { - if part.Type == model.FOR && part.ForRange.value == paramName { - switch status { - case model.DATA: - method.HasForParams = true - if part.ForRange.index == "_" { - s.members[index].SetForRangeKey("_index") - } - case model.VARIABLE: - param = fmt.Sprintf("%s.Quote(%s)", method.S, param) - } - result = section{ - Type: status, - Value: param, - } - return + if status == model.VARIABLE && param == "table" { + result = section{ + Type: model.SQL, + Value: strconv.Quote(method.Table), } - + return } - - return method.checkSQLVarByParams(param, status) + if status == model.DATA { + method.HasForParams = true + } + if status == model.VARIABLE { + param = fmt.Sprintf("%s.Quote(%s)", method.S, param) + } + result = section{ + Type: status, + Value: param, + } + return } // GetName ... @@ -581,15 +557,6 @@ func (s *section) sectionType(str string) error { return nil } -func (s *section) SetForRangeKey(key string) { - s.ForRange.index = key - s.Value = s.String() -} - -func (s *section) AddDataToParamMap() string { - return fmt.Sprintf("params[%q] = %s", s.SQLParamName(), s.Value) -} - func (s *section) SQLParamName() string { return strings.Replace(s.Value, ".", "", -1) } @@ -605,16 +572,3 @@ type ForRange struct { func (f *ForRange) String() string { return fmt.Sprintf("for %s, %s := range %s", f.index, f.value, f.rangeList) } - -func (f *ForRange) mapIndexName(prefix, dataName, clauseName string) string { - return fmt.Sprintf("\"%s%sFor%s_\"+strconv.Itoa(%s)", prefix, strings.Replace(dataName, ".", "", -1), strings.Title(clauseName), f.index) -} - -// DataValue return data value -func (f *ForRange) DataValue(dataName, clauseName string) string { - return f.mapIndexName("@", dataName, clauseName) -} - -func (f *ForRange) appendDataToParams(dataName, clauseName string) string { - return fmt.Sprintf("params[%s]=%s%s", f.mapIndexName("", dataName, clauseName), f.value, f.suffix) -} diff --git a/internal/parser/export.go b/internal/parser/export.go index 9d99d4c0..4753484a 100644 --- a/internal/parser/export.go +++ b/internal/parser/export.go @@ -36,18 +36,16 @@ func GetInterfacePath(v interface{}) (paths []*InterfacePath, err error) { path.Name = n } + ctx := build.Default + var p *build.Package + if strings.Split(arg.String(), ".")[0] == "main" { - _, file, _, ok := runtime.Caller(3) - if ok { - path.Files = append(path.Files, file) - } - paths = append(paths, &path) - continue + _, file, _, _ := runtime.Caller(3) + p, err = ctx.ImportDir(filepath.Dir(file), build.ImportComment) + } else { + p, err = ctx.Import(arg.PkgPath(), "", build.ImportComment) } - ctx := build.Default - var p *build.Package - p, err = ctx.Import(arg.PkgPath(), "", build.ImportComment) if err != nil { return } diff --git a/internal/template/method.go b/internal/template/method.go index 1feddbfe..d99629c1 100644 --- a/internal/template/method.go +++ b/internal/template/method.go @@ -5,7 +5,7 @@ const DIYMethod = ` //{{.DocComment }} func ({{.S}} {{.TargetStruct}}Do){{.FuncSign}}{ - {{if .HasSQLData}}params :=make(map[string]interface{},0) + {{if .HasSQLData}}var params []interface{} {{end}}var generateSQL strings.Builder {{range $line:=.Section.Tmpls}}{{$line}} @@ -13,11 +13,8 @@ func ({{.S}} {{.TargetStruct}}Do){{.FuncSign}}{ {{if .HasNeedNewResult}}result ={{if .ResultData.IsMap}}make{{else}}new{{end}}({{if ne .ResultData.Package ""}}{{.ResultData.Package}}.{{end}}{{.ResultData.Type}}){{end}} {{if or .ReturnRowsAffected .ReturnError}}var executeSQL *gorm.DB - {{end}}{{if .HasSQLData}}if len(params)>0{ - {{if or .ReturnRowsAffected .ReturnError}}executeSQL{{else}}_{{end}}= {{.S}}.UnderlyingDB().{{.GormOption}}(generateSQL.String(){{if .HasSQLData}},params{{end}}){{if not .ResultData.IsNull}}.{{.GormRunMethodName}}({{if .HasGotPoint}}&{{end}}{{.ResultData.Name}}){{end}} - }else{ - {{if or .ReturnRowsAffected .ReturnError}}executeSQL{{else}}_{{end}}= {{.S}}.UnderlyingDB().{{.GormOption}}(generateSQL.String()){{if not .ResultData.IsNull}}.{{.GormRunMethodName}}({{if .HasGotPoint}}&{{end}}{{.ResultData.Name}}){{end}} - }{{else}}{{if or .ReturnRowsAffected .ReturnError}}executeSQL{{else}}_{{end}}= {{.S}}.UnderlyingDB().{{.GormOption}}(generateSQL.String()){{if not .ResultData.IsNull}}.{{.GormRunMethodName}}({{if .HasGotPoint}}&{{end}}{{.ResultData.Name}}){{end}}{{end}} + {{end}} + {{if or .ReturnRowsAffected .ReturnError}}executeSQL{{else}}_{{end}} = {{.S}}.UnderlyingDB().{{.GormOption}}(generateSQL.String(){{if .HasSQLData}},params...{{end}}){{if not .ResultData.IsNull}}.{{.GormRunMethodName}}({{if .HasGotPoint}}&{{end}}{{.ResultData.Name}}){{end}} {{if .ReturnRowsAffected}}rowsAffected = executeSQL.RowsAffected {{end}}{{if .ReturnError}}err = executeSQL.Error {{end}}return