diff --git a/magefile.go b/magefile.go index 5ede255..36d4b7b 100644 --- a/magefile.go +++ b/magefile.go @@ -15,7 +15,7 @@ import ( // Default target to run when none is specified // If not set, running mage will list available targets -// var Default = Build +var Default = Build // A build step that requires additional params, or platform specific steps for example func Build() error { diff --git a/templates/model.tmpl b/templates/model.tmpl index 90fa6e5..e491e76 100644 --- a/templates/model.tmpl +++ b/templates/model.tmpl @@ -223,29 +223,42 @@ func {{ $struct }}By{{ range $i, $col := $key.Columns }}{{ $col.Name | structify // GetAll{{ $struct }}s retrieves all rows from '{{ .Name }}' as a slice of {{ $struct }}. // // Generated from table '{{ .Name }}'. -func GetAll{{ $struct }}s(db DB, wheres ...patcher.Wherer) ([]*{{ $struct }}, error) { +func GetAll{{ $struct }}s(db DB, filters ...any) ([]*{{ $struct }}, error) { t := prometheus.NewTimer(DatabaseLatency.WithLabelValues("get_all_" + {{ $struct | structify }}TableName)) defer t.ObserveDuration() args := make([]any, 0) builder := new(strings.Builder) - builder.WriteString("SELECT {{ range $i, $column := $.Table.Columns }}{{ if $i }}, {{ end }}`{{ $column.Name }}`{{ end }}") - builder.WriteString(" FROM {{ $.Table.Name }} t") - - if len(wheres) > 0 { - builder.WriteString(" WHERE ") - for i, where := range wheres { - if i > 0 { - wtStr := patcher.WhereTypeAnd // default to AND - wt, ok := where.(patcher.WhereTyper) - if ok && wt.WhereType().IsValid() { - wtStr = wt.WhereType() - } - builder.WriteString(string(wtStr) + " ") + builder.WriteString("SELECT {{ range $i, $column := $.Table.Columns }}{{ if $i }}, {{ end }}`t.{{ $column.Name }}`{{ end }}") + + if len(filters) > 0 { + for _, filter := range filters { + if joiner := filter.(patcher.Joiner); joiner != nil { + joinSql, joinArgs := joiner.Join() + builder.WriteString(joinSql) + args = append(args, joinArgs...) + } + } + } + + builder.WriteString("\nFROM {{ $.Table.Name }} t") + + if len(filters) > 0 { + builder.WriteString("\nWHERE\n") + for i, filter := range filters { + if where := filter.(patcher.Wherer); where != nil { + if i > 0 { + wtStr := patcher.WhereTypeAnd + if wt, ok := filter.(patcher.WhereTyper); ok { + wtStr = wt.WhereType() + } + builder.WriteString(string(" " + wtStr + " ")) + } + whereSql, whereArgs := where.Where() + builder.WriteString(whereSql) + builder.WriteString("\n") + args = append(args, whereArgs...) } - whereStr, whereArgs := where.Where() - builder.WriteString(whereStr) - args = append(args, whereArgs...) } }