Skip to content

Commit

Permalink
Fix Select subquery with DollarPlaceholder (#298)
Browse files Browse the repository at this point in the history
Fixes #286
  • Loading branch information
lann authored Oct 13, 2021
1 parent cd1fe0a commit 84ae2bc
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 12 deletions.
2 changes: 1 addition & 1 deletion case.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (b *sqlizerBuffer) WriteSql(item Sqlizer) {

var str string
var args []interface{}
str, args, b.err = item.ToSql()
str, args, b.err = nestedToSql(item)

if b.err != nil {
return
Expand Down
10 changes: 9 additions & 1 deletion part.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,17 @@ func (p part) ToSql() (sql string, args []interface{}, err error) {
return
}

func nestedToSql(s Sqlizer) (string, []interface{}, error) {
if raw, ok := s.(rawSqlizer); ok {
return raw.toSqlRaw()
} else {
return s.ToSql()
}
}

func appendToSql(parts []Sqlizer, w io.Writer, sep string, args []interface{}) ([]interface{}, error) {
for i, p := range parts {
partSql, partArgs, err := p.ToSql()
partSql, partArgs, err := nestedToSql(p)
if err != nil {
return nil, err
} else if len(partSql) == 0 {
Expand Down
16 changes: 6 additions & 10 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func (d *selectData) QueryRow() RowScanner {
}

func (d *selectData) ToSql() (sqlStr string, args []interface{}, err error) {
sqlStr, args, err = d.toSql()
sqlStr, args, err = d.toSqlRaw()
if err != nil {
return
}
Expand All @@ -62,10 +62,6 @@ func (d *selectData) ToSql() (sqlStr string, args []interface{}, err error) {
}

func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) {
return d.toSql()
}

func (d *selectData) toSql() (sqlStr string, args []interface{}, err error) {
if len(d.Columns) == 0 {
err = fmt.Errorf("select statements must have at least one result column")
return
Expand Down Expand Up @@ -222,6 +218,11 @@ func (b SelectBuilder) ToSql() (string, []interface{}, error) {
return data.ToSql()
}

func (b SelectBuilder) toSqlRaw() (string, []interface{}, error) {
data := builder.GetStruct(b).(selectData)
return data.toSqlRaw()
}

// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b SelectBuilder) MustSql() (string, []interface{}) {
Expand All @@ -232,11 +233,6 @@ func (b SelectBuilder) MustSql() (string, []interface{}) {
return sql, args
}

func (b SelectBuilder) toSqlRaw() (string, []interface{}, error) {
data := builder.GetStruct(b).(selectData)
return data.toSqlRaw()
}

// Prefix adds an expression to the beginning of the query
func (b SelectBuilder) Prefix(sql string, args ...interface{}) SelectBuilder {
return b.PrefixExpr(Expr(sql, args...))
Expand Down
17 changes: 17 additions & 0 deletions select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,23 @@ func TestSelectWithEmptyStringWhereClause(t *testing.T) {
assert.Equal(t, "SELECT * FROM users", sql)
}

func TestSelectSubqueryPlaceholderNumbering(t *testing.T) {
subquery := Select("a").Where("b = ?", 1).PlaceholderFormat(Dollar)
with := subquery.Prefix("WITH a AS (").Suffix(")")

sql, args, err := Select("*").
PrefixExpr(with).
FromSelect(subquery, "q").
Where("c = ?", 2).
PlaceholderFormat(Dollar).
ToSql()
assert.NoError(t, err)

expectedSql := "WITH a AS ( SELECT a WHERE b = $1 ) SELECT * FROM (SELECT a WHERE b = $2) AS q WHERE c = $3"
assert.Equal(t, expectedSql, sql)
assert.Equal(t, []interface{}{1, 1, 2}, args)
}

func ExampleSelect() {
Select("id", "created", "first_name").From("users") // ... continue building up your query

Expand Down

0 comments on commit 84ae2bc

Please sign in to comment.