package msgpack import ( "fmt" "math" "reflect" "github.com/vmihailenco/msgpack/v5/msgpcode" ) const ( minInternedStringLen = 3 maxDictLen = math.MaxUint16 ) var internedStringExtID = int8(math.MinInt8) func init() { extTypes[internedStringExtID] = &extInfo{ Type: stringType, Decoder: decodeInternedStringExt, } } func decodeInternedStringExt(d *Decoder, v reflect.Value, extLen int) error { idx, err := d.decodeInternedStringIndex(extLen) if err != nil { return err } s, err := d.internedStringAtIndex(idx) if err != nil { return err } v.SetString(s) return nil } //------------------------------------------------------------------------------ func encodeInternedInterfaceValue(e *Encoder, v reflect.Value) error { if v.IsNil() { return e.EncodeNil() } v = v.Elem() if v.Kind() == reflect.String { return e.encodeInternedString(v.String(), true) } return e.EncodeValue(v) } func encodeInternedStringValue(e *Encoder, v reflect.Value) error { return e.encodeInternedString(v.String(), true) } func (e *Encoder) encodeInternedString(s string, intern bool) error { // Interned string takes at least 3 bytes. Plain string 1 byte + string len. if idx, ok := e.dict[s]; ok { return e.encodeInternedStringIndex(idx) } if intern && len(s) >= minInternedStringLen && len(e.dict) < maxDictLen { if e.dict == nil { e.dict = make(map[string]int) } idx := len(e.dict) e.dict[s] = idx } return e.encodeNormalString(s) } func (e *Encoder) encodeInternedStringIndex(idx int) error { if idx <= math.MaxUint8 { if err := e.writeCode(msgpcode.FixExt1); err != nil { return err } return e.write1(byte(internedStringExtID), uint8(idx)) } if idx <= math.MaxUint16 { if err := e.writeCode(msgpcode.FixExt2); err != nil { return err } return e.write2(byte(internedStringExtID), uint16(idx)) } if uint64(idx) <= math.MaxUint32 { if err := e.writeCode(msgpcode.FixExt4); err != nil { return err } return e.write4(byte(internedStringExtID), uint32(idx)) } return fmt.Errorf("msgpack: interned string index=%d is too large", idx) } //------------------------------------------------------------------------------ func decodeInternedInterfaceValue(d *Decoder, v reflect.Value) error { s, err := d.decodeInternedString(true) if err == nil { v.Set(reflect.ValueOf(s)) return nil } if err != nil { if _, ok := err.(unexpectedCodeError); !ok { return err } } if err := d.s.UnreadByte(); err != nil { return err } return decodeInterfaceValue(d, v) } func decodeInternedStringValue(d *Decoder, v reflect.Value) error { s, err := d.decodeInternedString(true) if err != nil { return err } v.SetString(s) return nil } func (d *Decoder) decodeInternedString(intern bool) (string, error) { c, err := d.readCode() if err != nil { return "", err } if msgpcode.IsFixedString(c) { n := int(c & msgpcode.FixedStrMask) return d.decodeInternedStringWithLen(n, intern) } switch c { case msgpcode.Nil: return "", nil case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4: typeID, extLen, err := d.extHeader(c) if err != nil { return "", err } if typeID != internedStringExtID { err := fmt.Errorf("msgpack: got ext type=%d, wanted %d", typeID, internedStringExtID) return "", err } idx, err := d.decodeInternedStringIndex(extLen) if err != nil { return "", err } return d.internedStringAtIndex(idx) case msgpcode.Str8, msgpcode.Bin8: n, err := d.uint8() if err != nil { return "", err } return d.decodeInternedStringWithLen(int(n), intern) case msgpcode.Str16, msgpcode.Bin16: n, err := d.uint16() if err != nil { return "", err } return d.decodeInternedStringWithLen(int(n), intern) case msgpcode.Str32, msgpcode.Bin32: n, err := d.uint32() if err != nil { return "", err } return d.decodeInternedStringWithLen(int(n), intern) } return "", unexpectedCodeError{ code: c, hint: "interned string", } } func (d *Decoder) decodeInternedStringIndex(extLen int) (int, error) { switch extLen { case 1: n, err := d.uint8() if err != nil { return 0, err } return int(n), nil case 2: n, err := d.uint16() if err != nil { return 0, err } return int(n), nil case 4: n, err := d.uint32() if err != nil { return 0, err } return int(n), nil } err := fmt.Errorf("msgpack: unsupported ext len=%d decoding interned string", extLen) return 0, err } func (d *Decoder) internedStringAtIndex(idx int) (string, error) { if idx >= len(d.dict) { err := fmt.Errorf("msgpack: interned string at index=%d does not exist", idx) return "", err } return d.dict[idx], nil } func (d *Decoder) decodeInternedStringWithLen(n int, intern bool) (string, error) { if n <= 0 { return "", nil } s, err := d.stringWithLen(n) if err != nil { return "", err } if intern && len(s) >= minInternedStringLen && len(d.dict) < maxDictLen { d.dict = append(d.dict, s) } return s, nil }