encoding/decode.go
Emery Hemingway e74029f1c2 New Id type
2013-09-13 13:50:02 -04:00

418 lines
7.8 KiB
Go

// Copyright (c) 2013, Emery Hemingway.
// See the LICENSE file for terms and conditions.
package ebml
import (
"bytes"
"errors"
"io"
"reflect"
"runtime"
"strconv"
"sync"
)
// Unmarshal parses the EBML-encoded data and stores the result
// in the value pointed to by v.
func Unmarshal(data []byte, v interface{}) error {
d := decodeState{r: bytes.NewReader(data)}
return d.unmarshal(v)
}
// An InvalidUnmarshalError describes an invalid argument passed to Unmarshal.
// (The argument to Unmarshal must be a non-nil pointer.)
type InvalidUnmarshalError struct {
Type reflect.Type
}
func (e *InvalidUnmarshalError) Error() string {
if e.Type == nil {
return "ebml: Unmarshal(nil)"
}
if e.Type.Kind() != reflect.Ptr {
return "ebml: Unmarshal(non-pointer " + e.Type.String() + ")"
}
return "ebml: Unmarshal(nil " + e.Type.String() + ")"
}
func (d *decodeState) unmarshal(v interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
if _, ok := r.(runtime.Error); ok {
panic(r)
}
err = r.(error)
}
}()
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return &InvalidUnmarshalError{reflect.TypeOf(v)}
}
for {
id := d.readId()
if id == 0 {
break
}
size := d.readSize()
rv = reflect.Indirect(rv)
if rv.Type().NumMethod() > 0 {
if w, ok := rv.Interface().(io.Writer); ok {
_, err = io.CopyN(w, d.r, size)
if err != nil {
d.error(err)
}
}
} else {
d.readToStruct(id, size, rv)
}
}
return d.savedError
}
type decodeState struct {
r io.ReadSeeker
savedError error
}
// error aborts the decoding by panicking with err.
func (d *decodeState) error(err error) {
panic(err)
}
// saveError saves the first err it is called with,
// for reporting at the end of the unmarshal.
func (d *decodeState) saveError(err error) {
if d.savedError == nil {
d.savedError = err
}
}
func (d *decodeState) readId() uint32 {
b := make([]byte, 4)
_, err := d.r.Read(b[:1])
if err != nil {
if err == io.EOF {
return 0
} else {
d.error(err)
}
}
x := uint32(b[0])
switch {
case x >= 0x80:
return x
case x >= 0x40:
b = b[1:2]
case x >= 0x20:
b = b[1:3]
case x >= 0x10:
b = b[1:4]
default:
d.error(errors.New("positioned at an invalid Id or EBMLMaxIDLength > 4"))
return 0
}
_, err = d.r.Read(b)
if err != nil {
d.error(err)
}
for _, c := range b {
x <<= 8
x += uint32(c)
}
return x
}
func (d *decodeState) readSize() int64 {
b := make([]byte, 8)
_, err := d.r.Read(b[:1])
if err != nil {
d.error(err)
}
x := int64(b[0])
switch {
case x >= 0x80:
x -= 0x80
return x
case x >= 0x40:
x -= 0x40
b = b[1:2]
case x >= 0x20:
x -= 0x20
b = b[1:3]
case x >= 0x10:
x -= 0x10
b = b[1:4]
case x >= 0x08:
x -= 0x08
b = b[1:5]
case x >= 0x04:
x -= 0x04
b = b[1:6]
case x >= 0x02:
x -= 0x02
b = b[1:7]
case x >= 0x01:
x -= 0x01
b = b[1:]
}
_, err = d.r.Read(b)
if err != nil {
d.error(err)
}
for _, c := range b {
x <<= 8
x += int64(c)
}
return x
}
func (d *decodeState) readValue(size int64, v reflect.Value) {
if !v.IsValid() {
_, err := d.r.Seek(size, 1)
if err != nil {
d.error(err)
}
return
}
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
b := make([]byte, size)
_, err := d.r.Read(b)
if err != nil {
d.error(err)
}
v.SetInt(unmarshalInt(b))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
b := make([]byte, size)
_, err := d.r.Read(b)
if err != nil {
d.error(err)
}
v.SetUint(unmarshalUint(b))
case reflect.Struct:
if size == 0 {
return
}
d.readStruct(size, v)
case reflect.Interface:
if m, ok := v.Interface().(Unmarshaler); ok {
w := m.UnmarshalEBML(size)
_, err := io.CopyN(w, d.r, size)
if err != nil {
d.error(err)
}
} else {
d.error(&UnsupportedTypeError{v.Type()})
}
case reflect.Slice:
d.readToSlice(size, v)
case reflect.String:
b := make([]byte, size)
d.r.Read(b)
v.SetString(string(b))
case reflect.Ptr:
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
d.readValue(size, v.Elem())
default:
d.error(&UnsupportedTypeError{v.Type()})
}
}
func (d *decodeState) readStruct(size int64, v reflect.Value) {
t := v.Type()
fieldMap := cachedIdFields(t)
pos, err := d.r.Seek(0, 1)
if err != nil {
d.error(err)
}
end := pos + size
for pos < end {
subId := d.readId()
if subId == 0 {
return
}
subSize := d.readSize()
if subSize > 0 {
fi, ok := fieldMap[subId]
if !ok {
pos, err = d.r.Seek(subSize, 1)
if err != nil {
d.error(err)
}
continue
}
subv := v
if subv.Kind() == reflect.Ptr {
if subv.IsNil() {
subv.Set(reflect.New(subv.Type().Elem()))
}
subv = subv.Elem()
}
subv = subv.Field(fi)
switch subv.Kind() {
case reflect.Struct:
d.readStruct(subSize, subv)
case reflect.Slice:
d.readToSlice(subSize, subv)
default:
d.readValue(subSize, subv)
}
}
pos, err = d.r.Seek(0, 1)
if err != nil {
d.error(err)
}
}
}
// readToSlice decodes an EBML element from d.r and appends it to the
// slice at v.
func (d *decodeState) readToSlice(size int64, v reflect.Value) {
//var subv reflect.Value
i := v.Len()
if i >= v.Cap() {
newcap := v.Cap() + v.Cap()/2
if newcap < 4 {
newcap = 4
}
newv := reflect.MakeSlice(v.Type(), v.Len(), newcap)
reflect.Copy(newv, v)
v.Set(newv)
}
v.SetLen(i + 1)
d.readValue(size, v.Index(i))
}
// GET RID OF THIS
// readToStruct decodes an EBML element from d.r into the struct at v,
// if that element is tagged to a field of v.
func (d *decodeState) readToStruct(id uint32, size int64, v reflect.Value) {
// Calling this each time may be inefficient
fieldMap := cachedIdFields(v.Type())
fi, ok := fieldMap[id]
if !ok {
_, err := d.r.Seek(size, 1)
if err != nil {
d.error(err)
}
return
}
subv := v
if subv.Kind() == reflect.Ptr {
if subv.IsNil() {
subv.Set(reflect.New(subv.Type().Elem()))
}
subv = subv.Elem()
}
subv = subv.Field(fi)
d.readValue(size, subv)
}
// indirect walks down v allocating pointers as needed,
// until it gets to a non-pointer.
func indirect(v reflect.Value) reflect.Value {
// If v is a named type and is addressable,
// start with its addres, so that if the type has pointer methods,
// we find them.
if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
v = v.Addr()
}
for {
// Load value from interface, but only if the result will be
// usefully addressable
if v.Kind() == reflect.Interface && !v.IsNil() {
e := v.Elem()
if e.Kind() == reflect.Ptr && !e.IsNil() && e.Elem().Kind() == reflect.Ptr {
v = e
continue
}
}
if v.Kind() != reflect.Ptr {
break
}
if v.Elem().Kind() != reflect.Ptr && v.CanSet() {
break
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
v = v.Elem()
}
return v
}
func unmarshalInt(b []byte) int64 {
x := int64(b[0])
for _, c := range b[1:] {
x <<= 8
x += int64(c)
}
return x
}
func unmarshalUint(b []byte) uint64 {
x := uint64(b[0])
for _, c := range b[1:] {
x <<= 8
x += uint64(c)
}
return x
}
func parseFieldIds(t reflect.Type) (m map[uint32]int) {
m = make(map[uint32]int)
for i := 0; i < t.NumField(); i++ {
sf := t.Field(i)
tag := sf.Tag.Get("ebml")
if tag == "" {
continue
}
id, err := strconv.ParseUint(tag, 16, 32)
if err != nil {
panic(err.Error())
}
m[uint32(id)] = i
}
return
}
var fieldIdCache struct {
sync.RWMutex
m map[reflect.Type]map[uint32]int
}
func cachedIdFields(t reflect.Type) map[uint32]int {
fieldIdCache.RLock()
m, ok := fieldIdCache.m[t]
fieldIdCache.RUnlock()
if !ok {
if fieldIdCache.m == nil {
fieldIdCache.m = map[reflect.Type]map[uint32]int{}
}
m = parseFieldIds(t)
fieldIdCache.Lock()
fieldIdCache.m[t] = m
fieldIdCache.Unlock()
}
return m
}