Skip to content

Commit

Permalink
support "range over function" in Go 1.22
Browse files Browse the repository at this point in the history
  • Loading branch information
lqs committed May 30, 2024
1 parent 33af741 commit 9acd4cc
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
5 changes: 5 additions & 0 deletions cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ import (
"time"
)

// Scanner is the interface that wraps the Scan method.
type Scanner interface {
Scan(dest ...interface{}) error
}

// Cursor is the interface of a row cursor.
type Cursor interface {
Next() bool
Expand Down
26 changes: 26 additions & 0 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ type toSelectFinal interface {
FetchExactlyOne(out ...interface{}) error
FetchAll(dest ...interface{}) (rows int, err error)
FetchCursor() (Cursor, error)
FetchSeq() func(yield func(row Scanner) bool) // use with "range over function" in Go 1.22
}

type join struct {
Expand Down Expand Up @@ -164,6 +165,31 @@ type selectStatus struct {
lock string
}

type errorScanner struct {
err error
}

func (e errorScanner) Scan(dest ...interface{}) error {
return e.err
}

func (s selectStatus) FetchSeq() func(yield func(row Scanner) bool) {
return func(yield func(row Scanner) bool) {
cursor, err := s.FetchCursor()
if err != nil {
yield(errorScanner{err})
return
}

defer cursor.Close()
for cursor.Next() {
if !yield(cursor) {
break
}
}
}
}

type unionSelectStatus struct {
base selectBase
all bool
Expand Down
28 changes: 28 additions & 0 deletions select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,34 @@ func TestFetchAll(t *testing.T) {
}
}

func TestRangeFunc(t *testing.T) {
db := newMockDatabase()

oldColumnCount := sharedMockConn.columnCount
sharedMockConn.columnCount = 2
defer func() {
sharedMockConn.columnCount = oldColumnCount
}()

count := 0
seq := db.Select(field1, field2).From(Table1).FetchSeq()

// for row := range db.Select(field1, field2).From(Table1).FetchSeq() {}
seq(func(row Scanner) bool {
var f1 string
var f2 int
if err := row.Scan(&f1, &f2); err != nil {
t.Error(err)
}
count++
return true
})

if count != 10 {
t.Error(count)
}
}

func TestLock(t *testing.T) {
db := newMockDatabase()
table1 := NewTable("table1")
Expand Down

0 comments on commit 9acd4cc

Please sign in to comment.