|
- package codec
-
- import (
- "errors"
- "fmt"
- "io"
- "math"
-
- "github.com/golang/protobuf/proto"
- "github.com/golang/protobuf/protoc-gen-go/descriptor"
-
- "github.com/jhump/protoreflect/desc"
- )
-
- // ErrWireTypeEndGroup is returned from DecodeFieldValue if the tag and wire-type
- // it reads indicates an end-group marker.
- var ErrWireTypeEndGroup = errors.New("unexpected wire type: end group")
-
- // MessageFactory is used to instantiate messages when DecodeFieldValue needs to
- // decode a message value.
- //
- // Also see MessageFactory in "github.com/jhump/protoreflect/dynamic", which
- // implements this interface.
- type MessageFactory interface {
- NewMessage(md *desc.MessageDescriptor) proto.Message
- }
-
- // UnknownField represents a field that was parsed from the binary wire
- // format for a message, but was not a recognized field number. Enough
- // information is preserved so that re-serializing the message won't lose
- // any of the unrecognized data.
- type UnknownField struct {
- // The tag number for the unrecognized field.
- Tag int32
-
- // Encoding indicates how the unknown field was encoded on the wire. If it
- // is proto.WireBytes or proto.WireGroupStart then Contents will be set to
- // the raw bytes. If it is proto.WireTypeFixed32 then the data is in the least
- // significant 32 bits of Value. Otherwise, the data is in all 64 bits of
- // Value.
- Encoding int8
- Contents []byte
- Value uint64
- }
-
- // DecodeFieldValue will read a field value from the buffer and return its
- // value and the corresponding field descriptor. The given function is used
- // to lookup a field descriptor by tag number. The given factory is used to
- // instantiate a message if the field value is (or contains) a message value.
- //
- // On error, the field descriptor and value are typically nil. However, if the
- // error returned is ErrWireTypeEndGroup, the returned value will indicate any
- // tag number encoded in the end-group marker.
- //
- // If the field descriptor returned is nil, that means that the given function
- // returned nil. This is expected to happen for unrecognized tag numbers. In
- // that case, no error is returned, and the value will be an UnknownField.
- func (cb *Buffer) DecodeFieldValue(fieldFinder func(int32) *desc.FieldDescriptor, fact MessageFactory) (*desc.FieldDescriptor, interface{}, error) {
- if cb.EOF() {
- return nil, nil, io.EOF
- }
- tagNumber, wireType, err := cb.DecodeTagAndWireType()
- if err != nil {
- return nil, nil, err
- }
- if wireType == proto.WireEndGroup {
- return nil, tagNumber, ErrWireTypeEndGroup
- }
- fd := fieldFinder(tagNumber)
- if fd == nil {
- val, err := cb.decodeUnknownField(tagNumber, wireType)
- return nil, val, err
- }
- val, err := cb.decodeKnownField(fd, wireType, fact)
- return fd, val, err
- }
-
- // DecodeScalarField extracts a properly-typed value from v. The returned value's
- // type depends on the given field descriptor type. It will be the same type as
- // generated structs use for the field descriptor's type. Enum types will return
- // an int32. If the given field type uses length-delimited encoding (nested
- // messages, bytes, and strings), an error is returned.
- func DecodeScalarField(fd *desc.FieldDescriptor, v uint64) (interface{}, error) {
- switch fd.GetType() {
- case descriptor.FieldDescriptorProto_TYPE_BOOL:
- return v != 0, nil
- case descriptor.FieldDescriptorProto_TYPE_UINT32,
- descriptor.FieldDescriptorProto_TYPE_FIXED32:
- if v > math.MaxUint32 {
- return nil, ErrOverflow
- }
- return uint32(v), nil
-
- case descriptor.FieldDescriptorProto_TYPE_INT32,
- descriptor.FieldDescriptorProto_TYPE_ENUM:
- s := int64(v)
- if s > math.MaxInt32 || s < math.MinInt32 {
- return nil, ErrOverflow
- }
- return int32(s), nil
-
- case descriptor.FieldDescriptorProto_TYPE_SFIXED32:
- if v > math.MaxUint32 {
- return nil, ErrOverflow
- }
- return int32(v), nil
-
- case descriptor.FieldDescriptorProto_TYPE_SINT32:
- if v > math.MaxUint32 {
- return nil, ErrOverflow
- }
- return DecodeZigZag32(v), nil
-
- case descriptor.FieldDescriptorProto_TYPE_UINT64,
- descriptor.FieldDescriptorProto_TYPE_FIXED64:
- return v, nil
-
- case descriptor.FieldDescriptorProto_TYPE_INT64,
- descriptor.FieldDescriptorProto_TYPE_SFIXED64:
- return int64(v), nil
-
- case descriptor.FieldDescriptorProto_TYPE_SINT64:
- return DecodeZigZag64(v), nil
-
- case descriptor.FieldDescriptorProto_TYPE_FLOAT:
- if v > math.MaxUint32 {
- return nil, ErrOverflow
- }
- return math.Float32frombits(uint32(v)), nil
-
- case descriptor.FieldDescriptorProto_TYPE_DOUBLE:
- return math.Float64frombits(v), nil
-
- default:
- // bytes, string, message, and group cannot be represented as a simple numeric value
- return nil, fmt.Errorf("bad input; field %s requires length-delimited wire type", fd.GetFullyQualifiedName())
- }
- }
-
- // DecodeLengthDelimitedField extracts a properly-typed value from bytes. The
- // returned value's type will usually be []byte, string, or, for nested messages,
- // the type returned from the given message factory. However, since repeated
- // scalar fields can be length-delimited, when they used packed encoding, it can
- // also return an []interface{}, where each element is a scalar value. Furthermore,
- // it could return a scalar type, not in a slice, if the given field descriptor is
- // not repeated. This is to support cases where a field is changed from optional
- // to repeated. New code may emit a packed repeated representation, but old code
- // still expects a single scalar value. In this case, if the actual data in bytes
- // contains multiple values, only the last value is returned.
- func DecodeLengthDelimitedField(fd *desc.FieldDescriptor, bytes []byte, mf MessageFactory) (interface{}, error) {
- switch {
- case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES:
- return bytes, nil
-
- case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_STRING:
- return string(bytes), nil
-
- case fd.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE ||
- fd.GetType() == descriptor.FieldDescriptorProto_TYPE_GROUP:
- msg := mf.NewMessage(fd.GetMessageType())
- err := proto.Unmarshal(bytes, msg)
- if err != nil {
- return nil, err
- } else {
- return msg, nil
- }
-
- default:
- // even if the field is not repeated or not packed, we still parse it as such for
- // backwards compatibility (e.g. message we are de-serializing could have been both
- // repeated and packed at the time of serialization)
- packedBuf := NewBuffer(bytes)
- var slice []interface{}
- var val interface{}
- for !packedBuf.EOF() {
- var v uint64
- var err error
- if varintTypes[fd.GetType()] {
- v, err = packedBuf.DecodeVarint()
- } else if fixed32Types[fd.GetType()] {
- v, err = packedBuf.DecodeFixed32()
- } else if fixed64Types[fd.GetType()] {
- v, err = packedBuf.DecodeFixed64()
- } else {
- return nil, fmt.Errorf("bad input; cannot parse length-delimited wire type for field %s", fd.GetFullyQualifiedName())
- }
- if err != nil {
- return nil, err
- }
- val, err = DecodeScalarField(fd, v)
- if err != nil {
- return nil, err
- }
- if fd.IsRepeated() {
- slice = append(slice, val)
- }
- }
- if fd.IsRepeated() {
- return slice, nil
- } else {
- // if not a repeated field, last value wins
- return val, nil
- }
- }
- }
-
- func (b *Buffer) decodeKnownField(fd *desc.FieldDescriptor, encoding int8, fact MessageFactory) (interface{}, error) {
- var val interface{}
- var err error
- switch encoding {
- case proto.WireFixed32:
- var num uint64
- num, err = b.DecodeFixed32()
- if err == nil {
- val, err = DecodeScalarField(fd, num)
- }
- case proto.WireFixed64:
- var num uint64
- num, err = b.DecodeFixed64()
- if err == nil {
- val, err = DecodeScalarField(fd, num)
- }
- case proto.WireVarint:
- var num uint64
- num, err = b.DecodeVarint()
- if err == nil {
- val, err = DecodeScalarField(fd, num)
- }
-
- case proto.WireBytes:
- alloc := fd.GetType() == descriptor.FieldDescriptorProto_TYPE_BYTES
- var raw []byte
- raw, err = b.DecodeRawBytes(alloc)
- if err == nil {
- val, err = DecodeLengthDelimitedField(fd, raw, fact)
- }
-
- case proto.WireStartGroup:
- if fd.GetMessageType() == nil {
- return nil, fmt.Errorf("cannot parse field %s from group-encoded wire type", fd.GetFullyQualifiedName())
- }
- msg := fact.NewMessage(fd.GetMessageType())
- var data []byte
- data, err = b.ReadGroup(false)
- if err == nil {
- err = proto.Unmarshal(data, msg)
- if err == nil {
- val = msg
- }
- }
-
- default:
- return nil, ErrBadWireType
- }
- if err != nil {
- return nil, err
- }
-
- return val, nil
- }
-
- func (b *Buffer) decodeUnknownField(tagNumber int32, encoding int8) (interface{}, error) {
- u := UnknownField{Tag: tagNumber, Encoding: encoding}
- var err error
- switch encoding {
- case proto.WireFixed32:
- u.Value, err = b.DecodeFixed32()
- case proto.WireFixed64:
- u.Value, err = b.DecodeFixed64()
- case proto.WireVarint:
- u.Value, err = b.DecodeVarint()
- case proto.WireBytes:
- u.Contents, err = b.DecodeRawBytes(true)
- case proto.WireStartGroup:
- u.Contents, err = b.ReadGroup(true)
- default:
- err = ErrBadWireType
- }
- if err != nil {
- return nil, err
- }
- return u, nil
- }
|