diff --git a/chotki.go b/chotki.go index 7c2d37a..0ea2b70 100644 --- a/chotki.go +++ b/chotki.go @@ -11,8 +11,8 @@ import ( "github.com/cockroachdb/pebble" "github.com/cockroachdb/pebble/vfs" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toyqueue" + "github.com/drpcorg/chotki/toytlv" ) type Packet []byte diff --git a/chotki_test.go b/chotki_test.go index 6d83eff..7d7e23b 100644 --- a/chotki_test.go +++ b/chotki_test.go @@ -7,7 +7,7 @@ import ( "github.com/cockroachdb/pebble" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" + "github.com/drpcorg/chotki/toyqueue" "github.com/stretchr/testify/assert" ) diff --git a/counter.go b/counter.go index 1ff6326..dd48629 100644 --- a/counter.go +++ b/counter.go @@ -2,7 +2,7 @@ package chotki import ( "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) type Counter64 int64 diff --git a/counter_test.go b/counter_test.go index 3933578..fc7460a 100644 --- a/counter_test.go +++ b/counter_test.go @@ -2,7 +2,7 @@ package chotki import ( "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" "github.com/stretchr/testify/assert" "testing" ) diff --git a/examples/object_example.go b/examples/object_example.go index 95890ed..28268b4 100644 --- a/examples/object_example.go +++ b/examples/object_example.go @@ -4,7 +4,7 @@ import ( "github.com/cockroachdb/pebble" "github.com/drpcorg/chotki" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) const ExampleName = 1 diff --git a/examples/object_example_test.go b/examples/object_example_test.go index f031923..9bdfb16 100644 --- a/examples/object_example_test.go +++ b/examples/object_example_test.go @@ -7,7 +7,7 @@ import ( "github.com/drpcorg/chotki" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" + "github.com/drpcorg/chotki/toyqueue" "github.com/stretchr/testify/assert" ) diff --git a/examples/objects_test.go b/examples/objects_test.go index c37dd63..7766d30 100644 --- a/examples/objects_test.go +++ b/examples/objects_test.go @@ -6,7 +6,7 @@ import ( "github.com/drpcorg/chotki" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" + "github.com/drpcorg/chotki/toyqueue" "github.com/stretchr/testify/assert" ) diff --git a/go.mod b/go.mod index efc1913..ef675a9 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,6 @@ go 1.21.4 require ( github.com/cockroachdb/pebble v1.1.0 github.com/ergochat/readline v0.1.0 - github.com/learn-decentralized-systems/toyqueue v0.1.5 - github.com/learn-decentralized-systems/toytlv v0.2.1 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.8.4 ) diff --git a/go.sum b/go.sum index 2149c7b..b0eda58 100644 --- a/go.sum +++ b/go.sum @@ -177,10 +177,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/learn-decentralized-systems/toyqueue v0.1.5 h1:X2EQEWj2dyaE5BUkE58aXsMG7mq8Uv6CpiFCErAnCMQ= -github.com/learn-decentralized-systems/toyqueue v0.1.5/go.mod h1:T5PrFDCcxA1O7hb2MAlHYYFA89ry8hvXUuwg+drS1UQ= -github.com/learn-decentralized-systems/toytlv v0.2.1 h1:nk+gjjE9JZ659kkbxIlv/H/gF5Wtst1Dbn7KckqdFOQ= -github.com/learn-decentralized-systems/toytlv v0.2.1/go.mod h1:+xzKS/La5vCkdyIdOFDb2NVPGF808tG5n5b3Ufxkorg= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 h1:I0XW9+e1XWDxdcEniV4rQAIOPUGDq67JSCiRCgGCZLI= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= diff --git a/log0.go b/log0.go index 39e5a4b..fdcb799 100644 --- a/log0.go +++ b/log0.go @@ -2,8 +2,8 @@ package chotki import ( "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toyqueue" + "github.com/drpcorg/chotki/toytlv" ) const id1 = rdx.ID0 + rdx.ProInc diff --git a/objects.go b/objects.go index 547138b..8fb3f78 100644 --- a/objects.go +++ b/objects.go @@ -4,8 +4,8 @@ import ( "fmt" "github.com/cockroachdb/pebble" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toyqueue" + "github.com/drpcorg/chotki/toytlv" "github.com/pkg/errors" "unicode/utf8" ) diff --git a/op.go b/op.go index 6b2d0d1..d9286c0 100644 --- a/op.go +++ b/op.go @@ -2,7 +2,7 @@ package chotki import ( "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) func ParsePacket(pack []byte) (lit byte, id, ref rdx.ID, body []byte, err error) { diff --git a/orm.go b/orm.go index 46693ac..1360f9b 100644 --- a/orm.go +++ b/orm.go @@ -4,8 +4,8 @@ import ( "bytes" "github.com/cockroachdb/pebble" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toyqueue" + "github.com/drpcorg/chotki/toytlv" "sync" "text/template" ) diff --git a/packets.go b/packets.go index 73fb6b8..d95c83c 100644 --- a/packets.go +++ b/packets.go @@ -5,7 +5,7 @@ import ( "github.com/cockroachdb/pebble" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) func (cho *Chotki) UpdateVTree(id, ref rdx.ID, pb *pebble.Batch) (err error) { diff --git a/rdx/ELM.go b/rdx/ELM.go index b7290be..d52ec17 100644 --- a/rdx/ELM.go +++ b/rdx/ELM.go @@ -3,7 +3,7 @@ package rdx import ( "bytes" "errors" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" "slices" "sort" ) diff --git a/rdx/ELM_test.go b/rdx/ELM_test.go index c1e74ad..080eef5 100644 --- a/rdx/ELM_test.go +++ b/rdx/ELM_test.go @@ -1,8 +1,8 @@ package rdx import ( - "github.com/learn-decentralized-systems/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toyqueue" + "github.com/drpcorg/chotki/toytlv" "github.com/stretchr/testify/assert" "testing" ) diff --git a/rdx/FIRST.go b/rdx/FIRST.go index 564b8aa..5ed972a 100644 --- a/rdx/FIRST.go +++ b/rdx/FIRST.go @@ -4,9 +4,9 @@ import ( "bytes" "errors" "fmt" - "github.com/learn-decentralized-systems/toyqueue" + "github.com/drpcorg/chotki/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) // Common LWW functions diff --git a/rdx/FIRST_test.go b/rdx/FIRST_test.go index 6ea26ee..bd81c8a 100644 --- a/rdx/FIRST_test.go +++ b/rdx/FIRST_test.go @@ -3,7 +3,7 @@ package rdx import ( "testing" - "github.com/learn-decentralized-systems/toyqueue" + "github.com/drpcorg/chotki/toyqueue" "github.com/stretchr/testify/assert" ) diff --git a/rdx/NZ.go b/rdx/NZ.go index 50bc30f..d28102b 100644 --- a/rdx/NZ.go +++ b/rdx/NZ.go @@ -2,7 +2,7 @@ package rdx import ( "fmt" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) // N is an increment-only uint64 counter diff --git a/rdx/NZ_test.go b/rdx/NZ_test.go index ade0446..a64d297 100644 --- a/rdx/NZ_test.go +++ b/rdx/NZ_test.go @@ -1,7 +1,7 @@ package rdx import ( - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" "github.com/stretchr/testify/assert" "testing" ) diff --git a/rdx/X.go b/rdx/X.go index d1aea10..1d29da9 100644 --- a/rdx/X.go +++ b/rdx/X.go @@ -2,7 +2,7 @@ package rdx import ( hex2 "encoding/hex" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) func Xparse(rdt byte, val string) (tlv []byte) { diff --git a/rdx/id.go b/rdx/id.go index 0d2c1a0..17fe639 100644 --- a/rdx/id.go +++ b/rdx/id.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) /* diff --git a/rdx/rdx.go b/rdx/rdx.go index 9bc9d4c..32647a9 100644 --- a/rdx/rdx.go +++ b/rdx/rdx.go @@ -2,7 +2,7 @@ package rdx import ( "errors" - "github.com/learn-decentralized-systems/toyqueue" + "github.com/drpcorg/chotki/toyqueue" ) const ( diff --git a/rdx/vv.go b/rdx/vv.go index 90cc716..3510911 100644 --- a/rdx/vv.go +++ b/rdx/vv.go @@ -4,7 +4,7 @@ import ( "errors" "slices" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toytlv" ) // VV is a version vector, max ids seen from each known replica. diff --git a/repl/commands.go b/repl/commands.go index 0267090..a8d2cea 100644 --- a/repl/commands.go +++ b/repl/commands.go @@ -9,8 +9,8 @@ import ( "github.com/cockroachdb/pebble" "github.com/drpcorg/chotki" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toyqueue" + "github.com/drpcorg/chotki/toytlv" ) var HelpCreate = errors.New("create zone/1 {Name:\"Name\",Description:\"long text\"}") diff --git a/repl/repl.go b/repl/repl.go index bdce599..f9f2709 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -6,8 +6,8 @@ import ( "github.com/cockroachdb/pebble" "github.com/drpcorg/chotki" "github.com/drpcorg/chotki/rdx" + "github.com/drpcorg/chotki/toytlv" "github.com/ergochat/readline" - "github.com/learn-decentralized-systems/toytlv" "io" "os" "strings" diff --git a/sync.go b/sync.go index c4b5973..467425b 100644 --- a/sync.go +++ b/sync.go @@ -5,8 +5,8 @@ import ( "fmt" "github.com/cockroachdb/pebble" "github.com/drpcorg/chotki/rdx" - "github.com/learn-decentralized-systems/toyqueue" - "github.com/learn-decentralized-systems/toytlv" + "github.com/drpcorg/chotki/toyqueue" + "github.com/drpcorg/chotki/toytlv" "io" "os" "sync" diff --git a/toyqueue/drainfeed.go b/toyqueue/drainfeed.go new file mode 100644 index 0000000..621ad12 --- /dev/null +++ b/toyqueue/drainfeed.go @@ -0,0 +1,65 @@ +package toyqueue + +import "io" + +// Records (a batch of) as a very universal primitive, especially +// for database/network op/packet processing. Batching allows +// for writev() and other performance optimizations. ALso, if +// you have cryptography, blobs are way handier than structs. +// Records converts easily to net.Buffers. +type Records [][]byte + +type Feeder interface { + // Feed reads and returns records. + // The EoF convention follows that of io.Reader: + // can either return `records, EoF` or + // `records, nil` followed by `nil/{}, EoF` + Feed() (recs Records, err error) +} + +type FeedCloser interface { + Feeder + io.Closer +} + +type FeedSeeker interface { + Feeder + io.Seeker +} + +type FeedSeekCloser interface { + Feeder + io.Seeker + io.Closer +} + +type Drainer interface { + Drain(recs Records) error +} + +type DrainSeeker interface { + Drainer + io.Seeker +} + +type DrainCloser interface { + Drainer + io.Closer +} + +type DrainSeekCloser interface { + Drainer + io.Seeker + io.Closer +} + +type FeedDrainer interface { + Feeder + Drainer +} + +type FeedDrainCloser interface { + Feeder + Drainer + io.Closer +} diff --git a/toyqueue/queue.go b/toyqueue/queue.go new file mode 100644 index 0000000..7b95f64 --- /dev/null +++ b/toyqueue/queue.go @@ -0,0 +1,153 @@ +package toyqueue + +import ( + "errors" + "sync" +) + +func (recs Records) recrem(total int64) (prelen int, prerem int64) { + for len(recs) > prelen && int64(len(recs[prelen])) <= total { + total -= int64(len(recs[prelen])) + prelen++ + } + prerem = total + return +} + +func (recs Records) WholeRecordPrefix(limit int64) (prefix Records, remainder int64) { + prelen, remainder := recs.recrem(limit) + prefix = recs[:prelen] + return +} + +func (recs Records) ExactSuffix(total int64) (suffix Records) { + prelen, prerem := recs.recrem(total) + suffix = recs[prelen:] + if prerem != 0 { // damages the original, hence copy + edited := make(Records, 1, len(suffix)) + edited[0] = suffix[0][prerem:] + suffix = append(edited, suffix[1:]...) + } + return +} + +func (recs Records) TotalLen() (total int64) { + for _, r := range recs { + total += int64(len(r)) + } + return +} + +type RecordQueue struct { + recs Records + lock sync.Mutex + cond sync.Cond + Limit int +} + +var ErrWouldBlock = errors.New("the queue is over capacity") +var ErrClosed = errors.New("queue is closed") + +func (q *RecordQueue) Drain(recs Records) error { + q.lock.Lock() + was0 := len(q.recs) == 0 + if len(q.recs)+len(recs) > q.Limit { + q.lock.Unlock() + if q.Limit == 0 { + return ErrClosed + } + return ErrWouldBlock + } + q.recs = append(q.recs, recs...) + if was0 && q.cond.L != nil { + q.cond.Broadcast() + } + q.lock.Unlock() + return nil +} + +func (q *RecordQueue) Close() error { + q.Limit = 0 + return nil +} + +func (q *RecordQueue) Feed() (recs Records, err error) { + q.lock.Lock() + if len(q.recs) == 0 { + err = ErrWouldBlock + if q.Limit == 0 { + err = ErrClosed + } + q.lock.Unlock() + return + } + wasfull := len(q.recs) >= q.Limit + recs = q.recs + q.recs = q.recs[len(q.recs):] + if wasfull && q.cond.L != nil { + q.cond.Broadcast() + } + q.lock.Unlock() + return +} + +func (q *RecordQueue) Blocking() FeedDrainCloser { + if q.cond.L == nil { + q.cond.L = &q.lock + } + return &blockingRecordQueue{q} +} + +type blockingRecordQueue struct { + queue *RecordQueue +} + +func (bq *blockingRecordQueue) Close() error { + return bq.queue.Close() +} + +func (bq *blockingRecordQueue) Drain(recs Records) error { + q := bq.queue + q.lock.Lock() + for len(recs) > 0 { + was0 := len(q.recs) == 0 + for q.Limit <= len(q.recs) { + if q.Limit == 0 { + q.lock.Unlock() + return ErrClosed + } + q.cond.Wait() + } + qcap := q.Limit - len(q.recs) + if qcap > len(recs) { + qcap = len(recs) + } + q.recs = append(q.recs, recs[:qcap]...) + recs = recs[qcap:] + if was0 { + q.cond.Broadcast() + } + } + q.lock.Unlock() + return nil +} + +func (bq *blockingRecordQueue) Feed() (recs Records, err error) { + q := bq.queue + q.lock.Lock() + wasfull := len(q.recs) >= q.Limit + for len(q.recs) == 0 { + if q.Limit == 0 { + q.lock.Unlock() + return nil, ErrClosed + } + q.cond.Wait() + } + recs = q.recs + q.recs = q.recs[len(q.recs):] + if wasfull { + q.cond.Broadcast() + } + q.lock.Unlock() + return +} diff --git a/toyqueue/queue_test.go b/toyqueue/queue_test.go new file mode 100644 index 0000000..b9fdd17 --- /dev/null +++ b/toyqueue/queue_test.go @@ -0,0 +1,50 @@ +package toyqueue + +import ( + "encoding/binary" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestBlockingRecordQueue_Drain(t *testing.T) { + const N = 1 << 10 // 8K + const K = 1 << 4 // 16 + + orig := RecordQueue{Limit: 1024} + queue := orig.Blocking() + + for k := 0; k < K; k++ { + go func(k int) { + i := uint64(k) << 32 + for n := uint64(0); n < N; n++ { + var b [8]byte + binary.LittleEndian.PutUint64(b[:], i|n) + err := queue.Drain(Records{b[:]}) + assert.Nil(t, err) + } + }(k) + } + + check := [K]int{} + for i := uint64(0); i < N*K; { + nums, err := queue.Feed() + assert.Nil(t, err) + for _, num := range nums { + assert.Equal(t, 8, len(num)) + j := binary.LittleEndian.Uint64(num) + k := int(j >> 32) + n := int(j & 0xffffffff) + assert.Equal(t, check[k], n) + check[k] = n + 1 + i++ + } + } + + recs := [][]byte{{'a'}} + assert.Nil(t, queue.Close()) + err := queue.Drain(recs) + assert.Equal(t, ErrClosed, err) + _, err2 := queue.Feed() + assert.Equal(t, ErrClosed, err2) + +} diff --git a/toyqueue/util.go b/toyqueue/util.go new file mode 100644 index 0000000..728bbb6 --- /dev/null +++ b/toyqueue/util.go @@ -0,0 +1,46 @@ +package toyqueue + +func Relay(feeder Feeder, drainer Drainer) error { + recs, err := feeder.Feed() + if err != nil { + if len(recs) > 0 { + _ = drainer.Drain(recs) + } + return err + } + err = drainer.Drain(recs) + return err +} + +func Pump(feeder Feeder, drainer Drainer) (err error) { + for err == nil { + err = Relay(feeder, drainer) + } + return +} + +func PumpN(feeder Feeder, drainer Drainer, n int) (err error) { + for err == nil && n > 0 { + err = Relay(feeder, drainer) + n-- + } + return +} + +func PumpThenClose(feed FeedCloser, drain DrainCloser) error { + var ferr, derr error + for ferr == nil && derr == nil { + var recs Records + recs, ferr = feed.Feed() + if len(recs) > 0 { // e.g. Feed() may return data AND EOF + derr = drain.Drain(recs) + } + } + _ = feed.Close() + _ = drain.Close() + if ferr != nil { + return ferr + } else { + return derr + } +} diff --git a/toyqueue/util_test.go b/toyqueue/util_test.go new file mode 100644 index 0000000..8fb0705 --- /dev/null +++ b/toyqueue/util_test.go @@ -0,0 +1,54 @@ +package toyqueue + +import ( + "github.com/stretchr/testify/assert" + "io" + "testing" +) + +type sliceFeedDrainer struct { + data []byte + res []byte +} + +func (fd *sliceFeedDrainer) Close() error { + fd.res = append(fd.res, '(') + fd.res = append(fd.res, fd.data...) + fd.res = append(fd.res, ')') + return nil +} + +func (fd *sliceFeedDrainer) Drain(recs Records) error { + for _, rec := range recs { + fd.data = append(fd.data, rec...) + } + return nil +} + +func (fd *sliceFeedDrainer) Feed() (recs Records, err error) { + for i := 0; i < 3 && len(fd.data) > 0; i++ { + recs = append(recs, fd.data[0:1]) + fd.data = fd.data[1:] + } + if len(fd.data) == 0 { + err = io.EOF + } + return +} + +func TestPump(t *testing.T) { + sfd := sliceFeedDrainer{ + data: []byte("Hello world"), + } + err := PumpN(&sfd, &sfd, 2) + assert.Nil(t, err) + assert.Equal(t, sfd.data, []byte("worldHello ")) + + fro := sliceFeedDrainer{ + data: []byte("Hello world"), + } + to := sliceFeedDrainer{} + err = PumpThenClose(&fro, &to) + assert.Equal(t, err, io.EOF) + assert.Equal(t, []byte("(Hello world)"), to.res) +} diff --git a/toytlv/tcp.go b/toytlv/tcp.go new file mode 100644 index 0000000..87b8dd2 --- /dev/null +++ b/toytlv/tcp.go @@ -0,0 +1,375 @@ +package toytlv + +import ( + "io" + "log/slog" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/drpcorg/chotki/toyqueue" +) + +const ( + TYPICAL_MTU = 1500 + MAX_OUT_QUEUE_LEN = 1 << 20 // 16MB of pointers is a lot + + MAX_RETRY_PERIOD = time.Minute + MIN_RETRY_PERIOD = time.Second / 2 +) + +type TCPConn struct { + addr string + conn atomic.Pointer[net.Conn] + inout toyqueue.FeedDrainCloser + + wake sync.Cond + outmx sync.Mutex + + Reconnect bool + KeepAlive bool +} + +func (tcp *TCPConn) doRead() { + err := tcp.read() + if err != nil && err != ErrDisconnected { + // TODO: error handling + slog.Error("couldn't read from conn", "err", err) + } +} + +func (tcp *TCPConn) read() (err error) { + var buf []byte + for { + conn := tcp.conn.Load() + if conn == nil { + break + } + + buf, err = AppendRead(buf, *conn, TYPICAL_MTU) + if err != nil { + break + } + var recs toyqueue.Records + recs, buf, err = Split(buf) + if len(recs) == 0 { + time.Sleep(time.Millisecond) + continue + } + if err != nil { + break + } + + err = tcp.inout.Drain(recs) + if err != nil { + break + } + } + + if err != nil { + // TODO: error handling + slog.Error("couldn't read from conn", "err", err) + tcp.Close() + } + return +} + +func (tcp *TCPConn) doWrite() { + var err error + var recs toyqueue.Records + for err == nil { + conn := tcp.conn.Load() + if conn == nil { + break + } + + recs, err = tcp.inout.Feed() + b := net.Buffers(recs) + for len(b) > 0 && err == nil { + _, err = b.WriteTo(*conn) + } + } + if err != nil { + tcp.Close() // TODO err + } +} + +// Write what we believe is a valid ToyTLV frame. +// Provided for io.Writer compatibility +func (tcp *TCPConn) Write(data []byte) (n int, err error) { + err = tcp.Drain(toyqueue.Records{data}) + if err == nil { + n = len(data) + } + return +} + +func (tcp *TCPConn) Drain(recs toyqueue.Records) (err error) { + return tcp.inout.Drain(recs) +} + +func (tcp *TCPConn) Feed() (recs toyqueue.Records, err error) { + return tcp.inout.Feed() +} + +func (tcp *TCPConn) KeepTalking() { + talk_backoff := MIN_RETRY_PERIOD + conn_backoff := MIN_RETRY_PERIOD + + for { + conntime := time.Now() + go tcp.doWrite() + + // TODO: error handling + err := tcp.read() + slog.Error("couldn't read from conn", "err", err) + + if !tcp.Reconnect { + break + } + + atLeast5min := conntime.Add(time.Minute * 5) + if atLeast5min.After(time.Now()) { + talk_backoff *= 2 // connected, tried to talk, failed => wait more + if talk_backoff > MAX_RETRY_PERIOD { + talk_backoff = MAX_RETRY_PERIOD + } + } + + for { + if conn := tcp.conn.Load(); conn == nil { + break + } + + time.Sleep(conn_backoff + talk_backoff) + conn, err := net.Dial("tcp", tcp.addr) + if err != nil { + conn_backoff = conn_backoff * 2 + if conn_backoff > MAX_RETRY_PERIOD/2 { + conn_backoff = MAX_RETRY_PERIOD + } + } else { + tcp.conn.Store(&conn) + conn_backoff = MIN_RETRY_PERIOD + } + } + } +} + +func (tcp *TCPConn) Close() error { + tcp.outmx.Lock() + defer tcp.outmx.Unlock() + + // TODO writer closes on complete | 1 sec expired + if conn := tcp.conn.Swap(nil); conn != nil { + if err := (*conn).Close(); err != nil { + return err + } + + tcp.wake.Broadcast() + } + + return nil +} + +type Jack func(conn net.Conn) toyqueue.FeedDrainCloser + +// A TCP server/client for the use case of real-time async communication. +// Differently from the case of request-response (like HTTP), we do not +// wait for a request, then dedicating a thread to processing, then sending +// back the resulting response. Instead, we constantly fan sendQueue tons of +// tiny messages. That dictates different work patterns than your typical +// HTTP/RPC server as, for example, we cannot let one slow receiver delay +// event transmission to all the other receivers. +type TCPDepot struct { + conns map[string]*TCPConn + listens map[string]net.Listener + conmx sync.Mutex + jack Jack +} + +func (de *TCPDepot) Open(jack Jack) { + de.conmx.Lock() + de.conns = make(map[string]*TCPConn) + de.listens = make(map[string]net.Listener) + de.conmx.Unlock() + de.jack = jack +} + +func (de *TCPDepot) Close() error { + de.conmx.Lock() + defer de.conmx.Unlock() + + for _, lstn := range de.listens { + if err := lstn.Close(); err != nil { + return err + } + } + clear(de.listens) + + for _, con := range de.conns { + if err := con.Close(); err != nil { + return err + } + } + clear(de.conns) + + return nil +} + +// attrib?! +func (de *TCPDepot) Connect(addr string) (err error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return err + } + peer := TCPConn{ + addr: addr, + inout: de.jack(conn), + } + peer.wake.L = &peer.outmx + peer.conn.Store(&conn) + + de.conmx.Lock() + de.conns[addr] = &peer + de.conmx.Unlock() + + go peer.KeepTalking() + + return nil +} + +func (de *TCPDepot) DrainTo(recs toyqueue.Records, addr string) error { + de.conmx.Lock() + conn, ok := de.conns[addr] + de.conmx.Unlock() + if !ok { + return ErrAddressUnknown + } + return conn.Drain(recs) +} + +func (de *TCPDepot) Disconnect(addr string) (err error) { + de.conmx.Lock() + tcp, ok := de.conns[addr] + de.conmx.Unlock() + if !ok { + return ErrAddressUnknown + } + tcp.Close() + de.conmx.Lock() + delete(de.conns, addr) + de.conmx.Unlock() + return nil +} + +func (de *TCPDepot) Listen(addr string) (err error) { + listener, err := net.Listen("tcp", addr) + if err != nil { + return + } + de.conmx.Lock() + pre, ok := de.listens[addr] + if ok { + _ = pre.Close() + } + de.listens[addr] = listener + de.conmx.Unlock() + go de.KeepListening(addr) + return +} + +func (de *TCPDepot) StopListening(addr string) error { + de.conmx.Lock() + listener, ok := de.listens[addr] + delete(de.listens, addr) + de.conmx.Unlock() + if !ok { + return ErrAddressUnknown + } + return listener.Close() +} + +func (de *TCPDepot) KeepListening(addr string) { + for { + de.conmx.Lock() + listener, ok := de.listens[addr] + de.conmx.Unlock() + + if !ok { + break + } + conn, err := listener.Accept() + if err != nil { + break + } + addr := conn.RemoteAddr().String() + peer := TCPConn{ + addr: addr, + inout: de.jack(conn), + } + peer.wake.L = &peer.outmx + peer.conn.Store(&conn) + + de.conmx.Lock() + de.conns[addr] = &peer + de.conmx.Unlock() + + go peer.doWrite() + go peer.doRead() + } +} + +func ReadBuf(buf []byte, rdr io.Reader) ([]byte, error) { + avail := cap(buf) - len(buf) + if avail < 512 { + l := 4096 + if len(buf) > 2048 { + l = len(buf) * 2 + } + newbuf := make([]byte, l) + copy(newbuf[:], buf) + buf = newbuf[:len(buf)] + } + idle := buf[len(buf):cap(buf)] + n, err := rdr.Read(idle) + if err != nil { + return buf, err + } + if n == 0 { + return buf, io.EOF + } + buf = buf[:len(buf)+n] + return buf, nil +} + +func RoundPage(l int) int { + if (l & 0xfff) != 0 { + l = (l & ^0xfff) + 0x1000 + } + return l +} + +// AppendRead reads data from io.Reader into the *spare space* of the provided buffer, +// i.e. those cap(buf)-len(buf) vacant bytes. If the spare space is smaller than +// lenHint, allocates (as reading less bytes might be unwise). +func AppendRead(buf []byte, rdr io.Reader, lenHint int) ([]byte, error) { + avail := cap(buf) - len(buf) + if avail < lenHint { + want := RoundPage(len(buf) + lenHint) + newbuf := make([]byte, want) + copy(newbuf[:], buf) + buf = newbuf[:len(buf)] + } + idle := buf[len(buf):cap(buf)] + n, err := rdr.Read(idle) + if err != nil { + return buf, err + } + if n == 0 { + return buf, io.EOF + } + buf = buf[:len(buf)+n] + return buf, nil +} diff --git a/toytlv/tcp_test.go b/toytlv/tcp_test.go new file mode 100644 index 0000000..35e9251 --- /dev/null +++ b/toytlv/tcp_test.go @@ -0,0 +1,94 @@ +package toytlv + +import ( + "net" + "sync" + "sync/atomic" + "testing" + + "github.com/drpcorg/chotki/toyqueue" + "github.com/stretchr/testify/assert" +) + +// 1. create a server, create a client, echo +// 2. create a server, client, connect, disconn, reconnect +// 3. create a server, client, conn, stop the serv, relaunch, reconnect + +type TestConsumer struct { + rcvd toyqueue.Records + mx sync.Mutex + co sync.Cond +} + +func (c *TestConsumer) Drain(recs toyqueue.Records) error { + c.mx.Lock() + c.rcvd = append(c.rcvd, recs...) + c.co.Signal() + c.mx.Unlock() + return nil +} + +func (c *TestConsumer) Feed() (recs toyqueue.Records, err error) { + c.mx.Lock() + if len(c.rcvd) == 0 { + c.co.Wait() + } + recs = c.rcvd + c.rcvd = c.rcvd[len(c.rcvd):] + c.mx.Unlock() + return +} + +func (c *TestConsumer) Close() error { + return nil +} + +func TestTCPDepot_Connect(t *testing.T) { + + loop := "127.0.0.1:12345" + + tc := TestConsumer{} + tc.co.L = &tc.mx + depot := TCPDepot{} + var addr atomic.Value + addr.Store("") + + depot.Open(func(conn net.Conn) toyqueue.FeedDrainCloser { + a := conn.RemoteAddr().String() + if a != loop { + addr.Store(a) + } + return &tc + }) + + err := depot.Listen(loop) + assert.Nil(t, err) + + err = depot.Connect(loop) + assert.Nil(t, err) + + // send a record + recsto := toyqueue.Records{Record('M', []byte("Hi there"))} + err = depot.DrainTo(recsto, loop) + assert.Nil(t, err) + rec, err := tc.Feed() + assert.Nil(t, err) + lit, body, rest := TakeAny(rec[0]) + assert.Equal(t, uint8('M'), lit) + assert.Equal(t, "Hi there", string(body)) + assert.Equal(t, 0, len(rest)) + + // respond to that + recsback := toyqueue.Records{Record('M', []byte("Re: Hi there"))} + err = depot.DrainTo(recsback, addr.Load().(string)) + assert.Nil(t, err) + rerec, err := tc.Feed() + assert.Nil(t, err) + relit, rebody, rerest := TakeAny(rerec[0]) + assert.Equal(t, uint8('M'), relit) + assert.Equal(t, "Re: Hi there", string(rebody)) + assert.Equal(t, 0, len(rerest)) + + depot.Close() + +} diff --git a/toytlv/tlv.go b/toytlv/tlv.go new file mode 100644 index 0000000..6122811 --- /dev/null +++ b/toytlv/tlv.go @@ -0,0 +1,322 @@ +package toytlv + +import ( + "encoding/binary" + "errors" + + "github.com/drpcorg/chotki/toyqueue" +) + +const CaseBit uint8 = 'a' - 'A' + +var ( + ErrIncomplete = errors.New("incomplete data") + ErrBadRecord = errors.New("bad TLV record format") + ErrAddressUnknown = errors.New("address unknown") + ErrDisconnected = errors.New("disconnected by user") +) + +// ProbeHeader probes a TLV record header. Return values: +// - 0 0 0 incomplete header +// - '-' 0 0 bad format +// - 'A' 2 123 success +func ProbeHeader(data []byte) (lit byte, hdrlen, bodylen int) { + if len(data) == 0 { + return 0, 0, 0 + } + dlit := data[0] + if dlit >= '0' && dlit <= '9' { // tiny + lit = '0' + bodylen = int(dlit - '0') + hdrlen = 1 + } else if dlit >= 'a' && dlit <= 'z' { // short + if len(data) < 2 { + return + } + lit = dlit - CaseBit + hdrlen = 2 + bodylen = int(data[1]) + } else if dlit >= 'A' && dlit <= 'Z' { // long + if len(data) < 5 { + return + } + bl := binary.LittleEndian.Uint32(data[1:5]) + if bl > 0x7fffffff { + lit = '-' + return + } + lit = dlit + bodylen = int(bl) + hdrlen = 5 + } else { + lit = '-' + } + return +} + +// Incomplete returns the number of supposedly yet-unread bytes. +// 0 for complete, -1 for bad format, +// >0 for least-necessary read to complete either header or record. +func Incomplete(data []byte) int { + if len(data) == 0 { + return 1 // get something + } + dlit := data[0] + var bodylen int + if dlit >= '0' && dlit <= '9' { // tiny + bodylen = int(dlit - '0') + } else if dlit >= 'a' && dlit <= 'z' { // short + if len(data) < 2 { + bodylen = 2 + } else { + bodylen = int(data[1]) + 2 + } + } else if dlit >= 'A' && dlit <= 'Z' { // long + if len(data) < 5 { + bodylen = 5 + } else { + bl := binary.LittleEndian.Uint32(data[1:5]) + if bl > 0x7fffffff { + return -1 + } + bodylen = int(bl) + 5 + } + } else { + return -1 + } + if bodylen > len(data) { + return bodylen - len(data) + } else { + return 0 + } +} + +func Split(data []byte) (recs toyqueue.Records, rest []byte, err error) { + rest = data + for len(rest) > 0 { + lit, hlen, blen := ProbeHeader(rest) + if lit == '-' { + if len(recs) == 0 { + err = ErrBadRecord + } + return + } + if lit == 0 { + return + } + if hlen+blen > len(rest) { + break + } + recs = append(recs, rest[:hlen+blen]) + rest = rest[hlen+blen:] + } + return +} + +func ProbeHeaders(lits string, data []byte) int { + rest := data + for i := 0; i < len(lits); i++ { + l, hl, bl := ProbeHeader(rest) + if l != lits[i] { + return -1 + } + rest = rest[hl+bl:] + } + return len(data) - len(rest) +} + +// Feeds the header into the buffer. +// Subtle: lower-case lit allows for defaulting, uppercase must be explicit. +func AppendHeader(into []byte, lit byte, bodylen int) (ret []byte) { + biglit := lit &^ CaseBit + if biglit < 'A' || biglit > 'Z' { + panic("ToyTLV record type is A..Z") + } + if bodylen < 10 && (lit&CaseBit) != 0 { + ret = append(into, byte('0'+bodylen)) + } else if bodylen > 0xff { + if bodylen > 0x7fffffff { + panic("oversized TLV record") + } + ret = append(into, biglit) + ret = binary.LittleEndian.AppendUint32(ret, uint32(bodylen)) + } else { + ret = append(into, lit|CaseBit, byte(bodylen)) + } + return ret +} + +// Take is used to read safe TLV inputs (e.g. from own storage) with +// record types known in advance. +func Take(lit byte, data []byte) (body, rest []byte) { + flit, hdrlen, bodylen := ProbeHeader(data) + if flit == 0 || hdrlen+bodylen > len(data) { + return nil, data // Incomplete + } + if flit != lit && flit != '0' { + return nil, nil // BadRecord + } + body = data[hdrlen : hdrlen+bodylen] + rest = data[hdrlen+bodylen:] + return +} + +// TakeAny is used for safe TLV inputs when record types can vary. +func TakeAny(data []byte) (lit byte, body, rest []byte) { + if len(data) == 0 { + return 0, nil, nil + } + lit = data[0] & ^CaseBit + body, rest = Take(lit, data) + return +} + +// TakeWary reads TLV records of known type from unsafe input. +func TakeWary(lit byte, data []byte) (body, rest []byte, err error) { + flit, hdrlen, bodylen := ProbeHeader(data) + if flit == 0 || hdrlen+bodylen > len(data) { + return nil, data, ErrIncomplete + } + if flit != lit && flit != '0' { + return nil, nil, ErrBadRecord + } + body = data[hdrlen : hdrlen+bodylen] + rest = data[hdrlen+bodylen:] + return +} + +// TakeWary reads TLV records of arbitrary type from unsafe input. +func TakeAnyWary(data []byte) (lit byte, body, rest []byte, err error) { + if len(data) == 0 { + return 0, nil, nil, ErrIncomplete + } + lit = data[0] & ^CaseBit + body, rest = Take(lit, data) + return +} + +func TakeRecord(lit byte, data []byte) (rec, rest []byte) { + flit, hdrlen, bodylen := ProbeHeader(data) + if flit == 0 || hdrlen+bodylen > len(data) { + return nil, data // Incomplete + } + if flit != lit && flit != '0' { + return nil, nil // BadRecord + } + rec = data[0 : hdrlen+bodylen] + rest = data[hdrlen+bodylen:] + return +} + +func TakeAnyRecord(data []byte) (lit byte, rec, rest []byte) { + lit, hdrlen, bodylen := ProbeHeader(data) + if lit == 0 || hdrlen+bodylen > len(data) { + return 0, nil, data // Incomplete + } + if lit == '-' { + return '-', nil, nil // BadRecord + } + rec = data[0 : hdrlen+bodylen] + rest = data[hdrlen+bodylen:] + return +} + +func TotalLen(inputs [][]byte) (sum int) { + for _, input := range inputs { + sum += len(input) + } + return +} + +func Lit(rec []byte) byte { + b := rec[0] + if b >= 'a' && b <= 'z' { + return b - CaseBit + } else if b >= 'A' && b <= 'Z' { + return b + } else if b >= '0' && b <= '9' { + return '0' + } else { + return '-' + } +} + +// Append appends a record to the buffer; note that uppercase type +// is always explicit, lowercase can be defaulted. +func Append(into []byte, lit byte, body ...[]byte) (res []byte) { + total := TotalLen(body) + res = AppendHeader(into, lit, total) + for _, b := range body { + res = append(res, b...) + } + return res +} + +// Record composes a record of a given type +func Record(lit byte, body ...[]byte) []byte { + total := TotalLen(body) + ret := make([]byte, 0, total+5) + ret = AppendHeader(ret, lit, total) + for _, b := range body { + ret = append(ret, b...) + } + return ret +} + +func AppendTiny(into []byte, lit byte, body []byte) (res []byte) { + if len(body) > 9 { + return Append(into, lit, body) + } + res = append(into, '0'+byte(len(body))) + res = append(res, body...) + return +} + +func TinyRecord(lit byte, body []byte) (tiny []byte) { + var data [10]byte + return AppendTiny(data[:0], lit, body) +} + +func Join(records ...[]byte) (ret toyqueue.Records) { + for _, rec := range records { + ret = append(ret, rec) + } + return +} + +func Records(lit byte, bodies ...[]byte) (recs toyqueue.Records) { + for _, body := range bodies { + recs = append(recs, Record(lit, body)) + } + return +} + +func Concat(msg ...[]byte) []byte { + total := TotalLen(msg) + ret := make([]byte, 0, total) + for _, b := range msg { + ret = append(ret, b...) + } + return ret +} + +// OpenHeader opens a streamed TLV record; use append() to create the +// record body, then call CloseHeader(&buf, bookmark) +func OpenHeader(buf []byte, lit byte) (bookmark int, res []byte) { + lit &= ^CaseBit + if lit < 'A' || lit > 'Z' { + panic("TLV liters are uppercase A-Z") + } + res = append(buf, lit) + blanclen := []byte{0, 0, 0, 0} + res = append(res, blanclen...) + return len(res), res +} + +// CloseHeader closes a streamed TLV record +func CloseHeader(buf []byte, bookmark int) { + if bookmark < 5 || len(buf) < bookmark { + panic("check the API docs") + } + binary.LittleEndian.PutUint32(buf[bookmark-4:bookmark], uint32(len(buf)-bookmark)) +} diff --git a/toytlv/tlv_test.go b/toytlv/tlv_test.go new file mode 100644 index 0000000..c8f8c8b --- /dev/null +++ b/toytlv/tlv_test.go @@ -0,0 +1,52 @@ +package toytlv + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTLVAppend(t *testing.T) { + buf := []byte{} + buf = Append(buf, 'A', []byte{'A'}) + buf = Append(buf, 'b', []byte{'B', 'B'}) + correct2 := []byte{'a', 1, 'A', '2', 'B', 'B'} + assert.Equal(t, correct2, buf, "basic TLV fail") + + var c256 [256]byte + for n := range c256 { + c256[n] = 'c' + } + buf = Append(buf, 'C', c256[:]) + assert.Equal(t, len(correct2)+1+4+len(c256), len(buf)) + assert.Equal(t, uint8(67), buf[len(correct2)]) + assert.Equal(t, uint8(1), buf[len(correct2)+2]) + + lit, body, buf, err := TakeAnyWary(buf) + assert.Nil(t, err) + assert.Equal(t, uint8('A'), lit) + assert.Equal(t, []byte{'A'}, body) + + body2, _, err2 := TakeWary('B', buf) + assert.Nil(t, err2) + assert.Equal(t, []byte{'B', 'B'}, body2) +} + +func TestFeedHeader(t *testing.T) { + buf := []byte{} + l, buf := OpenHeader(buf, 'A') + text := "some text" + buf = append(buf, text...) + CloseHeader(buf, l) + lit, body, rest, err := TakeAnyWary(buf) + assert.Nil(t, err) + assert.Equal(t, uint8('A'), lit) + assert.Equal(t, text, string(body)) + assert.Equal(t, 0, len(rest)) +} + +func TestTinyRecord(t *testing.T) { + body := "12" + tiny := TinyRecord('X', []byte(body)) + assert.Equal(t, "212", string(tiny)) +}