From 7d1e3b30a6abb49d692f8aed67c1a85d7ce00e0c Mon Sep 17 00:00:00 2001 From: Nikolay Govorov Date: Mon, 22 Apr 2024 23:59:20 +0100 Subject: [PATCH] refactor tcp transport --- chotki.go | 54 +-- repl/commands.go | 5 +- toytlv/peer.go | 194 +++++++++++ toytlv/tcp.go | 383 ---------------------- toytlv/tlv.go | 2 + toytlv/transport.go | 178 ++++++++++ toytlv/{tcp_test.go => transport_test.go} | 12 +- 7 files changed, 415 insertions(+), 413 deletions(-) create mode 100644 toytlv/peer.go delete mode 100644 toytlv/tcp.go create mode 100644 toytlv/transport.go rename toytlv/{tcp_test.go => transport_test.go} (86%) diff --git a/chotki.go b/chotki.go index 674a03e..5bf4d50 100644 --- a/chotki.go +++ b/chotki.go @@ -1,6 +1,7 @@ package chotki import ( + "context" "encoding/binary" "errors" "fmt" @@ -53,7 +54,7 @@ type Chotki struct { src uint64 db *pebble.DB - net *toytlv.TCPDepot + net *toytlv.Transport dir string opts Options @@ -188,14 +189,13 @@ func Open(dirname string, opts Options) (*Chotki, error) { src: opts.Src, dir: dirname, opts: opts, - net: &toytlv.TCPDepot{}, types: make(map[rdx.ID]Fields), hooks: make(map[rdx.ID][]Hook), syncs: make(map[rdx.ID]*pebble.Batch), outq: make(map[string]toyqueue.DrainCloser), } - cho.net.Open(func(conn net.Conn) toyqueue.FeedDrainCloser { + cho.net = toytlv.NewTransport(func(conn net.Conn) toyqueue.FeedDrainCloser { return &Syncer{ Host: &cho, Mode: SyncRWLive, @@ -203,21 +203,6 @@ func Open(dirname string, opts Options) (*Chotki, error) { } }) - if opts.RestoreNetwork { - i := cho.db.NewIter(&pebble.IterOptions{}) - defer i.Close() - - for i.SeekGE([]byte{'l'}); i.Valid() && i.Key()[0] == 'L'; i.Next() { - address := string(i.Key()[1:]) - _ = cho.net.Listen(address) - } - - for i.SeekGE([]byte{'c'}); i.Valid() && i.Key()[0] == 'C'; i.Next() { - address := string(i.Key()[1:]) - _ = cho.net.Connect(address) - } - } - if !exists { id0 := rdx.IDFromSrcSeqOff(opts.Src, 0, 0) @@ -243,12 +228,37 @@ func Open(dirname string, opts Options) (*Chotki, error) { return &cho, nil } -func (cho *Chotki) Listen(addr string) error { - return cho.net.Listen(addr) +func (cho *Chotki) RestoreNet(ctx context.Context) error { + i := cho.db.NewIter(&pebble.IterOptions{}) + defer i.Close() + + for i.SeekGE([]byte{'l'}); i.Valid() && i.Key()[0] == 'L'; i.Next() { + address := string(i.Key()[1:]) + _ = cho.net.Listen(ctx, address) + } + + for i.SeekGE([]byte{'c'}); i.Valid() && i.Key()[0] == 'C'; i.Next() { + address := string(i.Key()[1:]) + _ = cho.net.Connect(ctx, address) + } + + return nil +} + +func (cho *Chotki) Listen(ctx context.Context, addr string) error { + return cho.net.Listen(ctx, addr) +} + +func (cho *Chotki) Unlisten(addr string) error { + return cho.net.Unlisten(addr) +} + +func (cho *Chotki) Connect(ctx context.Context, addr string) error { + return cho.net.Connect(ctx, addr) } -func (cho *Chotki) Connect(addr string) error { - return cho.net.Connect(addr) +func (cho *Chotki) Disconnect(addr string) error { + return cho.net.Disconnect(addr) } func (cho *Chotki) AddPacketHose(name string) (feed toyqueue.FeedCloser) { diff --git a/repl/commands.go b/repl/commands.go index fd44e25..6c384d6 100644 --- a/repl/commands.go +++ b/repl/commands.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "os" @@ -316,7 +317,7 @@ func (repl *REPL) CommandListen(arg *rdx.RDX) (id rdx.ID, err error) { } addr := rdx.Snative(rdx.Sparse(string(arg.Text))) if err == nil { - err = repl.Host.Listen(addr) + err = repl.Host.Listen(context.Background(), addr) } return } @@ -329,7 +330,7 @@ func (repl *REPL) CommandConnect(arg *rdx.RDX) (id rdx.ID, err error) { } addr := rdx.Snative(rdx.Sparse(string(arg.Text))) if err == nil { - err = repl.Host.Connect(addr) + err = repl.Host.Connect(context.Background(), addr) } return } diff --git a/toytlv/peer.go b/toytlv/peer.go new file mode 100644 index 0000000..ccf1be1 --- /dev/null +++ b/toytlv/peer.go @@ -0,0 +1,194 @@ +package toytlv + +import ( + "io" + "log/slog" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/drpcorg/chotki/toyqueue" +) + +type ConnType = uint + +const ( + TCP ConnType = iota + 1 + TLS + QUIC +) + +const ( + TYPICAL_MTU = 1500 + MAX_OUT_QUEUE_LEN = 1 << 20 // 16MB of pointers is a lot + + MAX_RETRY_PERIOD = time.Minute + MIN_RETRY_PERIOD = time.Second / 2 +) + +type Jack func(conn net.Conn) toyqueue.FeedDrainCloser + +type Peer struct { + conn atomic.Pointer[net.Conn] + inout toyqueue.FeedDrainCloser + reconnect func() (net.Conn, error) + + mu sync.Mutex + co sync.Cond + + Protocol ConnType + KeepAlive bool +} + +func (tcp *Peer) doRead() { + err := tcp.read() + + if err != nil && err != ErrDisconnected { + _ = tcp.Close() + + // TODO: error handling + slog.Error("couldn't read from conn", "err", err) + } +} + +func (tcp *Peer) read() error { + var buf []byte + var err error + + for { + conn := tcp.conn.Load() + if conn == nil { + break + } + + if buf, err = appendRead(buf, *conn, TYPICAL_MTU); err != nil { + return err + } + + var recs toyqueue.Records + if recs, buf, err = Split(buf); err != nil { + return err + } else if len(recs) == 0 { + time.Sleep(time.Millisecond) + continue + } + + if err = tcp.inout.Drain(recs); err != nil { + return err + } + } + + return nil +} + +func (tcp *Peer) doWrite() { + var err error + var recs toyqueue.Records + for err == nil { + conn := tcp.conn.Load() + if conn == nil { + break + } + + recs, err = tcp.inout.Feed() + b := net.Buffers(recs) + for len(b) > 0 && err == nil { + _, err = b.WriteTo(*conn) + } + } + if err != nil { + tcp.Close() // TODO err + } +} + +func (tcp *Peer) Drain(recs toyqueue.Records) error { + return tcp.inout.Drain(recs) +} + +func (tcp *Peer) Feed() (toyqueue.Records, error) { + return tcp.inout.Feed() +} + +func (tcp *Peer) keepTalking() { + go tcp.doWrite() + go tcp.doRead() + + talkBackoff, connBackoff := MIN_RETRY_PERIOD, MIN_RETRY_PERIOD + + for tcp.reconnect == nil { + conntime := time.Now() + + atLeast5min := conntime.Add(time.Minute * 5) + if atLeast5min.After(time.Now()) { + talkBackoff *= 2 // connected, tried to talk, failed => wait more + if talkBackoff > MAX_RETRY_PERIOD { + talkBackoff = MAX_RETRY_PERIOD + } + } + + for { + if conn := tcp.conn.Load(); conn == nil { + break + } + + time.Sleep(connBackoff + talkBackoff) + conn, err := tcp.reconnect() + if err != nil { + connBackoff = connBackoff * 2 + if connBackoff > MAX_RETRY_PERIOD/2 { + connBackoff = MAX_RETRY_PERIOD + } + } else { + tcp.conn.Store(&conn) + connBackoff = MIN_RETRY_PERIOD + } + } + } +} + +func (tcp *Peer) Close() error { + tcp.mu.Lock() + defer tcp.mu.Unlock() + + // TODO writer closes on complete | 1 sec expired + if conn := tcp.conn.Swap(nil); conn != nil { + if err := (*conn).Close(); err != nil { + return err + } + + tcp.co.Broadcast() + } + + return nil +} + +func roundPage(l int) int { + if (l & 0xfff) != 0 { + l = (l & ^0xfff) + 0x1000 + } + return l +} + +// appendRead reads data from io.Reader into the *spare space* of the provided buffer, +// i.e. those cap(buf)-len(buf) vacant bytes. If the spare space is smaller than +// lenHint, allocates (as reading less bytes might be unwise). +func appendRead(buf []byte, rdr io.Reader, lenHint int) ([]byte, error) { + avail := cap(buf) - len(buf) + if avail < lenHint { + want := roundPage(len(buf) + lenHint) + newbuf := make([]byte, want) + copy(newbuf[:], buf) + buf = newbuf[:len(buf)] + } + idle := buf[len(buf):cap(buf)] + n, err := rdr.Read(idle) + if err != nil { + return buf, err + } + if n == 0 { + return buf, io.EOF + } + buf = buf[:len(buf)+n] + return buf, nil +} diff --git a/toytlv/tcp.go b/toytlv/tcp.go deleted file mode 100644 index a3455b1..0000000 --- a/toytlv/tcp.go +++ /dev/null @@ -1,383 +0,0 @@ -package toytlv - -import ( - "io" - "log/slog" - "net" - "sync" - "sync/atomic" - "time" - - "github.com/drpcorg/chotki/toyqueue" -) - -const ( - TYPICAL_MTU = 1500 - MAX_OUT_QUEUE_LEN = 1 << 20 // 16MB of pointers is a lot - - MAX_RETRY_PERIOD = time.Minute - MIN_RETRY_PERIOD = time.Second / 2 -) - -type ConnType = uint - -const ( - TCP ConnType = iota + 1 - TLS ConnType = iota + 1 - QUIC ConnType = iota + 1 -) - -type TCPConn struct { - addr string - conn atomic.Pointer[net.Conn] - inout toyqueue.FeedDrainCloser - - wake sync.Cond - outmx sync.Mutex - - Reconnect bool - KeepAlive bool -} - -func (tcp *TCPConn) doRead() { - err := tcp.read() - if err != nil && err != ErrDisconnected { - // TODO: error handling - slog.Error("couldn't read from conn", "err", err) - } -} - -func (tcp *TCPConn) read() (err error) { - var buf []byte - for { - conn := tcp.conn.Load() - if conn == nil { - break - } - - buf, err = AppendRead(buf, *conn, TYPICAL_MTU) - if err != nil { - break - } - var recs toyqueue.Records - recs, buf, err = Split(buf) - if len(recs) == 0 { - time.Sleep(time.Millisecond) - continue - } - if err != nil { - break - } - - err = tcp.inout.Drain(recs) - if err != nil { - break - } - } - - if err != nil { - // TODO: error handling - slog.Error("couldn't read from conn", "err", err) - tcp.Close() - } - return -} - -func (tcp *TCPConn) doWrite() { - var err error - var recs toyqueue.Records - for err == nil { - conn := tcp.conn.Load() - if conn == nil { - break - } - - recs, err = tcp.inout.Feed() - b := net.Buffers(recs) - for len(b) > 0 && err == nil { - _, err = b.WriteTo(*conn) - } - } - if err != nil { - tcp.Close() // TODO err - } -} - -// Write what we believe is a valid ToyTLV frame. -// Provided for io.Writer compatibility -func (tcp *TCPConn) Write(data []byte) (n int, err error) { - err = tcp.Drain(toyqueue.Records{data}) - if err == nil { - n = len(data) - } - return -} - -func (tcp *TCPConn) Drain(recs toyqueue.Records) (err error) { - return tcp.inout.Drain(recs) -} - -func (tcp *TCPConn) Feed() (recs toyqueue.Records, err error) { - return tcp.inout.Feed() -} - -func (tcp *TCPConn) KeepTalking() { - talk_backoff := MIN_RETRY_PERIOD - conn_backoff := MIN_RETRY_PERIOD - - for { - conntime := time.Now() - go tcp.doWrite() - - // TODO: error handling - err := tcp.read() - slog.Error("couldn't read from conn", "err", err) - - if !tcp.Reconnect { - break - } - - atLeast5min := conntime.Add(time.Minute * 5) - if atLeast5min.After(time.Now()) { - talk_backoff *= 2 // connected, tried to talk, failed => wait more - if talk_backoff > MAX_RETRY_PERIOD { - talk_backoff = MAX_RETRY_PERIOD - } - } - - for { - if conn := tcp.conn.Load(); conn == nil { - break - } - - time.Sleep(conn_backoff + talk_backoff) - conn, err := net.Dial("tcp", tcp.addr) - if err != nil { - conn_backoff = conn_backoff * 2 - if conn_backoff > MAX_RETRY_PERIOD/2 { - conn_backoff = MAX_RETRY_PERIOD - } - } else { - tcp.conn.Store(&conn) - conn_backoff = MIN_RETRY_PERIOD - } - } - } -} - -func (tcp *TCPConn) Close() error { - tcp.outmx.Lock() - defer tcp.outmx.Unlock() - - // TODO writer closes on complete | 1 sec expired - if conn := tcp.conn.Swap(nil); conn != nil { - if err := (*conn).Close(); err != nil { - return err - } - - tcp.wake.Broadcast() - } - - return nil -} - -type Jack func(conn net.Conn) toyqueue.FeedDrainCloser - -// A TCP server/client for the use case of real-time async communication. -// Differently from the case of request-response (like HTTP), we do not -// wait for a request, then dedicating a thread to processing, then sending -// back the resulting response. Instead, we constantly fan sendQueue tons of -// tiny messages. That dictates different work patterns than your typical -// HTTP/RPC server as, for example, we cannot let one slow receiver delay -// event transmission to all the other receivers. -type TCPDepot struct { - conns map[string]*TCPConn - listens map[string]net.Listener - conmx sync.Mutex - jack Jack -} - -func (de *TCPDepot) Open(jack Jack) { - de.conmx.Lock() - de.conns = make(map[string]*TCPConn) - de.listens = make(map[string]net.Listener) - de.conmx.Unlock() - de.jack = jack -} - -func (de *TCPDepot) Close() error { - de.conmx.Lock() - defer de.conmx.Unlock() - - for _, lstn := range de.listens { - if err := lstn.Close(); err != nil { - return err - } - } - clear(de.listens) - - for _, con := range de.conns { - if err := con.Close(); err != nil { - return err - } - } - clear(de.conns) - - return nil -} - -// attrib?! -func (de *TCPDepot) Connect(addr string) (err error) { - conn, err := net.Dial("tcp", addr) - if err != nil { - return err - } - peer := TCPConn{ - addr: addr, - inout: de.jack(conn), - } - peer.wake.L = &peer.outmx - peer.conn.Store(&conn) - - de.conmx.Lock() - de.conns[addr] = &peer - de.conmx.Unlock() - - go peer.KeepTalking() - - return nil -} - -func (de *TCPDepot) DrainTo(recs toyqueue.Records, addr string) error { - de.conmx.Lock() - conn, ok := de.conns[addr] - de.conmx.Unlock() - if !ok { - return ErrAddressUnknown - } - return conn.Drain(recs) -} - -func (de *TCPDepot) Disconnect(addr string) (err error) { - de.conmx.Lock() - tcp, ok := de.conns[addr] - de.conmx.Unlock() - if !ok { - return ErrAddressUnknown - } - tcp.Close() - de.conmx.Lock() - delete(de.conns, addr) - de.conmx.Unlock() - return nil -} - -func (de *TCPDepot) Listen(addr string) (err error) { - listener, err := net.Listen("tcp", addr) - if err != nil { - return - } - de.conmx.Lock() - pre, ok := de.listens[addr] - if ok { - _ = pre.Close() - } - de.listens[addr] = listener - de.conmx.Unlock() - go de.KeepListening(addr) - return -} - -func (de *TCPDepot) StopListening(addr string) error { - de.conmx.Lock() - listener, ok := de.listens[addr] - delete(de.listens, addr) - de.conmx.Unlock() - if !ok { - return ErrAddressUnknown - } - return listener.Close() -} - -func (de *TCPDepot) KeepListening(addr string) { - for { - de.conmx.Lock() - listener, ok := de.listens[addr] - de.conmx.Unlock() - - if !ok { - break - } - conn, err := listener.Accept() - if err != nil { - break - } - addr := conn.RemoteAddr().String() - peer := TCPConn{ - addr: addr, - inout: de.jack(conn), - } - peer.wake.L = &peer.outmx - peer.conn.Store(&conn) - - de.conmx.Lock() - de.conns[addr] = &peer - de.conmx.Unlock() - - go peer.doWrite() - go peer.doRead() - } -} - -func ReadBuf(buf []byte, rdr io.Reader) ([]byte, error) { - avail := cap(buf) - len(buf) - if avail < 512 { - l := 4096 - if len(buf) > 2048 { - l = len(buf) * 2 - } - newbuf := make([]byte, l) - copy(newbuf[:], buf) - buf = newbuf[:len(buf)] - } - idle := buf[len(buf):cap(buf)] - n, err := rdr.Read(idle) - if err != nil { - return buf, err - } - if n == 0 { - return buf, io.EOF - } - buf = buf[:len(buf)+n] - return buf, nil -} - -func RoundPage(l int) int { - if (l & 0xfff) != 0 { - l = (l & ^0xfff) + 0x1000 - } - return l -} - -// AppendRead reads data from io.Reader into the *spare space* of the provided buffer, -// i.e. those cap(buf)-len(buf) vacant bytes. If the spare space is smaller than -// lenHint, allocates (as reading less bytes might be unwise). -func AppendRead(buf []byte, rdr io.Reader, lenHint int) ([]byte, error) { - avail := cap(buf) - len(buf) - if avail < lenHint { - want := RoundPage(len(buf) + lenHint) - newbuf := make([]byte, want) - copy(newbuf[:], buf) - buf = newbuf[:len(buf)] - } - idle := buf[len(buf):cap(buf)] - n, err := rdr.Read(idle) - if err != nil { - return buf, err - } - if n == 0 { - return buf, io.EOF - } - buf = buf[:len(buf)+n] - return buf, nil -} diff --git a/toytlv/tlv.go b/toytlv/tlv.go index 6122811..04c5f3a 100644 --- a/toytlv/tlv.go +++ b/toytlv/tlv.go @@ -10,6 +10,8 @@ import ( const CaseBit uint8 = 'a' - 'A' var ( + ErrAddressDuplicated = errors.New("the address already used") + ErrIncomplete = errors.New("incomplete data") ErrBadRecord = errors.New("bad TLV record format") ErrAddressUnknown = errors.New("address unknown") diff --git a/toytlv/transport.go b/toytlv/transport.go new file mode 100644 index 0000000..35a6a49 --- /dev/null +++ b/toytlv/transport.go @@ -0,0 +1,178 @@ +package toytlv + +import ( + "context" + "log/slog" + "net" + "sync" + + "github.com/drpcorg/chotki/toyqueue" +) + +// A TCP/TLS/QUIC server/client for the use case of real-time async communication. +// Differently from the case of request-response (like HTTP), we do not +// wait for a request, then dedicating a thread to processing, then sending +// back the resulting response. Instead, we constantly fan sendQueue tons of +// tiny messages. That dictates different work patterns than your typical +// HTTP/RPC server as, for example, we cannot let one slow receiver delay +// event transmission to all the other receivers. +type Transport struct { + closed bool + mu sync.Mutex + jack Jack + conns map[string]*Peer + listens map[string]net.Listener +} + +func NewTransport(jack Jack) *Transport { + return &Transport{ + jack: jack, + conns: make(map[string]*Peer), + listens: make(map[string]net.Listener), + } +} + +func (de *Transport) Close() error { + de.mu.Lock() + defer de.mu.Unlock() + + de.jack = nil + de.closed = true + + for _, lstn := range de.listens { + if err := lstn.Close(); err != nil { + return err + } + } + clear(de.listens) + + for _, con := range de.conns { + if err := con.Close(); err != nil { + return err + } + } + clear(de.conns) + + return nil +} + +func (n *Transport) Connect(ctx context.Context, addr string) (err error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return err + } + + peer := Peer{ + inout: n.jack(conn), + reconnect: func() (net.Conn, error) { + return net.Dial("tcp", addr) + }, + } + peer.co.L = &peer.mu + peer.conn.Store(&conn) + + n.mu.Lock() + n.conns[addr] = &peer + n.mu.Unlock() + + go peer.keepTalking() + + return nil +} + +func (de *Transport) Disconnect(addr string) (err error) { + de.mu.Lock() + defer de.mu.Unlock() + + conn, ok := de.conns[addr] + if !ok { + return ErrAddressUnknown + } + + conn.Close() + delete(de.conns, addr) + + return nil +} + +func (n *Transport) Listen(ctx context.Context, addr string) error { + listener, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + n.mu.Lock() + defer n.mu.Unlock() + + if _, ok := n.listens[addr]; ok { + return ErrAddressDuplicated + } + + n.listens[addr] = listener + go n.keepListening(ctx, addr) + + return nil +} + +func (de *Transport) Unlisten(addr string) error { + de.mu.Lock() + defer de.mu.Unlock() + + listener, ok := de.listens[addr] + if !ok { + return ErrAddressUnknown + } + + delete(de.listens, addr) + return listener.Close() +} + +func (de *Transport) drainTo(recs toyqueue.Records, addr string) error { + de.mu.Lock() + conn, ok := de.conns[addr] + de.mu.Unlock() + if !ok { + return ErrAddressUnknown + } + return conn.Drain(recs) +} + +func (de *Transport) keepListening(ctx context.Context, addr string) { + for { + de.mu.Lock() + closed := de.closed + listener, ok := de.listens[addr] + de.mu.Unlock() + + if closed || !ok { + break + } + + select { + case <-ctx.Done(): + de.Close() + break + + default: + // continue + } + + conn, err := listener.Accept() + if err != nil { + // reconnects are the client's responsibility, just skip + slog.Error("couldn't accept connect request", "err", err) + continue + } + + addr := conn.RemoteAddr().String() + peer := Peer{ inout: de.jack(conn) } + peer.co.L = &peer.mu + peer.conn.Store(&conn) + + de.mu.Lock() + de.conns[addr] = &peer + de.mu.Unlock() + + go peer.keepTalking() + } +} diff --git a/toytlv/tcp_test.go b/toytlv/transport_test.go similarity index 86% rename from toytlv/tcp_test.go rename to toytlv/transport_test.go index 35e9251..6f80ad9 100644 --- a/toytlv/tcp_test.go +++ b/toytlv/transport_test.go @@ -1,6 +1,7 @@ package toytlv import ( + "context" "net" "sync" "sync/atomic" @@ -49,11 +50,10 @@ func TestTCPDepot_Connect(t *testing.T) { tc := TestConsumer{} tc.co.L = &tc.mx - depot := TCPDepot{} var addr atomic.Value addr.Store("") - depot.Open(func(conn net.Conn) toyqueue.FeedDrainCloser { + depot := NewTransport(func(conn net.Conn) toyqueue.FeedDrainCloser { a := conn.RemoteAddr().String() if a != loop { addr.Store(a) @@ -61,15 +61,15 @@ func TestTCPDepot_Connect(t *testing.T) { return &tc }) - err := depot.Listen(loop) + err := depot.Listen(context.Background(), loop) assert.Nil(t, err) - err = depot.Connect(loop) + err = depot.Connect(context.Background(), loop) assert.Nil(t, err) // send a record recsto := toyqueue.Records{Record('M', []byte("Hi there"))} - err = depot.DrainTo(recsto, loop) + err = depot.drainTo(recsto, loop) assert.Nil(t, err) rec, err := tc.Feed() assert.Nil(t, err) @@ -80,7 +80,7 @@ func TestTCPDepot_Connect(t *testing.T) { // respond to that recsback := toyqueue.Records{Record('M', []byte("Re: Hi there"))} - err = depot.DrainTo(recsback, addr.Load().(string)) + err = depot.drainTo(recsback, addr.Load().(string)) assert.Nil(t, err) rerec, err := tc.Feed() assert.Nil(t, err)