You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
519 lines
10 KiB
519 lines
10 KiB
package protocol
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"hash/crc32"
|
|
"io"
|
|
"io/ioutil"
|
|
"reflect"
|
|
"sync"
|
|
"sync/atomic"
|
|
)
|
|
|
|
type discarder interface {
|
|
Discard(int) (int, error)
|
|
}
|
|
|
|
type decoder struct {
|
|
reader io.Reader
|
|
remain int
|
|
buffer [8]byte
|
|
err error
|
|
table *crc32.Table
|
|
crc32 uint32
|
|
}
|
|
|
|
func (d *decoder) Reset(r io.Reader, n int) {
|
|
d.reader = r
|
|
d.remain = n
|
|
d.buffer = [8]byte{}
|
|
d.err = nil
|
|
d.table = nil
|
|
d.crc32 = 0
|
|
}
|
|
|
|
func (d *decoder) Read(b []byte) (int, error) {
|
|
if d.err != nil {
|
|
return 0, d.err
|
|
}
|
|
if d.remain == 0 {
|
|
return 0, io.EOF
|
|
}
|
|
if len(b) > d.remain {
|
|
b = b[:d.remain]
|
|
}
|
|
n, err := d.reader.Read(b)
|
|
if n > 0 && d.table != nil {
|
|
d.crc32 = crc32.Update(d.crc32, d.table, b[:n])
|
|
}
|
|
d.remain -= n
|
|
return n, err
|
|
}
|
|
|
|
func (d *decoder) ReadByte() (byte, error) {
|
|
c := d.readByte()
|
|
return c, d.err
|
|
}
|
|
|
|
func (d *decoder) done() bool {
|
|
return d.remain == 0 || d.err != nil
|
|
}
|
|
|
|
func (d *decoder) setCRC(table *crc32.Table) {
|
|
d.table, d.crc32 = table, 0
|
|
}
|
|
|
|
func (d *decoder) decodeBool(v value) {
|
|
v.setBool(d.readBool())
|
|
}
|
|
|
|
func (d *decoder) decodeInt8(v value) {
|
|
v.setInt8(d.readInt8())
|
|
}
|
|
|
|
func (d *decoder) decodeInt16(v value) {
|
|
v.setInt16(d.readInt16())
|
|
}
|
|
|
|
func (d *decoder) decodeInt32(v value) {
|
|
v.setInt32(d.readInt32())
|
|
}
|
|
|
|
func (d *decoder) decodeInt64(v value) {
|
|
v.setInt64(d.readInt64())
|
|
}
|
|
|
|
func (d *decoder) decodeString(v value) {
|
|
v.setString(d.readString())
|
|
}
|
|
|
|
func (d *decoder) decodeCompactString(v value) {
|
|
v.setString(d.readCompactString())
|
|
}
|
|
|
|
func (d *decoder) decodeBytes(v value) {
|
|
v.setBytes(d.readBytes())
|
|
}
|
|
|
|
func (d *decoder) decodeCompactBytes(v value) {
|
|
v.setBytes(d.readCompactBytes())
|
|
}
|
|
|
|
func (d *decoder) decodeArray(v value, elemType reflect.Type, decodeElem decodeFunc) {
|
|
if n := d.readInt32(); n < 0 {
|
|
v.setArray(array{})
|
|
} else {
|
|
a := makeArray(elemType, int(n))
|
|
for i := 0; i < int(n) && d.remain > 0; i++ {
|
|
decodeElem(d, a.index(i))
|
|
}
|
|
v.setArray(a)
|
|
}
|
|
}
|
|
|
|
func (d *decoder) decodeCompactArray(v value, elemType reflect.Type, decodeElem decodeFunc) {
|
|
if n := d.readUnsignedVarInt(); n < 1 {
|
|
v.setArray(array{})
|
|
} else {
|
|
a := makeArray(elemType, int(n-1))
|
|
for i := 0; i < int(n-1) && d.remain > 0; i++ {
|
|
decodeElem(d, a.index(i))
|
|
}
|
|
v.setArray(a)
|
|
}
|
|
}
|
|
|
|
func (d *decoder) discardAll() {
|
|
d.discard(d.remain)
|
|
}
|
|
|
|
func (d *decoder) discard(n int) {
|
|
if n > d.remain {
|
|
n = d.remain
|
|
}
|
|
var err error
|
|
if r, _ := d.reader.(discarder); r != nil {
|
|
n, err = r.Discard(n)
|
|
d.remain -= n
|
|
} else {
|
|
_, err = io.Copy(ioutil.Discard, d)
|
|
}
|
|
d.setError(err)
|
|
}
|
|
|
|
func (d *decoder) read(n int) []byte {
|
|
b := make([]byte, n)
|
|
n, err := io.ReadFull(d, b)
|
|
b = b[:n]
|
|
d.setError(err)
|
|
return b
|
|
}
|
|
|
|
func (d *decoder) writeTo(w io.Writer, n int) {
|
|
limit := d.remain
|
|
if n < limit {
|
|
d.remain = n
|
|
}
|
|
c, err := io.Copy(w, d)
|
|
if int(c) < n && err == nil {
|
|
err = io.ErrUnexpectedEOF
|
|
}
|
|
d.remain = limit - int(c)
|
|
d.setError(err)
|
|
}
|
|
|
|
func (d *decoder) setError(err error) {
|
|
if d.err == nil && err != nil {
|
|
d.err = err
|
|
d.discardAll()
|
|
}
|
|
}
|
|
|
|
func (d *decoder) readFull(b []byte) bool {
|
|
n, err := io.ReadFull(d, b)
|
|
d.setError(err)
|
|
return n == len(b)
|
|
}
|
|
|
|
func (d *decoder) readByte() byte {
|
|
if d.readFull(d.buffer[:1]) {
|
|
return d.buffer[0]
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (d *decoder) readBool() bool {
|
|
return d.readByte() != 0
|
|
}
|
|
|
|
func (d *decoder) readInt8() int8 {
|
|
if d.readFull(d.buffer[:1]) {
|
|
return readInt8(d.buffer[:1])
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (d *decoder) readInt16() int16 {
|
|
if d.readFull(d.buffer[:2]) {
|
|
return readInt16(d.buffer[:2])
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (d *decoder) readInt32() int32 {
|
|
if d.readFull(d.buffer[:4]) {
|
|
return readInt32(d.buffer[:4])
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (d *decoder) readInt64() int64 {
|
|
if d.readFull(d.buffer[:8]) {
|
|
return readInt64(d.buffer[:8])
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (d *decoder) readString() string {
|
|
if n := d.readInt16(); n < 0 {
|
|
return ""
|
|
} else {
|
|
return bytesToString(d.read(int(n)))
|
|
}
|
|
}
|
|
|
|
func (d *decoder) readVarString() string {
|
|
if n := d.readVarInt(); n < 0 {
|
|
return ""
|
|
} else {
|
|
return bytesToString(d.read(int(n)))
|
|
}
|
|
}
|
|
|
|
func (d *decoder) readCompactString() string {
|
|
if n := d.readUnsignedVarInt(); n < 1 {
|
|
return ""
|
|
} else {
|
|
return bytesToString(d.read(int(n - 1)))
|
|
}
|
|
}
|
|
|
|
func (d *decoder) readBytes() []byte {
|
|
if n := d.readInt32(); n < 0 {
|
|
return nil
|
|
} else {
|
|
return d.read(int(n))
|
|
}
|
|
}
|
|
|
|
func (d *decoder) readVarBytes() []byte {
|
|
if n := d.readVarInt(); n < 0 {
|
|
return nil
|
|
} else {
|
|
return d.read(int(n))
|
|
}
|
|
}
|
|
|
|
func (d *decoder) readCompactBytes() []byte {
|
|
if n := d.readUnsignedVarInt(); n < 1 {
|
|
return nil
|
|
} else {
|
|
return d.read(int(n - 1))
|
|
}
|
|
}
|
|
|
|
func (d *decoder) readVarInt() int64 {
|
|
n := 11 // varints are at most 11 bytes
|
|
|
|
if n > d.remain {
|
|
n = d.remain
|
|
}
|
|
|
|
x := uint64(0)
|
|
s := uint(0)
|
|
|
|
for n > 0 {
|
|
b := d.readByte()
|
|
|
|
if (b & 0x80) == 0 {
|
|
x |= uint64(b) << s
|
|
return int64(x>>1) ^ -(int64(x) & 1)
|
|
}
|
|
|
|
x |= uint64(b&0x7f) << s
|
|
s += 7
|
|
n--
|
|
}
|
|
|
|
d.setError(fmt.Errorf("cannot decode varint from input stream"))
|
|
return 0
|
|
}
|
|
|
|
func (d *decoder) readUnsignedVarInt() uint64 {
|
|
n := 11 // varints are at most 11 bytes
|
|
|
|
if n > d.remain {
|
|
n = d.remain
|
|
}
|
|
|
|
x := uint64(0)
|
|
s := uint(0)
|
|
|
|
for n > 0 {
|
|
b := d.readByte()
|
|
|
|
if (b & 0x80) == 0 {
|
|
x |= uint64(b) << s
|
|
return x
|
|
}
|
|
|
|
x |= uint64(b&0x7f) << s
|
|
s += 7
|
|
n--
|
|
}
|
|
|
|
d.setError(fmt.Errorf("cannot decode unsigned varint from input stream"))
|
|
return 0
|
|
}
|
|
|
|
type decodeFunc func(*decoder, value)
|
|
|
|
var (
|
|
_ io.Reader = (*decoder)(nil)
|
|
_ io.ByteReader = (*decoder)(nil)
|
|
|
|
readerFrom = reflect.TypeOf((*io.ReaderFrom)(nil)).Elem()
|
|
)
|
|
|
|
func decodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc {
|
|
if reflect.PtrTo(typ).Implements(readerFrom) {
|
|
return readerDecodeFuncOf(typ)
|
|
}
|
|
switch typ.Kind() {
|
|
case reflect.Bool:
|
|
return (*decoder).decodeBool
|
|
case reflect.Int8:
|
|
return (*decoder).decodeInt8
|
|
case reflect.Int16:
|
|
return (*decoder).decodeInt16
|
|
case reflect.Int32:
|
|
return (*decoder).decodeInt32
|
|
case reflect.Int64:
|
|
return (*decoder).decodeInt64
|
|
case reflect.String:
|
|
return stringDecodeFuncOf(flexible, tag)
|
|
case reflect.Struct:
|
|
return structDecodeFuncOf(typ, version, flexible)
|
|
case reflect.Slice:
|
|
if typ.Elem().Kind() == reflect.Uint8 { // []byte
|
|
return bytesDecodeFuncOf(flexible, tag)
|
|
}
|
|
return arrayDecodeFuncOf(typ, version, flexible, tag)
|
|
default:
|
|
panic("unsupported type: " + typ.String())
|
|
}
|
|
}
|
|
|
|
func stringDecodeFuncOf(flexible bool, tag structTag) decodeFunc {
|
|
if flexible {
|
|
// In flexible messages, all strings are compact
|
|
return (*decoder).decodeCompactString
|
|
}
|
|
return (*decoder).decodeString
|
|
}
|
|
|
|
func bytesDecodeFuncOf(flexible bool, tag structTag) decodeFunc {
|
|
if flexible {
|
|
// In flexible messages, all arrays are compact
|
|
return (*decoder).decodeCompactBytes
|
|
}
|
|
return (*decoder).decodeBytes
|
|
}
|
|
|
|
func structDecodeFuncOf(typ reflect.Type, version int16, flexible bool) decodeFunc {
|
|
type field struct {
|
|
decode decodeFunc
|
|
index index
|
|
tagID int
|
|
}
|
|
|
|
var fields []field
|
|
taggedFields := map[int]*field{}
|
|
|
|
forEachStructField(typ, func(typ reflect.Type, index index, tag string) {
|
|
forEachStructTag(tag, func(tag structTag) bool {
|
|
if tag.MinVersion <= version && version <= tag.MaxVersion {
|
|
f := field{
|
|
decode: decodeFuncOf(typ, version, flexible, tag),
|
|
index: index,
|
|
tagID: tag.TagID,
|
|
}
|
|
|
|
if tag.TagID < -1 {
|
|
// Normal required field
|
|
fields = append(fields, f)
|
|
} else {
|
|
// Optional tagged field (flexible messages only)
|
|
taggedFields[tag.TagID] = &f
|
|
}
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
})
|
|
|
|
return func(d *decoder, v value) {
|
|
for i := range fields {
|
|
f := &fields[i]
|
|
f.decode(d, v.fieldByIndex(f.index))
|
|
}
|
|
|
|
if flexible {
|
|
// See https://cwiki.apache.org/confluence/display/KAFKA/KIP-482%3A+The+Kafka+Protocol+should+Support+Optional+Tagged+Fields
|
|
// for details of tag buffers in "flexible" messages.
|
|
n := int(d.readUnsignedVarInt())
|
|
|
|
for i := 0; i < n; i++ {
|
|
tagID := int(d.readUnsignedVarInt())
|
|
size := int(d.readUnsignedVarInt())
|
|
|
|
f, ok := taggedFields[tagID]
|
|
if ok {
|
|
f.decode(d, v.fieldByIndex(f.index))
|
|
} else {
|
|
d.read(size)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func arrayDecodeFuncOf(typ reflect.Type, version int16, flexible bool, tag structTag) decodeFunc {
|
|
elemType := typ.Elem()
|
|
elemFunc := decodeFuncOf(elemType, version, flexible, tag)
|
|
if flexible {
|
|
// In flexible messages, all arrays are compact
|
|
return func(d *decoder, v value) { d.decodeCompactArray(v, elemType, elemFunc) }
|
|
}
|
|
|
|
return func(d *decoder, v value) { d.decodeArray(v, elemType, elemFunc) }
|
|
}
|
|
|
|
func readerDecodeFuncOf(typ reflect.Type) decodeFunc {
|
|
typ = reflect.PtrTo(typ)
|
|
return func(d *decoder, v value) {
|
|
if d.err == nil {
|
|
_, err := v.iface(typ).(io.ReaderFrom).ReadFrom(d)
|
|
if err != nil {
|
|
d.setError(err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func readInt8(b []byte) int8 {
|
|
return int8(b[0])
|
|
}
|
|
|
|
func readInt16(b []byte) int16 {
|
|
return int16(binary.BigEndian.Uint16(b))
|
|
}
|
|
|
|
func readInt32(b []byte) int32 {
|
|
return int32(binary.BigEndian.Uint32(b))
|
|
}
|
|
|
|
func readInt64(b []byte) int64 {
|
|
return int64(binary.BigEndian.Uint64(b))
|
|
}
|
|
|
|
func Unmarshal(data []byte, version int16, value interface{}) error {
|
|
typ := elemTypeOf(value)
|
|
cache, _ := unmarshalers.Load().(map[versionedType]decodeFunc)
|
|
key := versionedType{typ: typ, version: version}
|
|
decode := cache[key]
|
|
|
|
if decode == nil {
|
|
decode = decodeFuncOf(reflect.TypeOf(value).Elem(), version, false, structTag{
|
|
MinVersion: -1,
|
|
MaxVersion: -1,
|
|
TagID: -2,
|
|
Compact: true,
|
|
Nullable: true,
|
|
})
|
|
|
|
newCache := make(map[versionedType]decodeFunc, len(cache)+1)
|
|
newCache[key] = decode
|
|
|
|
for typ, fun := range cache {
|
|
newCache[typ] = fun
|
|
}
|
|
|
|
unmarshalers.Store(newCache)
|
|
}
|
|
|
|
d, _ := decoders.Get().(*decoder)
|
|
if d == nil {
|
|
d = &decoder{reader: bytes.NewReader(nil)}
|
|
}
|
|
|
|
d.remain = len(data)
|
|
r, _ := d.reader.(*bytes.Reader)
|
|
r.Reset(data)
|
|
|
|
defer func() {
|
|
r.Reset(nil)
|
|
d.Reset(r, 0)
|
|
decoders.Put(d)
|
|
}()
|
|
|
|
decode(d, valueOf(value))
|
|
return dontExpectEOF(d.err)
|
|
}
|
|
|
|
var (
|
|
decoders sync.Pool // *decoder
|
|
unmarshalers atomic.Value // map[versionedType]decodeFunc
|
|
)
|
|
|