556 lines
13 KiB
Go
556 lines
13 KiB
Go
/*
|
|
* Licensed to the Apache Software Foundation (ASF) under one
|
|
* or more contributor license agreements. See the NOTICE file
|
|
* distributed with this work for additional information
|
|
* regarding copyright ownership. The ASF licenses this file
|
|
* to you under the Apache License, Version 2.0 (the
|
|
* "License"); you may not use this file except in compliance
|
|
* with the License. You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing,
|
|
* software distributed under the License is distributed on an
|
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
* KIND, either express or implied. See the License for the
|
|
* specific language governing permissions and limitations
|
|
* under the License.
|
|
*/
|
|
|
|
package thrift
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
)
|
|
|
|
type TBinaryProtocol struct {
|
|
trans TRichTransport
|
|
origTransport TTransport
|
|
cfg *TConfiguration
|
|
buffer [64]byte
|
|
}
|
|
|
|
type TBinaryProtocolFactory struct {
|
|
cfg *TConfiguration
|
|
}
|
|
|
|
// Deprecated: Use NewTBinaryProtocolConf instead.
|
|
func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol {
|
|
return NewTBinaryProtocolConf(t, &TConfiguration{
|
|
noPropagation: true,
|
|
})
|
|
}
|
|
|
|
// Deprecated: Use NewTBinaryProtocolConf instead.
|
|
func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol {
|
|
return NewTBinaryProtocolConf(t, &TConfiguration{
|
|
TBinaryStrictRead: &strictRead,
|
|
TBinaryStrictWrite: &strictWrite,
|
|
|
|
noPropagation: true,
|
|
})
|
|
}
|
|
|
|
func NewTBinaryProtocolConf(t TTransport, conf *TConfiguration) *TBinaryProtocol {
|
|
PropagateTConfiguration(t, conf)
|
|
p := &TBinaryProtocol{
|
|
origTransport: t,
|
|
cfg: conf,
|
|
}
|
|
if et, ok := t.(TRichTransport); ok {
|
|
p.trans = et
|
|
} else {
|
|
p.trans = NewTRichTransport(t)
|
|
}
|
|
return p
|
|
}
|
|
|
|
// Deprecated: Use NewTBinaryProtocolFactoryConf instead.
|
|
func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory {
|
|
return NewTBinaryProtocolFactoryConf(&TConfiguration{
|
|
noPropagation: true,
|
|
})
|
|
}
|
|
|
|
// Deprecated: Use NewTBinaryProtocolFactoryConf instead.
|
|
func NewTBinaryProtocolFactory(strictRead, strictWrite bool) *TBinaryProtocolFactory {
|
|
return NewTBinaryProtocolFactoryConf(&TConfiguration{
|
|
TBinaryStrictRead: &strictRead,
|
|
TBinaryStrictWrite: &strictWrite,
|
|
|
|
noPropagation: true,
|
|
})
|
|
}
|
|
|
|
func NewTBinaryProtocolFactoryConf(conf *TConfiguration) *TBinaryProtocolFactory {
|
|
return &TBinaryProtocolFactory{
|
|
cfg: conf,
|
|
}
|
|
}
|
|
|
|
func (p *TBinaryProtocolFactory) GetProtocol(t TTransport) TProtocol {
|
|
return NewTBinaryProtocolConf(t, p.cfg)
|
|
}
|
|
|
|
func (p *TBinaryProtocolFactory) SetTConfiguration(conf *TConfiguration) {
|
|
p.cfg = conf
|
|
}
|
|
|
|
/**
|
|
* Writing Methods
|
|
*/
|
|
|
|
func (p *TBinaryProtocol) WriteMessageBegin(ctx context.Context, name string, typeId TMessageType, seqId int32) error {
|
|
if p.cfg.GetTBinaryStrictWrite() {
|
|
version := uint32(VERSION_1) | uint32(typeId)
|
|
e := p.WriteI32(ctx, int32(version))
|
|
if e != nil {
|
|
return e
|
|
}
|
|
e = p.WriteString(ctx, name)
|
|
if e != nil {
|
|
return e
|
|
}
|
|
e = p.WriteI32(ctx, seqId)
|
|
return e
|
|
} else {
|
|
e := p.WriteString(ctx, name)
|
|
if e != nil {
|
|
return e
|
|
}
|
|
e = p.WriteByte(ctx, int8(typeId))
|
|
if e != nil {
|
|
return e
|
|
}
|
|
e = p.WriteI32(ctx, seqId)
|
|
return e
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteMessageEnd(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteStructBegin(ctx context.Context, name string) error {
|
|
return nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteStructEnd(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteFieldBegin(ctx context.Context, name string, typeId TType, id int16) error {
|
|
e := p.WriteByte(ctx, int8(typeId))
|
|
if e != nil {
|
|
return e
|
|
}
|
|
e = p.WriteI16(ctx, id)
|
|
return e
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteFieldEnd(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteFieldStop(ctx context.Context) error {
|
|
e := p.WriteByte(ctx, STOP)
|
|
return e
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
|
|
e := p.WriteByte(ctx, int8(keyType))
|
|
if e != nil {
|
|
return e
|
|
}
|
|
e = p.WriteByte(ctx, int8(valueType))
|
|
if e != nil {
|
|
return e
|
|
}
|
|
e = p.WriteI32(ctx, int32(size))
|
|
return e
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteMapEnd(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
|
|
e := p.WriteByte(ctx, int8(elemType))
|
|
if e != nil {
|
|
return e
|
|
}
|
|
e = p.WriteI32(ctx, int32(size))
|
|
return e
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteListEnd(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
|
|
e := p.WriteByte(ctx, int8(elemType))
|
|
if e != nil {
|
|
return e
|
|
}
|
|
e = p.WriteI32(ctx, int32(size))
|
|
return e
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteSetEnd(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteBool(ctx context.Context, value bool) error {
|
|
if value {
|
|
return p.WriteByte(ctx, 1)
|
|
}
|
|
return p.WriteByte(ctx, 0)
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteByte(ctx context.Context, value int8) error {
|
|
e := p.trans.WriteByte(byte(value))
|
|
return NewTProtocolException(e)
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteI16(ctx context.Context, value int16) error {
|
|
v := p.buffer[0:2]
|
|
binary.BigEndian.PutUint16(v, uint16(value))
|
|
_, e := p.trans.Write(v)
|
|
return NewTProtocolException(e)
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteI32(ctx context.Context, value int32) error {
|
|
v := p.buffer[0:4]
|
|
binary.BigEndian.PutUint32(v, uint32(value))
|
|
_, e := p.trans.Write(v)
|
|
return NewTProtocolException(e)
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteI64(ctx context.Context, value int64) error {
|
|
v := p.buffer[0:8]
|
|
binary.BigEndian.PutUint64(v, uint64(value))
|
|
_, err := p.trans.Write(v)
|
|
return NewTProtocolException(err)
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteDouble(ctx context.Context, value float64) error {
|
|
return p.WriteI64(ctx, int64(math.Float64bits(value)))
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteString(ctx context.Context, value string) error {
|
|
e := p.WriteI32(ctx, int32(len(value)))
|
|
if e != nil {
|
|
return e
|
|
}
|
|
_, err := p.trans.WriteString(value)
|
|
return NewTProtocolException(err)
|
|
}
|
|
|
|
func (p *TBinaryProtocol) WriteBinary(ctx context.Context, value []byte) error {
|
|
e := p.WriteI32(ctx, int32(len(value)))
|
|
if e != nil {
|
|
return e
|
|
}
|
|
_, err := p.trans.Write(value)
|
|
return NewTProtocolException(err)
|
|
}
|
|
|
|
/**
|
|
* Reading methods
|
|
*/
|
|
|
|
func (p *TBinaryProtocol) ReadMessageBegin(ctx context.Context) (name string, typeId TMessageType, seqId int32, err error) {
|
|
size, e := p.ReadI32(ctx)
|
|
if e != nil {
|
|
return "", typeId, 0, NewTProtocolException(e)
|
|
}
|
|
if size < 0 {
|
|
typeId = TMessageType(size & 0x0ff)
|
|
version := int64(int64(size) & VERSION_MASK)
|
|
if version != VERSION_1 {
|
|
return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Bad version in ReadMessageBegin"))
|
|
}
|
|
name, e = p.ReadString(ctx)
|
|
if e != nil {
|
|
return name, typeId, seqId, NewTProtocolException(e)
|
|
}
|
|
seqId, e = p.ReadI32(ctx)
|
|
if e != nil {
|
|
return name, typeId, seqId, NewTProtocolException(e)
|
|
}
|
|
return name, typeId, seqId, nil
|
|
}
|
|
if p.cfg.GetTBinaryStrictRead() {
|
|
return name, typeId, seqId, NewTProtocolExceptionWithType(BAD_VERSION, fmt.Errorf("Missing version in ReadMessageBegin"))
|
|
}
|
|
name, e2 := p.readStringBody(size)
|
|
if e2 != nil {
|
|
return name, typeId, seqId, e2
|
|
}
|
|
b, e3 := p.ReadByte(ctx)
|
|
if e3 != nil {
|
|
return name, typeId, seqId, e3
|
|
}
|
|
typeId = TMessageType(b)
|
|
seqId, e4 := p.ReadI32(ctx)
|
|
if e4 != nil {
|
|
return name, typeId, seqId, e4
|
|
}
|
|
return name, typeId, seqId, nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadMessageEnd(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
|
|
return
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadStructEnd(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadFieldBegin(ctx context.Context) (name string, typeId TType, seqId int16, err error) {
|
|
t, err := p.ReadByte(ctx)
|
|
typeId = TType(t)
|
|
if err != nil {
|
|
return name, typeId, seqId, err
|
|
}
|
|
if t != STOP {
|
|
seqId, err = p.ReadI16(ctx)
|
|
}
|
|
return name, typeId, seqId, err
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadFieldEnd(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
var invalidDataLength = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Invalid data length"))
|
|
|
|
func (p *TBinaryProtocol) ReadMapBegin(ctx context.Context) (kType, vType TType, size int, err error) {
|
|
k, e := p.ReadByte(ctx)
|
|
if e != nil {
|
|
err = NewTProtocolException(e)
|
|
return
|
|
}
|
|
kType = TType(k)
|
|
v, e := p.ReadByte(ctx)
|
|
if e != nil {
|
|
err = NewTProtocolException(e)
|
|
return
|
|
}
|
|
vType = TType(v)
|
|
size32, e := p.ReadI32(ctx)
|
|
if e != nil {
|
|
err = NewTProtocolException(e)
|
|
return
|
|
}
|
|
if size32 < 0 {
|
|
err = invalidDataLength
|
|
return
|
|
}
|
|
size = int(size32)
|
|
return kType, vType, size, nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadMapEnd(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) {
|
|
b, e := p.ReadByte(ctx)
|
|
if e != nil {
|
|
err = NewTProtocolException(e)
|
|
return
|
|
}
|
|
elemType = TType(b)
|
|
size32, e := p.ReadI32(ctx)
|
|
if e != nil {
|
|
err = NewTProtocolException(e)
|
|
return
|
|
}
|
|
if size32 < 0 {
|
|
err = invalidDataLength
|
|
return
|
|
}
|
|
size = int(size32)
|
|
|
|
return
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadListEnd(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) {
|
|
b, e := p.ReadByte(ctx)
|
|
if e != nil {
|
|
err = NewTProtocolException(e)
|
|
return
|
|
}
|
|
elemType = TType(b)
|
|
size32, e := p.ReadI32(ctx)
|
|
if e != nil {
|
|
err = NewTProtocolException(e)
|
|
return
|
|
}
|
|
if size32 < 0 {
|
|
err = invalidDataLength
|
|
return
|
|
}
|
|
size = int(size32)
|
|
return elemType, size, nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadSetEnd(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadBool(ctx context.Context) (bool, error) {
|
|
b, e := p.ReadByte(ctx)
|
|
v := true
|
|
if b != 1 {
|
|
v = false
|
|
}
|
|
return v, e
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadByte(ctx context.Context) (int8, error) {
|
|
v, err := p.trans.ReadByte()
|
|
return int8(v), err
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadI16(ctx context.Context) (value int16, err error) {
|
|
buf := p.buffer[0:2]
|
|
err = p.readAll(ctx, buf)
|
|
value = int16(binary.BigEndian.Uint16(buf))
|
|
return value, err
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadI32(ctx context.Context) (value int32, err error) {
|
|
buf := p.buffer[0:4]
|
|
err = p.readAll(ctx, buf)
|
|
value = int32(binary.BigEndian.Uint32(buf))
|
|
return value, err
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadI64(ctx context.Context) (value int64, err error) {
|
|
buf := p.buffer[0:8]
|
|
err = p.readAll(ctx, buf)
|
|
value = int64(binary.BigEndian.Uint64(buf))
|
|
return value, err
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadDouble(ctx context.Context) (value float64, err error) {
|
|
buf := p.buffer[0:8]
|
|
err = p.readAll(ctx, buf)
|
|
value = math.Float64frombits(binary.BigEndian.Uint64(buf))
|
|
return value, err
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadString(ctx context.Context) (value string, err error) {
|
|
size, e := p.ReadI32(ctx)
|
|
if e != nil {
|
|
return "", e
|
|
}
|
|
err = checkSizeForProtocol(size, p.cfg)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if size < 0 {
|
|
err = invalidDataLength
|
|
return
|
|
}
|
|
if size == 0 {
|
|
return "", nil
|
|
}
|
|
if size < int32(len(p.buffer)) {
|
|
// Avoid allocation on small reads
|
|
buf := p.buffer[:size]
|
|
read, e := io.ReadFull(p.trans, buf)
|
|
return string(buf[:read]), NewTProtocolException(e)
|
|
}
|
|
|
|
return p.readStringBody(size)
|
|
}
|
|
|
|
func (p *TBinaryProtocol) ReadBinary(ctx context.Context) ([]byte, error) {
|
|
size, e := p.ReadI32(ctx)
|
|
if e != nil {
|
|
return nil, e
|
|
}
|
|
if err := checkSizeForProtocol(size, p.cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
buf, err := safeReadBytes(size, p.trans)
|
|
return buf, NewTProtocolException(err)
|
|
}
|
|
|
|
func (p *TBinaryProtocol) Flush(ctx context.Context) (err error) {
|
|
return NewTProtocolException(p.trans.Flush(ctx))
|
|
}
|
|
|
|
func (p *TBinaryProtocol) Skip(ctx context.Context, fieldType TType) (err error) {
|
|
return SkipDefaultDepth(ctx, p, fieldType)
|
|
}
|
|
|
|
func (p *TBinaryProtocol) Transport() TTransport {
|
|
return p.origTransport
|
|
}
|
|
|
|
func (p *TBinaryProtocol) readAll(ctx context.Context, buf []byte) (err error) {
|
|
var read int
|
|
_, deadlineSet := ctx.Deadline()
|
|
for {
|
|
read, err = io.ReadFull(p.trans, buf)
|
|
if deadlineSet && read == 0 && isTimeoutError(err) && ctx.Err() == nil {
|
|
// This is I/O timeout without anything read,
|
|
// and we still have time left, keep retrying.
|
|
continue
|
|
}
|
|
// For anything else, don't retry
|
|
break
|
|
}
|
|
return NewTProtocolException(err)
|
|
}
|
|
|
|
func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
|
|
buf, err := safeReadBytes(size, p.trans)
|
|
return string(buf), NewTProtocolException(err)
|
|
}
|
|
|
|
func (p *TBinaryProtocol) SetTConfiguration(conf *TConfiguration) {
|
|
PropagateTConfiguration(p.trans, conf)
|
|
PropagateTConfiguration(p.origTransport, conf)
|
|
p.cfg = conf
|
|
}
|
|
|
|
var (
|
|
_ TConfigurationSetter = (*TBinaryProtocolFactory)(nil)
|
|
_ TConfigurationSetter = (*TBinaryProtocol)(nil)
|
|
)
|
|
|
|
// This function is shared between TBinaryProtocol and TCompactProtocol.
|
|
//
|
|
// It tries to read size bytes from trans, in a way that prevents large
|
|
// allocations when size is insanely large (mostly caused by malformed message).
|
|
func safeReadBytes(size int32, trans io.Reader) ([]byte, error) {
|
|
if size < 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
buf := new(bytes.Buffer)
|
|
_, err := io.CopyN(buf, trans, int64(size))
|
|
return buf.Bytes(), err
|
|
}
|