diff --git a/internal/plain/plain.go b/internal/plain/plain.go index 72a0bf1..656ebe2 100644 --- a/internal/plain/plain.go +++ b/internal/plain/plain.go @@ -94,6 +94,35 @@ type Coder interface { Decoder } +func TestRoundTrip(c Coder, verbose bool) error { + var buf = new(bytes.Buffer) + if err := c.PlainEncode(buf); err != nil { + return fmt.Errorf("round trip, enc1: %w", err) + } + d := buf.Bytes() + s1 := string(d) + if verbose { + fmt.Printf("encoded\n```\n%s```\n", s1) + } + buf = bytes.NewBuffer(d) + if err := c.PlainDecode(buf); err != nil { + return fmt.Errorf("round trip, dec: %w", err) + } + buf = new(bytes.Buffer) + if err := c.PlainEncode(buf); err != nil { + return fmt.Errorf("round trip, enc2: %w", err) + } + d = buf.Bytes() + s2 := string(d) + if verbose { + fmt.Printf("encode2\n```\n%s```\n", s2) + } + if s1 != s2 { + return fmt.Errorf("\n%s\n!=\n%s\n", s1, s2) + } + return nil +} + func EncodeDecode(c Coder) error { var buf = new(bytes.Buffer) if err := c.PlainEncode(buf); err != nil { diff --git a/typeset/node.go b/typeset/node.go index d986827..0b5cf1f 100644 --- a/typeset/node.go +++ b/typeset/node.go @@ -15,7 +15,6 @@ package typeset import ( - "bufio" "fmt" "io" ) @@ -42,13 +41,16 @@ type named struct { } func (n named) PlainEncode(w io.Writer) error { - _, err := fmt.Fprintf(w, "%s:%08x", n.name, n.typ) + _, err := fmt.Fprintf(w, "%s %08x", n.name, n.typ) return err } func (n *named) PlainDecode(r io.Reader) error { - _, err := fmt.Fscanf(r, "%s:%08x", &n.name, &n.typ) - return err + m, err := fmt.Fscanf(r, "%s %08x", &n.name, &n.typ) + if err != nil { + return fmt.Errorf("decode named n=%d: %w", m, err) + } + return nil } func (n *node) PlainEncode(w io.Writer) error { @@ -83,7 +85,7 @@ func (n *node) PlainEncode(w io.Writer) error { if err != nil { return err } - err = wrapJoinEncode(w, "(", ", ", ")", n.params) + err = wrapJoinEncode(w, "(", ", ", ") ", n.params) if err != nil { return err } @@ -118,14 +120,14 @@ func (n *node) PlainDecode(r io.Reader) error { case Func: err = n.decodeFunc(r) } - return nil + return err } func (n *node) decodeFunc(r io.Reader) error { buf := make([]byte, 2) _, err := io.ReadFull(r, buf) if err != nil { - return err + return fmt.Errorf("fn hdr: %w", err) } n.key = NoType switch buf[0] { @@ -133,7 +135,7 @@ func (n *node) decodeFunc(r io.Reader) error { case 'm': _, err = fmt.Fscanf(r, "%08x", &n.key) if err != nil { - return err + return fmt.Errorf("key %w", err) } default: return fmt.Errorf("unexpected '%s'", string(buf)) @@ -146,11 +148,14 @@ func (n *node) decodeFunc(r io.Reader) error { default: return fmt.Errorf("unexpected '%s'", string(buf)) } - err = wrapJoinDecode(r, "(", ", ", ")", &n.params) + err = wrapJoinDecode(r, "(", ", ", ") ", &n.params) if err != nil { - return err + return fmt.Errorf("params %w", err) } err = wrapJoinDecode(r, "(", ", ", ")", &n.results) + if err != nil { + return fmt.Errorf("results %w", err) + } return nil } @@ -179,26 +184,61 @@ func wrapJoinEncode(w io.Writer, left, sep, right string, elts []named) error { return err } +type bb struct { + r io.Reader + b []byte // size 1 + buf bool +} + +func newbb(r io.Reader) *bb { + return &bb{r: r, b: make([]byte, 1), buf: false} +} + +func (bb *bb) Read(d []byte) (int, error) { + n := 0 + var nn int + var err error + if bb.buf && len(d) > 0 { + d[0] = bb.b[0] + bb.buf = false + d = d[1:] + n++ + nn, err = bb.r.Read(d) + } else { + nn, err = bb.r.Read(d) + } + return nn + n, err +} + +func (bb *bb) PeekByte() (byte, error) { + if bb.buf { + return 0, fmt.Errorf("already read %s", string(bb.b)) + } + + _, err := io.ReadFull(bb.r, bb.b) + if err != nil { + return 0, err + } + bb.buf = true + return bb.b[0], nil +} + func wrapJoinDecode(r io.Reader, left, sep, right string, elts *[]named) error { - br := bufio.NewReader(r) - _ = br + br := newbb(r) err := expect(r, left) if err != nil { return err } for { - b, err := br.ReadByte() + b, err := br.PeekByte() if err != nil { - return err - } - if err = br.UnreadByte(); err != nil { - return err + return fmt.Errorf("peek %w", err) } switch b { case sep[0]: - err = expect(r, sep) + err = expect(br, sep) case right[0]: - return expect(r, right) + return expect(br, right) default: n := len(*elts) *elts = append(*elts, named{}) @@ -206,7 +246,7 @@ func wrapJoinDecode(r io.Reader, left, sep, right string, elts *[]named) error { err = elt.PlainDecode(br) } if err != nil { - return err + return fmt.Errorf("sep or dec: %w\n", err) } } } diff --git a/typeset/plain.go b/typeset/plain.go index a529fcd..fe8acee 100644 --- a/typeset/plain.go +++ b/typeset/plain.go @@ -28,9 +28,12 @@ func (t *T) PlainEncode(w io.Writer) error { for i := int(_endType); i < N; i++ { node := &t.nodes[i] if err = node.PlainEncode(w); err != nil { - return err + return fmt.Errorf("typeset encode node %d: %w", i, err) } _, err = fmt.Fprintf(w, "\n") + if err != nil { + return fmt.Errorf("typeset encode eol %d: %w", i, err) + } } return nil } @@ -41,7 +44,7 @@ func (t *T) PlainDecode(r io.Reader) error { var H int _, err := fmt.Fscanf(r, "%d:%d\n", &N, &H) if err != nil { - return err + return fmt.Errorf("typeset decode hdr: %w", err) } tt.hash = make([]Type, H) copy(tt.hash[:_endType], t.hash[:_endType]) @@ -53,14 +56,14 @@ func (t *T) PlainDecode(r io.Reader) error { node := &tt.nodes[ty] node.zero() if err = node.PlainDecode(r); err != nil { - return err + return fmt.Errorf("typeset decode node %d: %w", ty, err) } - _, err = io.ReadFull(r, eol) + n, err := io.ReadFull(r, eol) if err != nil { - return err + return fmt.Errorf("typeset decode eol %d: %d %w", ty, n, err) } if eol[0] != byte('\n') { - return fmt.Errorf("expected eol") + return fmt.Errorf("expected eol got '%s'", string(eol)) } node.hash = tt.hashCode(ty) hi := node.hash % uint32(H) diff --git a/typeset/t_test.go b/typeset/t_test.go index e7127e1..c7ba35d 100644 --- a/typeset/t_test.go +++ b/typeset/t_test.go @@ -16,9 +16,11 @@ package typeset import ( "bytes" + "go/token" "go/types" - "os" "testing" + + "github.com/go-air/pal/internal/plain" ) func TestTypeSetGrow(t *testing.T) { @@ -28,7 +30,27 @@ func TestTypeSetGrow(t *testing.T) { for i := 0; i < 2025; i++ { palBase = ts.getPointer(palBase) } - ts.PlainEncode(os.Stdout) + //ts.PlainEncode(os.Stdout) + if err := plain.TestRoundTrip(ts, false); err != nil { + t.Error(err) + } +} + +func TestTypeSetFunc(t *testing.T) { + ts := New() + params := types.NewTuple( + types.NewVar(token.NoPos, nil, "p1", types.Typ[types.Int64]), + types.NewVar(token.NoPos, nil, "p2", types.Typ[types.Int64]), + types.NewVar(token.NoPos, nil, "p3", types.Typ[types.Float64])) + results := types.NewTuple( + types.NewVar(token.NoPos, nil, "r1", types.Typ[types.Int64]), + types.NewVar(token.NoPos, nil, "r2", types.Typ[types.Int64])) + sig := types.NewSignature(nil, params, results, false) + _ = ts.FromGoType(sig) + //ts.PlainEncode(os.Stdout) + if err := plain.TestRoundTrip(ts, false); err != nil { + t.Error(err) + } } func TestTypeSet(t *testing.T) {