diff --git a/internal/sqladapter/result.go b/internal/sqladapter/result.go index 1b25e2ea..1c30c357 100644 --- a/internal/sqladapter/result.go +++ b/internal/sqladapter/result.go @@ -213,7 +213,7 @@ func (r *Result) Select(fields ...interface{}) db.Result { // String satisfies fmt.Stringer func (r *Result) String() string { - query, err := r.buildPaginator() + query, err := r.Paginator() if err != nil { panic(err.Error()) } @@ -222,7 +222,7 @@ func (r *Result) String() string { // All dumps all Results into a pointer to an slice of structs or maps. func (r *Result) All(dst interface{}) error { - query, err := r.buildPaginator() + query, err := r.Paginator() if err != nil { r.setErr(err) return err @@ -235,7 +235,7 @@ func (r *Result) All(dst interface{}) error { // One fetches only one Result from the set. func (r *Result) One(dst interface{}) error { one := r.Limit(1).(*Result) - query, err := one.buildPaginator() + query, err := one.Paginator() if err != nil { r.setErr(err) return err @@ -251,7 +251,7 @@ func (r *Result) Next(dst interface{}) bool { defer r.iterMu.Unlock() if r.iter == nil { - query, err := r.buildPaginator() + query, err := r.Paginator() if err != nil { r.setErr(err) return false @@ -309,7 +309,7 @@ func (r *Result) Update(values interface{}) error { } func (r *Result) TotalPages() (uint, error) { - query, err := r.buildPaginator() + query, err := r.Paginator() if err != nil { r.setErr(err) return 0, err @@ -325,7 +325,7 @@ func (r *Result) TotalPages() (uint, error) { } func (r *Result) TotalEntries() (uint64, error) { - query, err := r.buildPaginator() + query, err := r.Paginator() if err != nil { r.setErr(err) return 0, err @@ -391,7 +391,7 @@ func (r *Result) Count() (uint64, error) { return counter.Count, nil } -func (r *Result) buildPaginator() (db.Paginator, error) { +func (r *Result) Paginator() (db.Paginator, error) { if err := r.Err(); err != nil { return nil, err } diff --git a/internal/sqlbuilder/builder.go b/internal/sqlbuilder/builder.go index da7ecb3c..6443d76d 100644 --- a/internal/sqlbuilder/builder.go +++ b/internal/sqlbuilder/builder.go @@ -51,7 +51,11 @@ var defaultMapOptions = MapOptions{ IncludeNil: false, } -type compilable interface { +type hasPaginator interface { + Paginator() (db.Paginator, error) +} + +type isCompilable interface { Compile() (string, error) Arguments() []interface{} } @@ -347,7 +351,17 @@ func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, err for i := range columns { switch v := columns[i].(type) { - case compilable: + case hasPaginator: + p, err := v.Paginator() + if err != nil { + return nil, nil, err + } + + q, a := Preprocess(p.String(), p.Arguments()) + + f[i] = exql.RawValue("(" + q + ")") + args = append(args, a...) + case isCompilable: c, err := v.Compile() if err != nil { return nil, nil, err diff --git a/internal/sqlbuilder/convert.go b/internal/sqlbuilder/convert.go index 34d0c59a..50f344c9 100644 --- a/internal/sqlbuilder/convert.go +++ b/internal/sqlbuilder/convert.go @@ -122,7 +122,13 @@ func preprocessFn(arg interface{}) (string, []interface{}) { switch t := arg.(type) { case *adapter.RawExpr: return Preprocess(t.Raw(), t.Arguments()) - case compilable: + case hasPaginator: + p, err := t.Paginator() + if err == nil { + return `(` + p.String() + `)`, p.Arguments() + } + panic(err.Error()) + case isCompilable: c, err := t.Compile() if err == nil { return `(` + c + `)`, t.Arguments() diff --git a/internal/testsuite/sql_suite.go b/internal/testsuite/sql_suite.go index 407e484b..011e2df8 100644 --- a/internal/testsuite/sql_suite.go +++ b/internal/testsuite/sql_suite.go @@ -1898,3 +1898,34 @@ func (s *SQLTestSuite) Test_Issue565() { s.Error(err) s.Zero(result.Name) } + +func (s *SQLTestSuite) TestSelectFromSubquery() { + sess := s.Session() + + { + var artists []artistType + q := sess.SQL().SelectFrom( + sess.SQL().SelectFrom("artist").Where(db.Cond{ + "name": db.IsNotNull(), + }), + ).As("_q") + err := q.All(&artists) + s.NoError(err) + + s.NotZero(len(artists)) + } + + { + var artists []artistType + q := sess.SQL().SelectFrom( + sess.Collection("artist").Find(db.Cond{ + "name": db.IsNotNull(), + }), + ).As("_q") + err := q.All(&artists) + s.NoError(err) + + s.NotZero(len(artists)) + } + +}