astc: Make IntegerEncodedValue a trivial structure

This commit is contained in:
ReinUsesLisp 2020-03-13 22:49:28 -03:00
parent 70a31eda62
commit e183820956
1 changed files with 181 additions and 216 deletions

View File

@ -160,232 +160,198 @@ private:
enum class IntegerEncoding { JustBits, Qus32, Trit }; enum class IntegerEncoding { JustBits, Qus32, Trit };
class IntegerEncodedValue { struct IntegerEncodedValue {
private: constexpr IntegerEncodedValue(IntegerEncoding encoding_, u32 num_bits_)
IntegerEncoding m_Encoding{}; : encoding{encoding_}, num_bits{num_bits_} {}
u32 m_NumBits = 0;
u32 m_BitValue = 0;
union {
u32 m_Qus32Value = 0;
u32 m_TritValue;
};
public: constexpr bool MatchesEncoding(const IntegerEncodedValue& other) const {
constexpr IntegerEncodedValue() = default; return encoding == other.encoding && num_bits == other.num_bits;
constexpr IntegerEncodedValue(IntegerEncoding encoding, u32 numBits)
: m_Encoding(encoding), m_NumBits(numBits) {}
IntegerEncoding GetEncoding() const {
return m_Encoding;
}
u32 BaseBitLength() const {
return m_NumBits;
}
u32 GetBitValue() const {
return m_BitValue;
}
void SetBitValue(u32 val) {
m_BitValue = val;
}
u32 GetTritValue() const {
return m_TritValue;
}
void SetTritValue(u32 val) {
m_TritValue = val;
}
u32 GetQus32Value() const {
return m_Qus32Value;
}
void SetQus32Value(u32 val) {
m_Qus32Value = val;
}
bool MatchesEncoding(const IntegerEncodedValue& other) const {
return m_Encoding == other.m_Encoding && m_NumBits == other.m_NumBits;
} }
// Returns the number of bits required to encode nVals values. // Returns the number of bits required to encode nVals values.
u32 GetBitLength(u32 nVals) const { u32 GetBitLength(u32 nVals) const {
u32 totalBits = m_NumBits * nVals; u32 totalBits = num_bits * nVals;
if (m_Encoding == IntegerEncoding::Trit) { if (encoding == IntegerEncoding::Trit) {
totalBits += (nVals * 8 + 4) / 5; totalBits += (nVals * 8 + 4) / 5;
} else if (m_Encoding == IntegerEncoding::Qus32) { } else if (encoding == IntegerEncoding::Qus32) {
totalBits += (nVals * 7 + 2) / 3; totalBits += (nVals * 7 + 2) / 3;
} }
return totalBits; return totalBits;
} }
// Returns a new instance of this struct that corresponds to the IntegerEncoding encoding;
// can take no more than maxval values u32 num_bits;
static IntegerEncodedValue CreateEncoding(u32 maxVal) { u32 bit_value = 0;
while (maxVal > 0) { union {
u32 check = maxVal + 1; u32 qus32_value = 0;
u32 trit_value;
};
};
// Is maxVal a power of two? static void DecodeTritBlock(InputBitStream& bits, std::vector<IntegerEncodedValue>& result,
if (!(check & (check - 1))) { u32 nBitsPerValue) {
return IntegerEncodedValue(IntegerEncoding::JustBits, Popcnt(maxVal)); // Implement the algorithm in section C.2.12
} u32 m[5];
u32 t[5];
u32 T;
// Is maxVal of the type 3*2^n - 1? // Read the trit encoded block according to
if ((check % 3 == 0) && !((check / 3) & ((check / 3) - 1))) { // table C.2.14
return IntegerEncodedValue(IntegerEncoding::Trit, Popcnt(check / 3 - 1)); m[0] = bits.ReadBits(nBitsPerValue);
} T = bits.ReadBits(2);
m[1] = bits.ReadBits(nBitsPerValue);
T |= bits.ReadBits(2) << 2;
m[2] = bits.ReadBits(nBitsPerValue);
T |= bits.ReadBit() << 4;
m[3] = bits.ReadBits(nBitsPerValue);
T |= bits.ReadBits(2) << 5;
m[4] = bits.ReadBits(nBitsPerValue);
T |= bits.ReadBit() << 7;
// Is maxVal of the type 5*2^n - 1? u32 C = 0;
if ((check % 5 == 0) && !((check / 5) & ((check / 5) - 1))) {
return IntegerEncodedValue(IntegerEncoding::Qus32, Popcnt(check / 5 - 1));
}
// Apparently it can't be represented with a bounded s32eger sequence... Bits<u32> Tb(T);
// just iterate. if (Tb(2, 4) == 7) {
maxVal--; C = (Tb(5, 7) << 2) | Tb(0, 1);
} t[4] = t[3] = 2;
return IntegerEncodedValue(IntegerEncoding::JustBits, 0); } else {
} C = Tb(0, 4);
if (Tb(5, 6) == 3) {
// Fills result with the values that are encoded in the given t[4] = 2;
// bitstream. We must know beforehand what the maximum possible t[3] = Tb[7];
// value is, and how many values we're decoding.
static void DecodeIntegerSequence(std::vector<IntegerEncodedValue>& result,
InputBitStream& bits, u32 maxRange, u32 nValues) {
// Determine encoding parameters
IntegerEncodedValue val = IntegerEncodedValue::CreateEncoding(maxRange);
// Start decoding
u32 nValsDecoded = 0;
while (nValsDecoded < nValues) {
switch (val.GetEncoding()) {
case IntegerEncoding::Qus32:
DecodeQus32Block(bits, result, val.BaseBitLength());
nValsDecoded += 3;
break;
case IntegerEncoding::Trit:
DecodeTritBlock(bits, result, val.BaseBitLength());
nValsDecoded += 5;
break;
case IntegerEncoding::JustBits:
val.SetBitValue(bits.ReadBits(val.BaseBitLength()));
result.push_back(val);
nValsDecoded++;
break;
}
}
}
private:
static void DecodeTritBlock(InputBitStream& bits, std::vector<IntegerEncodedValue>& result,
u32 nBitsPerValue) {
// Implement the algorithm in section C.2.12
u32 m[5];
u32 t[5];
u32 T;
// Read the trit encoded block according to
// table C.2.14
m[0] = bits.ReadBits(nBitsPerValue);
T = bits.ReadBits(2);
m[1] = bits.ReadBits(nBitsPerValue);
T |= bits.ReadBits(2) << 2;
m[2] = bits.ReadBits(nBitsPerValue);
T |= bits.ReadBit() << 4;
m[3] = bits.ReadBits(nBitsPerValue);
T |= bits.ReadBits(2) << 5;
m[4] = bits.ReadBits(nBitsPerValue);
T |= bits.ReadBit() << 7;
u32 C = 0;
Bits<u32> Tb(T);
if (Tb(2, 4) == 7) {
C = (Tb(5, 7) << 2) | Tb(0, 1);
t[4] = t[3] = 2;
} else { } else {
C = Tb(0, 4); t[4] = Tb[7];
if (Tb(5, 6) == 3) { t[3] = Tb(5, 6);
t[4] = 2; }
t[3] = Tb[7]; }
} else {
t[4] = Tb[7]; Bits<u32> Cb(C);
t[3] = Tb(5, 6); if (Cb(0, 1) == 3) {
} t[2] = 2;
t[1] = Cb[4];
t[0] = (Cb[3] << 1) | (Cb[2] & ~Cb[3]);
} else if (Cb(2, 3) == 3) {
t[2] = 2;
t[1] = 2;
t[0] = Cb(0, 1);
} else {
t[2] = Cb[4];
t[1] = Cb(2, 3);
t[0] = (Cb[1] << 1) | (Cb[0] & ~Cb[1]);
}
for (std::size_t i = 0; i < 5; ++i) {
IntegerEncodedValue& val = result.emplace_back(IntegerEncoding::Trit, nBitsPerValue);
val.bit_value = m[i];
val.trit_value = t[i];
}
}
static void DecodeQus32Block(InputBitStream& bits, std::vector<IntegerEncodedValue>& result,
u32 nBitsPerValue) {
// Implement the algorithm in section C.2.12
u32 m[3];
u32 q[3];
u32 Q;
// Read the trit encoded block according to
// table C.2.15
m[0] = bits.ReadBits(nBitsPerValue);
Q = bits.ReadBits(3);
m[1] = bits.ReadBits(nBitsPerValue);
Q |= bits.ReadBits(2) << 3;
m[2] = bits.ReadBits(nBitsPerValue);
Q |= bits.ReadBits(2) << 5;
Bits<u32> Qb(Q);
if (Qb(1, 2) == 3 && Qb(5, 6) == 0) {
q[0] = q[1] = 4;
q[2] = (Qb[0] << 2) | ((Qb[4] & ~Qb[0]) << 1) | (Qb[3] & ~Qb[0]);
} else {
u32 C = 0;
if (Qb(1, 2) == 3) {
q[2] = 4;
C = (Qb(3, 4) << 3) | ((~Qb(5, 6) & 3) << 1) | Qb[0];
} else {
q[2] = Qb(5, 6);
C = Qb(0, 4);
} }
Bits<u32> Cb(C); Bits<u32> Cb(C);
if (Cb(0, 1) == 3) { if (Cb(0, 2) == 5) {
t[2] = 2; q[1] = 4;
t[1] = Cb[4]; q[0] = Cb(3, 4);
t[0] = (Cb[3] << 1) | (Cb[2] & ~Cb[3]);
} else if (Cb(2, 3) == 3) {
t[2] = 2;
t[1] = 2;
t[0] = Cb(0, 1);
} else { } else {
t[2] = Cb[4]; q[1] = Cb(3, 4);
t[1] = Cb(2, 3); q[0] = Cb(0, 2);
t[0] = (Cb[1] << 1) | (Cb[0] & ~Cb[1]);
}
for (u32 i = 0; i < 5; i++) {
IntegerEncodedValue val(IntegerEncoding::Trit, nBitsPerValue);
val.SetBitValue(m[i]);
val.SetTritValue(t[i]);
result.push_back(val);
} }
} }
static void DecodeQus32Block(InputBitStream& bits, std::vector<IntegerEncodedValue>& result, for (std::size_t i = 0; i < 3; ++i) {
u32 nBitsPerValue) { IntegerEncodedValue& val = result.emplace_back(IntegerEncoding::Qus32, nBitsPerValue);
// Implement the algorithm in section C.2.12 val.bit_value = m[i];
u32 m[3]; val.qus32_value = q[i];
u32 q[3]; }
u32 Q; }
// Read the trit encoded block according to // Returns a new instance of this struct that corresponds to the
// table C.2.15 // can take no more than maxval values
m[0] = bits.ReadBits(nBitsPerValue); static IntegerEncodedValue CreateEncoding(u32 maxVal) {
Q = bits.ReadBits(3); while (maxVal > 0) {
m[1] = bits.ReadBits(nBitsPerValue); u32 check = maxVal + 1;
Q |= bits.ReadBits(2) << 3;
m[2] = bits.ReadBits(nBitsPerValue);
Q |= bits.ReadBits(2) << 5;
Bits<u32> Qb(Q); // Is maxVal a power of two?
if (Qb(1, 2) == 3 && Qb(5, 6) == 0) { if (!(check & (check - 1))) {
q[0] = q[1] = 4; return IntegerEncodedValue(IntegerEncoding::JustBits, Popcnt(maxVal));
q[2] = (Qb[0] << 2) | ((Qb[4] & ~Qb[0]) << 1) | (Qb[3] & ~Qb[0]);
} else {
u32 C = 0;
if (Qb(1, 2) == 3) {
q[2] = 4;
C = (Qb(3, 4) << 3) | ((~Qb(5, 6) & 3) << 1) | Qb[0];
} else {
q[2] = Qb(5, 6);
C = Qb(0, 4);
}
Bits<u32> Cb(C);
if (Cb(0, 2) == 5) {
q[1] = 4;
q[0] = Cb(3, 4);
} else {
q[1] = Cb(3, 4);
q[0] = Cb(0, 2);
}
} }
for (u32 i = 0; i < 3; i++) { // Is maxVal of the type 3*2^n - 1?
IntegerEncodedValue val(IntegerEncoding::Qus32, nBitsPerValue); if ((check % 3 == 0) && !((check / 3) & ((check / 3) - 1))) {
val.m_BitValue = m[i]; return IntegerEncodedValue(IntegerEncoding::Trit, Popcnt(check / 3 - 1));
val.m_Qus32Value = q[i]; }
// Is maxVal of the type 5*2^n - 1?
if ((check % 5 == 0) && !((check / 5) & ((check / 5) - 1))) {
return IntegerEncodedValue(IntegerEncoding::Qus32, Popcnt(check / 5 - 1));
}
// Apparently it can't be represented with a bounded s32eger sequence...
// just iterate.
maxVal--;
}
return IntegerEncodedValue(IntegerEncoding::JustBits, 0);
}
// Fills result with the values that are encoded in the given
// bitstream. We must know beforehand what the maximum possible
// value is, and how many values we're decoding.
static void DecodeIntegerSequence(std::vector<IntegerEncodedValue>& result, InputBitStream& bits,
u32 maxRange, u32 nValues) {
// Determine encoding parameters
IntegerEncodedValue val = CreateEncoding(maxRange);
// Start decoding
u32 nValsDecoded = 0;
while (nValsDecoded < nValues) {
switch (val.encoding) {
case IntegerEncoding::Qus32:
DecodeQus32Block(bits, result, val.num_bits);
nValsDecoded += 3;
break;
case IntegerEncoding::Trit:
DecodeTritBlock(bits, result, val.num_bits);
nValsDecoded += 5;
break;
case IntegerEncoding::JustBits:
val.bit_value = bits.ReadBits(val.num_bits);
result.push_back(val); result.push_back(val);
nValsDecoded++;
break;
} }
} }
}; }
namespace ASTCC { namespace ASTCC {
@ -405,7 +371,7 @@ struct TexelWeightParams {
nIdxs *= 2; nIdxs *= 2;
} }
return IntegerEncodedValue::CreateEncoding(m_MaxWeight).GetBitLength(nIdxs); return CreateEncoding(m_MaxWeight).GetBitLength(nIdxs);
} }
u32 GetNumWeightValues() const { u32 GetNumWeightValues() const {
@ -814,12 +780,12 @@ static void DecodeColorValues(u32* out, u8* data, const u32* modes, const u32 nP
// figure out the max value for each of them... // figure out the max value for each of them...
u32 range = 256; u32 range = 256;
while (--range > 0) { while (--range > 0) {
IntegerEncodedValue val = IntegerEncodedValue::CreateEncoding(range); IntegerEncodedValue val = CreateEncoding(range);
u32 bitLength = val.GetBitLength(nValues); u32 bitLength = val.GetBitLength(nValues);
if (bitLength <= nBitsForColorData) { if (bitLength <= nBitsForColorData) {
// Find the smallest possible range that matches the given encoding // Find the smallest possible range that matches the given encoding
while (--range > 0) { while (--range > 0) {
IntegerEncodedValue newval = IntegerEncodedValue::CreateEncoding(range); IntegerEncodedValue newval = CreateEncoding(range);
if (!newval.MatchesEncoding(val)) { if (!newval.MatchesEncoding(val)) {
break; break;
} }
@ -834,7 +800,7 @@ static void DecodeColorValues(u32* out, u8* data, const u32* modes, const u32 nP
// We now have enough to decode our s32eger sequence. // We now have enough to decode our s32eger sequence.
std::vector<IntegerEncodedValue> decodedColorValues; std::vector<IntegerEncodedValue> decodedColorValues;
InputBitStream colorStream(data); InputBitStream colorStream(data);
IntegerEncodedValue::DecodeIntegerSequence(decodedColorValues, colorStream, range, nValues); DecodeIntegerSequence(decodedColorValues, colorStream, range, nValues);
// Once we have the decoded values, we need to dequantize them to the 0-255 range // Once we have the decoded values, we need to dequantize them to the 0-255 range
// This procedure is outlined in ASTC spec C.2.13 // This procedure is outlined in ASTC spec C.2.13
@ -846,8 +812,8 @@ static void DecodeColorValues(u32* out, u8* data, const u32* modes, const u32 nP
} }
const IntegerEncodedValue& val = *itr; const IntegerEncodedValue& val = *itr;
u32 bitlen = val.BaseBitLength(); u32 bitlen = val.num_bits;
u32 bitval = val.GetBitValue(); u32 bitval = val.bit_value;
assert(bitlen >= 1); assert(bitlen >= 1);
@ -855,7 +821,7 @@ static void DecodeColorValues(u32* out, u8* data, const u32* modes, const u32 nP
// A is just the lsb replicated 9 times. // A is just the lsb replicated 9 times.
A = Replicate(bitval & 1, 1, 9); A = Replicate(bitval & 1, 1, 9);
switch (val.GetEncoding()) { switch (val.encoding) {
// Replicate bits // Replicate bits
case IntegerEncoding::JustBits: case IntegerEncoding::JustBits:
out[outIdx++] = Replicate(bitval, bitlen, 8); out[outIdx++] = Replicate(bitval, bitlen, 8);
@ -864,7 +830,7 @@ static void DecodeColorValues(u32* out, u8* data, const u32* modes, const u32 nP
// Use algorithm in C.2.13 // Use algorithm in C.2.13
case IntegerEncoding::Trit: { case IntegerEncoding::Trit: {
D = val.GetTritValue(); D = val.trit_value;
switch (bitlen) { switch (bitlen) {
case 1: { case 1: {
@ -915,7 +881,7 @@ static void DecodeColorValues(u32* out, u8* data, const u32* modes, const u32 nP
case IntegerEncoding::Qus32: { case IntegerEncoding::Qus32: {
D = val.GetQus32Value(); D = val.qus32_value;
switch (bitlen) { switch (bitlen) {
case 1: { case 1: {
@ -956,9 +922,9 @@ static void DecodeColorValues(u32* out, u8* data, const u32* modes, const u32 nP
} // switch(bitlen) } // switch(bitlen)
} // case IntegerEncoding::Qus32 } // case IntegerEncoding::Qus32
break; break;
} // switch(val.GetEncoding()) } // switch(val.encoding)
if (val.GetEncoding() != IntegerEncoding::JustBits) { if (val.encoding != IntegerEncoding::JustBits) {
u32 T = D * C + B; u32 T = D * C + B;
T ^= A; T ^= A;
T = (A & 0x80) | (T >> 2); T = (A & 0x80) | (T >> 2);
@ -973,20 +939,20 @@ static void DecodeColorValues(u32* out, u8* data, const u32* modes, const u32 nP
} }
static u32 UnquantizeTexelWeight(const IntegerEncodedValue& val) { static u32 UnquantizeTexelWeight(const IntegerEncodedValue& val) {
u32 bitval = val.GetBitValue(); u32 bitval = val.bit_value;
u32 bitlen = val.BaseBitLength(); u32 bitlen = val.num_bits;
u32 A = Replicate(bitval & 1, 1, 7); u32 A = Replicate(bitval & 1, 1, 7);
u32 B = 0, C = 0, D = 0; u32 B = 0, C = 0, D = 0;
u32 result = 0; u32 result = 0;
switch (val.GetEncoding()) { switch (val.encoding) {
case IntegerEncoding::JustBits: case IntegerEncoding::JustBits:
result = Replicate(bitval, bitlen, 6); result = Replicate(bitval, bitlen, 6);
break; break;
case IntegerEncoding::Trit: { case IntegerEncoding::Trit: {
D = val.GetTritValue(); D = val.trit_value;
assert(D < 3); assert(D < 3);
switch (bitlen) { switch (bitlen) {
@ -1018,7 +984,7 @@ static u32 UnquantizeTexelWeight(const IntegerEncodedValue& val) {
} break; } break;
case IntegerEncoding::Qus32: { case IntegerEncoding::Qus32: {
D = val.GetQus32Value(); D = val.qus32_value;
assert(D < 5); assert(D < 5);
switch (bitlen) { switch (bitlen) {
@ -1044,7 +1010,7 @@ static u32 UnquantizeTexelWeight(const IntegerEncodedValue& val) {
} break; } break;
} }
if (val.GetEncoding() != IntegerEncoding::JustBits && bitlen > 0) { if (val.encoding != IntegerEncoding::JustBits && bitlen > 0) {
// Decode the value... // Decode the value...
result = D * C + B; result = D * C + B;
result ^= A; result ^= A;
@ -1562,9 +1528,8 @@ static void DecompressBlock(const u8 inBuf[16], const u32 blockWidth, const u32
std::vector<IntegerEncodedValue> texelWeightValues; std::vector<IntegerEncodedValue> texelWeightValues;
InputBitStream weightStream(texelWeightData); InputBitStream weightStream(texelWeightData);
IntegerEncodedValue::DecodeIntegerSequence(texelWeightValues, weightStream, DecodeIntegerSequence(texelWeightValues, weightStream, weightParams.m_MaxWeight,
weightParams.m_MaxWeight, weightParams.GetNumWeightValues());
weightParams.GetNumWeightValues());
// Blocks can be at most 12x12, so we can have as many as 144 weights // Blocks can be at most 12x12, so we can have as many as 144 weights
u32 weights[2][144]; u32 weights[2][144];