Marshaler and Unmarshaler interfaces for raw bencode handling

Decode can make short error-less reads without blocking.
Null strings parsed properly.
Encode can handle nil values.
This commit is contained in:
Emery Hemingway 2014-02-08 19:20:26 -05:00
parent aa92f3ed80
commit aeb108691f
6 changed files with 373 additions and 204 deletions

View File

@ -4,6 +4,7 @@ import (
"bytes"
"encoding"
"fmt"
"io"
"net"
"reflect"
"testing"
@ -31,13 +32,15 @@ var tests = []test{
{in: `i2e`, ptr: new(interface{}), out: int64(2)},
{in: "i0e", ptr: new(interface{}), out: int64(0)},
{in: "i0e", ptr: new(int), out: 0},
{in: "0:", ptr: new(string), out: ""},
{in: "1:a", ptr: new(string), out: "a"},
{in: "2:a\"", ptr: new(string), out: "a\""},
{in: "3:abc", ptr: new([]byte), out: []byte("abc")},
{in: "11:0123456789a", ptr: new(interface{}), out: []byte("0123456789a")},
{in: "le", ptr: new([]int64), out: []int64{}},
{in: "li1ei2ee", ptr: new([]int), out: []int{1, 2}},
{in: "l3:abc3:defe", ptr: new([]string), out: []string{"abc", "def"}},
{in: "l3:abc3:def0:e", ptr: new([]string), out: []string{"abc", "def", ""}},
//{in: "li42e3:abce", ptr: new([]interface{}), out: []interface{}{42, []byte("abc")}},
{in: "de", ptr: new(map[string]interface{}), out: make(map[string]interface{})},
{in: "d3:cati1e3:dogi2ee", ptr: new(map[string]int), out: map[string]int{"cat": 1, "dog": 2}},
@ -45,25 +48,6 @@ var tests = []test{
{in: "d1:i3:1231:m9:Arith.Adde", ptr: new(request), out: request{"Arith.Add", nil, "123"}},
}
type request struct {
Method string `bencode:"m"`
Params interface{} `bencode:"p,omitempty"`
Id string `bencode:"i"`
}
func TestRequest(t *testing.T) {
var buf bytes.Buffer
var req request
dec := NewDecoder(&buf)
for i := 0; i < 8; i++ {
fmt.Fprintf(&buf, "d1:ii%de1:m9:Arith.Add1:pd1:Ai%de1:Bi%deee",
i, i, i+1)
if err := dec.Decode(&req); err != nil {
t.Fatal(err)
}
}
}
var afs = []byte("d18:availableFunctionsd18:AdminLog_subscribed4:filed8:requiredi0e4:type6:Stringe5:leveld8:requiredi0e4:type6:Stringe4:lined8:requiredi0e4:type3:Intee20:AdminLog_unsubscribed8:streamIdd8:requiredi1e4:type6:Stringee18:Admin_asyncEnabledde24:Admin_availableFunctionsd4:paged8:requiredi0e4:type3:Intee34:InterfaceController_disconnectPeerd6:pubkeyd8:requiredi1e4:type6:Stringee29:InterfaceController_peerStatsd4:paged8:requiredi0e4:type3:Intee17:SwitchPinger_pingd4:datad8:requiredi0e4:type6:Stringe4:pathd8:requiredi1e4:type6:Stringe7:timeoutd8:requiredi0e4:type3:Intee16:UDPInterface_newd11:bindAddressd8:requiredi0e4:type6:Stringeee4:morei1e4:txid8:c37b0faae")
type outer struct {
@ -71,17 +55,6 @@ type outer struct {
Txid string `bencode:"txid"`
}
func TestSkip(t *testing.T) {
o := new(outer)
err := Unmarshal(afs, o)
if err != nil {
t.Fatal("error unmarshaling nested struct,", err)
}
if o.Txid != "c37b0faa" {
t.Errorf("got txid %q", o.Txid)
}
}
func TestMarshal(t *testing.T) {
buf := new(bytes.Buffer)
enc := NewEncoder(buf)
@ -132,16 +105,44 @@ func TestUnmarshal(t *testing.T) {
}
}
func TestSkip(t *testing.T) {
o := new(outer)
err := Unmarshal(afs, o)
if err != nil {
t.Fatal("error unmarshaling nested struct,", err)
}
if o.Txid != "c37b0faa" {
t.Errorf("got txid %q", o.Txid)
}
}
type request struct {
Method string `bencode:"m"`
Params interface{} `bencode:"p,omitempty"`
Id string `bencode:"i"`
}
func TestRequest(t *testing.T) {
var buf bytes.Buffer
var req request
dec := NewDecoder(&buf)
for i := 0; i < 8; i++ {
fmt.Fprintf(&buf, "d1:ii%de1:m9:Arith.Add1:pd1:Ai%de1:Bi%deee",
i, i, i+1)
if err := dec.Decode(&req); err != nil {
t.Fatal(err)
}
}
}
type nestA struct {
A int
B int
C *nestB
A, B int
C *nestB
}
type nestB struct {
D int
E int
F *nestA `bencode:",omitempty"`
D, E int
F *nestA `bencode:",omitempty"`
}
func TestNesting(t *testing.T) {
@ -233,6 +234,95 @@ func TestTextInterface(t *testing.T) {
}
}
type cooked struct {
A, B, C int
}
type raw struct {
A, B int
C *RawMessage
}
func TestRawMessage(t *testing.T) {
var buf bytes.Buffer
enc := NewEncoder(&buf)
dec := NewDecoder(&buf)
c := &cooked{1, 2, 3}
err := enc.Encode(c)
if err != nil {
t.Error("encoding cooked:", err)
}
r := new(raw)
err = dec.Decode(r)
if err != nil {
t.Error("decoding to raw:", err)
}
if len(*r.C) == 0 {
t.Fatal("nested RawMessage had zero length")
}
var x int
err = Unmarshal(*r.C, &x)
if err != nil {
t.Error("decoding RawMessage:", err)
}
if x != 3 {
t.Errorf("RawMessage: want %d, got %d", 3, x)
}
err = enc.Encode(r)
if err != nil {
t.Error("encoding RawMessage:", err)
}
err = dec.Decode(c)
if err != nil {
t.Error("decoding encoding RawMessage:", err)
}
if c.C != 3 {
t.Error("mismatch")
}
}
func TestPipe(t *testing.T) {
r, w := io.Pipe()
dec := NewDecoder(r)
enc := NewEncoder(w)
for i, tt := range tests {
if tt.ptr == nil {
continue
}
v := reflect.New(reflect.TypeOf(tt.ptr).Elem())
go enc.Encode(tt.out)
if err := dec.Decode(v.Interface()); err != nil {
t.Errorf("#%d: %q %v want %v", i, tt.in, err, tt.err)
}
if !reflect.DeepEqual(v.Elem().Interface(), tt.out) {
t.Errorf("#%d: %s mismatch\nhave: %#+v\nwant: %#+v", i, tt.in, v.Elem().Interface(), tt.out)
}
}
}
type serverResponse struct {
Error interface{} `bencode:"e"` //,omitempty"`
Id *RawMessage `bencode:"i"`
Result interface{} `bencode:"r"`
}
func TestEncodeNil(t *testing.T) {
var resp serverResponse
resp.Error = nil
resp.Result = nil
_, err := Marshal(resp)
if err != nil {
t.Error("EncodeNil:", err)
}
}
type benchmarkStruct struct {
Q string `bencode:"q"`
AQ string `bencode:"aq,omitempty"`

168
decode.go
View File

@ -56,31 +56,31 @@ func (dec *Decoder) Buffered() io.Reader {
func (dec *Decoder) readValue() (int, error) {
dec.scan.reset()
scanp := 0
var scanp, op, n int
var err error
Input:
for {
// Look in the buffer for a new value.
buf := dec.buf[scanp:]
for i := 0; i < len(buf); i++ {
op := dec.scan.step(&dec.scan, int(buf[i]))
if op > 0 {
for scanp < len(dec.buf) {
op = dec.scan.step(&dec.scan, int(dec.buf[scanp]))
scanp++
if op >= 0 {
dec.scan.bytes += int64(op)
i += op
continue
}
dec.scan.bytes++
scanp += op
if dec.scan.endTop {
break Input
}
} else {
dec.scan.bytes++
switch op {
case scanEnd:
break Input
if op == scanEnd {
scanp += i
break Input
}
if op == scanError {
dec.err = dec.scan.err
return 0, dec.scan.err
case scanError:
dec.err = dec.scan.err
return 0, dec.scan.err
}
}
}
scanp = len(dec.buf)
// Did the last read have an error?
// Delayed until now to allow buffer scan.
@ -102,7 +102,6 @@ Input:
}
// Read. Delay error for the next interation (after scan).
var n int
n, err = dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)])
dec.buf = dec.buf[0 : len(dec.buf)+n]
}
@ -127,7 +126,6 @@ func (e *InvalidUnmarshalError) Error() string {
}
func (d *decodeState) unmarshal(v interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
@ -184,7 +182,7 @@ func (d *decodeState) error(err error) {
panic(err)
}
// skip reads d.data until it hits the given op code
// skip reads d.data with a fresh scanner, skimming over the next value
func (d *decodeState) skip() {
var skipScan scanner
skipScan.reset()
@ -193,33 +191,41 @@ func (d *decodeState) skip() {
var op int
for {
if d.off > len(d.data) {
d.error(errors.New("reached end of data"))
}
op = skipScan.step(&skipScan, int(d.data[d.off]))
if op > 0 {
d.off += op
}
switch op {
case scanEnd:
return
case scanError:
d.error(skipScan.err)
}
d.off++
if op >= 0 {
d.off += op
if skipScan.endTop {
return
}
} else {
switch op {
case scanEnd:
return
case scanError:
d.error(skipScan.err)
}
}
}
}
// readValue decodes the next item from d.data[d.off:] into v,
// updating d.off.
var (
unmarshalerType = reflect.TypeOf(new(Unmarshaler)).Elem()
textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
)
// value decodes the next item from d.data[d.off:] into v, updating d.off.
func (d *decodeState) value(v reflect.Value) {
if !v.IsValid() {
d.skip()
return
}
if v.Type().Implements(unmarshalerType) {
d.unmarshaler(v)
return
}
op := d.scan.step(&d.scan, int(d.data[d.off]))
d.off++
@ -288,14 +294,14 @@ Read:
c = int(d.data[d.off])
d.off++
switch d.scan.step(&d.scan, c) {
case scanParseInteger:
continue
case scanEndInteger, scanEnd:
break Read
case scanError:
d.error(d.scan.err)
default:
d.error(errPhase)
case scanParseInteger:
continue
case scanEndInteger:
break Read
}
}
return d.data[i : d.off-1]
@ -374,7 +380,11 @@ Read:
c = int(d.data[d.off])
d.off++
op = d.scan.step(&d.scan, c)
if op < 0 {
if op >= 0 {
i = d.off
d.off += op
break Read
} else {
switch op {
case scanParseStringLen, scanParseString:
continue
@ -383,17 +393,11 @@ Read:
default:
d.error(errPhase)
}
} else { // op was a string length
i = d.off
d.off += op
break Read
}
}
return d.data[i:d.off]
}
var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
// string consumes a string from d.data[d.off:], decoding into the value v.
func (d *decodeState) string(v reflect.Value) {
for {
@ -473,7 +477,7 @@ Read:
op = d.scan.step(&d.scan, c)
switch op {
case scanEndList:
case scanEndList, scanEnd:
break Read
}
@ -626,18 +630,20 @@ Read:
c = int(d.data[d.off])
d.off++
op = d.scan.step(&d.scan, c)
if op < 0 {
if op > 0 {
p = d.off
d.off += op
break ReadKey
} else {
switch op {
case scanEndDict:
case scanEndDict, scanEnd:
break Read
case scanBeginKeyLen, scanParseKeyLen, scanParseKey:
case scanEndKeyLen:
p = d.off
default:
d.error(errPhase)
}
} else { //op was the key string length
p = d.off
d.off += op
break ReadKey
}
}
key = string(d.data[p:d.off])
@ -704,18 +710,20 @@ Read:
c = int(d.data[d.off])
d.off++
op = d.scan.step(&d.scan, c)
if op < 0 {
if op > 0 {
p = d.off
d.off += op
break ReadKey
} else {
switch op {
case scanEndDict:
break Read
case scanBeginKeyLen, scanParseKeyLen, scanParseKey:
case scanEndKeyLen:
p = d.off
default:
d.error(errPhase)
}
} else { // op was a string length
p = d.off
d.off += op
break ReadKey
}
}
key = string(d.data[p:d.off])
@ -724,6 +732,46 @@ Read:
return m
}
// unmarshaler reads raw bencode into an Unmarshaler.
func (d *decodeState) unmarshaler(v reflect.Value) {
if v.Kind() == reflect.Ptr && v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
u := v.Interface().(Unmarshaler)
var tmpScan scanner
tmpScan.reset()
tmpScan.step = d.scan.step
d.scan.step = stateEndValue
start := d.off
var op int
ReadRaw:
for {
if d.off > len(d.data) {
d.error(errors.New("readed end of data"))
}
op = tmpScan.step(&tmpScan, int(d.data[d.off]))
if op > 0 {
d.off += op + 1
} else {
d.off++
switch op {
case scanEnd:
break ReadRaw
case scanError:
d.error(tmpScan.err)
}
}
}
if err := u.UnmarshalBencode(d.data[start:d.off]); err != nil {
d.error(err)
}
}
// Unmarshal parses the bencode-encoded data and stores the
// result in the value pointed to by v.
//

View File

@ -191,13 +191,34 @@ func isEmptyValue(v reflect.Value) bool {
return false
}
var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
var (
marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
)
// reflectValue writes the value in v to the output.
func (e *encodeState) reflectValue(v reflect.Value) {
if !v.IsValid() {
e.Write([]byte{'0', ':'})
return
}
if v.Type().Implements(marshalerType) {
m := v.Interface().(Marshaler)
b, err := m.MarshalBencode()
if err == nil {
_, err = e.Write(b)
}
if err != nil {
e.error(err)
}
return
}
if v.Type().Implements(textMarshalerType) {
u := v.Interface().(encoding.TextMarshaler)
b, err := u.MarshalText()
m := v.Interface().(encoding.TextMarshaler)
b, err := m.MarshalText()
if err != nil {
e.error(err)
}
@ -262,7 +283,6 @@ func (e *encodeState) reflectValue(v reflect.Value) {
e.reflectValue(v.Index(i))
}
e.WriteByte('e')
return
case reflect.Interface, reflect.Ptr:
e.reflectValue(v.Elem())

41
interface.go Normal file
View File

@ -0,0 +1,41 @@
package bencode
import (
"errors"
)
// Marshaler is the interface implemented by objects that
// can marshal themselves into valid bencode.
type Marshaler interface {
MarshalBencode() ([]byte, error)
}
// Unmarshaler is the interface implemented by objects that
// can unmarshal themselves from bencode.
type Unmarshaler interface {
UnmarshalBencode([]byte) error
}
// RawMessage is a raw encoded bencode object.
// It is intedended to delay decoding or precomute an encoding.
type RawMessage []byte
// MarshalText returns *m as the bencode encoding of m.
func (m *RawMessage) MarshalBencode() ([]byte, error) {
if m == nil {
return []byte{'0', ':'}, nil
}
return *m, nil
}
// UnmarshalText sets *m to a copy of data.
func (m *RawMessage) UnmarshalBencode(text []byte) error {
if m == nil {
return errors.New("bencode.RawMessage: UnmarshalText on nil pointer")
}
*m = append((*m)[0:0], text...)
return nil
}
var _ Marshaler = (*RawMessage)(nil)
var _ Unmarshaler = (*RawMessage)(nil)

27
raw.go
View File

@ -1,27 +0,0 @@
package bencode
import (
"encoding"
"errors"
)
// RawMessage is a raw encoded bencode object.
// It is intedended to delay decoding or precomute an encoding.
type RawMessage []byte
// MarshalText returns *m as the bencode encoding of m.
func (m *RawMessage) MarshalText() ([]byte, error) {
return *m, nil
}
// UnmarshalText sets *m to a copy of data.
func (m *RawMessage) UnmarshalText(text []byte) error {
if m == nil {
return errors.New("bencode.RawMessage: UnmarshalText on nil pointer")
}
*m = append((*m)[0:0], text...)
return nil
}
var _ encoding.TextMarshaler = (*RawMessage)(nil)
var _ encoding.TextUnmarshaler = (*RawMessage)(nil)

View File

@ -88,13 +88,13 @@ type scanner struct {
// every subsequent call will retern scanError too.
const (
// Continue.
scanBeginInteger = 0 - iota
scanParseInteger = 0 - iota
scanEndInteger = 0 - iota
scanBeginStringLen = 0 - iota
scanParseStringLen = 0 - iota
scanEndStringLen = 0 - iota
scanParseString = 0 - iota
scanBeginInteger = 0 - iota
scanParseInteger = 0 - iota
scanEndInteger = 0 - iota
scanBeginList = 0 - iota
scanEndList = 0 - iota
scanEndValue = 0 - iota
@ -116,7 +116,7 @@ const (
// These values are stored in the parseState stack.
// They give the current state of a composite value
// being scanned. If the parser is inside a nested value
// the parseSTate describes the nested state, outermost at entry 0.
// the parseState describes the nested state, outermost at entry 0.
const (
parseInteger = iota // parsing an integer
parseString // parsing a string
@ -176,11 +176,9 @@ func (s *scanner) popParseState() {
n := len(s.parseState) - 1
s.parseState = s.parseState[0:n]
if n == 0 {
s.step = stateEndTop
s.endTop = true
} else {
s.step = stateEndValue
}
s.step = stateEndValue
}
// stateBeginValue is the state at the beginning of the input.
@ -191,12 +189,12 @@ func stateBeginValue(s *scanner, c int) int {
s.pushParseState(parseInteger)
return scanBeginInteger
case 'l':
s.step = stateBeginList
s.step = stateBeginListValue
s.pushParseState(parseListValue)
return scanBeginList
case 'd':
s.step = stateBeginDictKey
s.pushParseState(parseDictKey)
s.pushParseState(parseDictValue)
return scanBeginDict
}
@ -209,50 +207,13 @@ func stateBeginValue(s *scanner, c int) int {
return s.error(c, "looking for beginning of value")
}
// stateEndValue is the state after completing a value,
// such as after reading 'e' or finishing a string.
func stateEndValue(s *scanner, c int) int {
n := len(s.parseState)
if n == 0 {
// Completed top-level before the current byte.
s.step = stateEndTop
s.endTop = true
return stateEndTop(s, c)
}
ps := s.parseState[n-1]
switch ps {
case parseDictKey:
s.parseState[n-1] = parseDictValue
s.step = stateBeginValue
return scanDictValue
case parseDictValue:
s.popParseState()
s.step = stateBeginDictKey
return stateBeginDictKey(s, c)
case parseListValue:
if c == 'e' {
s.popParseState()
return scanEndList
}
s.step = stateBeginValue
//s.pushParseState(parseListValue)
return stateBeginValue(s, c)
}
return s.error(c, "")
}
// stateEndTop is the state after finishing the top-level value,
// such as after finishing a dictionary or list.
func stateEndTop(s *scanner, c int) int {
return scanEnd
}
// stateParseInteger is the state after reading an `i`.
func stateParseInteger(s *scanner, c int) int {
if c == 'e' {
s.popParseState()
if s.endTop {
return scanEnd
}
return scanEndInteger
}
if (c >= '0' && c <= '9') || c == '-' {
@ -263,15 +224,18 @@ func stateParseInteger(s *scanner, c int) int {
func stateParseStringLen(s *scanner, c int) int {
if c == ':' {
var err error
var l int
if l, err = strconv.Atoi(string(s.strLenB)); err != nil {
l, err := strconv.Atoi(string(s.strLenB))
if err != nil {
s.err = err
return scanError
}
// decoder should read this string as a slice
s.popParseState()
s.step = stateEndValue
// BUG(emery): undefined behavior with top level strings
// this is a problem, if this string is a top-level object,
// the fact that this scanner has reached the end isn't communicated.
// I guess I could shift the string length and set scanEnd bits
return l
}
if c >= '0' && c <= '9' {
@ -281,48 +245,81 @@ func stateParseStringLen(s *scanner, c int) int {
return s.error(c, "in string length")
}
func stateBeginList(s *scanner, c int) int {
func stateBeginListValue(s *scanner, c int) int {
if c == 'e' {
s.popParseState()
if s.endTop {
return scanEnd
}
return scanEndList
}
return stateBeginValue(s, c)
}
func stateBeginDictKey(s *scanner, c int) int {
if c == 'e' {
s.popParseState() // pop parseDictValue
if s.endTop {
return scanEnd
}
return scanEndDict
}
if c >= '0' && c <= '9' {
s.strLenB = append(s.strLenB[0:0], byte(c))
s.step = stateParseKeyLen
return scanBeginKeyLen
}
return s.error(c, "in start of dictionary key length")
}
func stateParseKeyLen(s *scanner, c int) int {
if c == ':' {
var err error
var l int
if l, err = strconv.Atoi(string(s.strLenB)); err != nil {
l, err := strconv.Atoi(string(s.strLenB))
if err != nil {
s.err = err
return scanError
}
// decoder should read this chunk at once
s.popParseState()
s.step = stateBeginValue
s.pushParseState(parseDictValue)
return l
}
if c >= '0' && c <= '9' {
s.strLenB = append(s.strLenB, byte(c))
return scanParseKeyLen
}
return s.error(c, "in dicionary key length")
}
func stateBeginDictKey(s *scanner, c int) int {
if c == 'e' {
if len(s.parseState) == 0 {
return scanEnd
}
s.popParseState()
return scanEndDict
}
if c >= '0' && c <= '9' {
s.strLenB = append(s.strLenB[0:0], byte(c))
s.step = stateParseKeyLen
s.pushParseState(parseDictKey)
return scanBeginKeyLen
}
return s.error(c, "in dictionary key length")
}
// stateEndValue is the state after completing a value,
// such as after reading 'e' or finishing a string.
func stateEndValue(s *scanner, c int) int {
n := len(s.parseState)
if n == 0 {
// Completed top-level before the current byte.
s.step = stateBeginValue
s.endTop = true
return scanEnd
}
ps := s.parseState[n-1]
switch ps {
case parseDictKey:
s.step = stateBeginValue
return scanDictValue
case parseDictValue:
s.step = stateBeginDictKey
return stateBeginDictKey(s, c)
case parseListValue:
if c == 'e' {
s.popParseState()
if s.endTop {
return scanEnd
}
return scanEndList
}
s.step = stateBeginValue
return stateBeginValue(s, c)
}
return s.error(c, "")
}