00001
00002
00003
00004
00005
00006
00007
00008 #include "pch.h"
00009
00010 #include "zinflate.h"
00011 #include "secblock.h"
00012 #include "smartptr.h"
00013
00014 NAMESPACE_BEGIN(CryptoPP)
00015
00016 struct CodeLessThan
00017 {
00018 inline bool operator()(CryptoPP::HuffmanDecoder::code_t lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs)
00019 {return lhs < rhs.code;}
00020
00021 inline bool operator()(const CryptoPP::HuffmanDecoder::CodeInfo &lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs)
00022 {return lhs.code < rhs.code;}
00023 };
00024
00025 inline bool LowFirstBitReader::FillBuffer(unsigned int length)
00026 {
00027 while (m_bitsBuffered < length)
00028 {
00029 byte b;
00030 if (!m_store.Get(b))
00031 return false;
00032 m_buffer |= (unsigned long)b << m_bitsBuffered;
00033 m_bitsBuffered += 8;
00034 }
00035 assert(m_bitsBuffered <= sizeof(unsigned long)*8);
00036 return true;
00037 }
00038
00039 inline unsigned long LowFirstBitReader::PeekBits(unsigned int length)
00040 {
00041 bool result = FillBuffer(length);
00042 CRYPTOPP_UNUSED(result); assert(result);
00043 return m_buffer & (((unsigned long)1 << length) - 1);
00044 }
00045
00046 inline void LowFirstBitReader::SkipBits(unsigned int length)
00047 {
00048 assert(m_bitsBuffered >= length);
00049 m_buffer >>= length;
00050 m_bitsBuffered -= length;
00051 }
00052
00053 inline unsigned long LowFirstBitReader::GetBits(unsigned int length)
00054 {
00055 unsigned long result = PeekBits(length);
00056 SkipBits(length);
00057 return result;
00058 }
00059
00060 inline HuffmanDecoder::code_t HuffmanDecoder::NormalizeCode(HuffmanDecoder::code_t code, unsigned int codeBits)
00061 {
00062 return code << (MAX_CODE_BITS - codeBits);
00063 }
00064
00065 void HuffmanDecoder::Initialize(const unsigned int *codeBits, unsigned int nCodes)
00066 {
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082 if (nCodes == 0)
00083 throw Err("null code");
00084
00085 m_maxCodeBits = *std::max_element(codeBits, codeBits+nCodes);
00086
00087 if (m_maxCodeBits > MAX_CODE_BITS)
00088 throw Err("code length exceeds maximum");
00089
00090 if (m_maxCodeBits == 0)
00091 throw Err("null code");
00092
00093
00094 SecBlockWithHint<unsigned int, 15+1> blCount(m_maxCodeBits+1);
00095 std::fill(blCount.begin(), blCount.end(), 0);
00096 unsigned int i;
00097 for (i=0; i<nCodes; i++)
00098 blCount[codeBits[i]]++;
00099
00100
00101 code_t code = 0;
00102 SecBlockWithHint<code_t, 15+1> nextCode(m_maxCodeBits+1);
00103 nextCode[1] = 0;
00104 for (i=2; i<=m_maxCodeBits; i++)
00105 {
00106
00107 if (code > code + blCount[i-1])
00108 throw Err("codes oversubscribed");
00109 code += blCount[i-1];
00110 if (code > (code << 1))
00111 throw Err("codes oversubscribed");
00112 code <<= 1;
00113 nextCode[i] = code;
00114 }
00115
00116
00117 const unsigned long long shiftedMaxCode = (1ULL << m_maxCodeBits);
00118 if (code > shiftedMaxCode - blCount[m_maxCodeBits])
00119 throw Err("codes oversubscribed");
00120 else if (m_maxCodeBits != 1 && code < shiftedMaxCode - blCount[m_maxCodeBits])
00121 throw Err("codes incomplete");
00122
00123
00124 m_codeToValue.resize(nCodes - blCount[0]);
00125 unsigned int j=0;
00126 for (i=0; i<nCodes; i++)
00127 {
00128 unsigned int len = codeBits[i];
00129 if (len != 0)
00130 {
00131 code = NormalizeCode(nextCode[len]++, len);
00132 m_codeToValue[j].code = code;
00133 m_codeToValue[j].len = len;
00134 m_codeToValue[j].value = i;
00135 j++;
00136 }
00137 }
00138 std::sort(m_codeToValue.begin(), m_codeToValue.end());
00139
00140
00141 m_cacheBits = STDMIN(9U, m_maxCodeBits);
00142 m_cacheMask = (1 << m_cacheBits) - 1;
00143 m_normalizedCacheMask = NormalizeCode(m_cacheMask, m_cacheBits);
00144 assert(m_normalizedCacheMask == BitReverse(m_cacheMask));
00145
00146 const unsigned long long shiftedCache = (1ULL << m_cacheBits);
00147 assert(shiftedCache <= SIZE_MAX);
00148 if (m_cache.size() != shiftedCache)
00149 m_cache.resize((size_t)shiftedCache);
00150
00151 for (i=0; i<m_cache.size(); i++)
00152 m_cache[i].type = 0;
00153 }
00154
00155 void HuffmanDecoder::FillCacheEntry(LookupEntry &entry, code_t normalizedCode) const
00156 {
00157 normalizedCode &= m_normalizedCacheMask;
00158 const CodeInfo &codeInfo = *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode, CodeLessThan())-1);
00159 if (codeInfo.len <= m_cacheBits)
00160 {
00161 entry.type = 1;
00162 entry.value = codeInfo.value;
00163 entry.len = codeInfo.len;
00164 }
00165 else
00166 {
00167 entry.begin = &codeInfo;
00168 const CodeInfo *last = & *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode + ~m_normalizedCacheMask, CodeLessThan())-1);
00169 if (codeInfo.len == last->len)
00170 {
00171 entry.type = 2;
00172 entry.len = codeInfo.len;
00173 }
00174 else
00175 {
00176 entry.type = 3;
00177 entry.end = last+1;
00178 }
00179 }
00180 }
00181
00182 inline unsigned int HuffmanDecoder::Decode(code_t code, value_t &value) const
00183 {
00184 assert(m_codeToValue.size() > 0);
00185 LookupEntry &entry = m_cache[code & m_cacheMask];
00186
00187 code_t normalizedCode = 0;
00188 if (entry.type != 1)
00189 normalizedCode = BitReverse(code);
00190
00191 if (entry.type == 0)
00192 FillCacheEntry(entry, normalizedCode);
00193
00194 if (entry.type == 1)
00195 {
00196 value = entry.value;
00197 return entry.len;
00198 }
00199 else
00200 {
00201 const CodeInfo &codeInfo = (entry.type == 2)
00202 ? entry.begin[(normalizedCode << m_cacheBits) >> (MAX_CODE_BITS - (entry.len - m_cacheBits))]
00203 : *(std::upper_bound(entry.begin, entry.end, normalizedCode, CodeLessThan())-1);
00204 value = codeInfo.value;
00205 return codeInfo.len;
00206 }
00207 }
00208
00209 bool HuffmanDecoder::Decode(LowFirstBitReader &reader, value_t &value) const
00210 {
00211 bool result = reader.FillBuffer(m_maxCodeBits);
00212 if(!result) return false;
00213
00214 unsigned int codeBits = Decode(reader.PeekBuffer(), value);
00215 if (codeBits > reader.BitsBuffered())
00216 return false;
00217 reader.SkipBits(codeBits);
00218 return true;
00219 }
00220
00221
00222
00223 Inflator::Inflator(BufferedTransformation *attachment, bool repeat, int propagation)
00224 : AutoSignaling<Filter>(propagation)
00225 , m_state(PRE_STREAM), m_repeat(repeat), m_eof(0), m_wrappedAround(0)
00226 , m_blockType(0xff), m_storedLen(0xffff), m_nextDecode(), m_literal(0)
00227 , m_distance(0), m_reader(m_inQueue), m_current(0), m_lastFlush(0)
00228 {
00229 Detach(attachment);
00230 }
00231
00232 void Inflator::IsolatedInitialize(const NameValuePairs ¶meters)
00233 {
00234 m_state = PRE_STREAM;
00235 parameters.GetValue("Repeat", m_repeat);
00236 m_inQueue.Clear();
00237 m_reader.SkipBits(m_reader.BitsBuffered());
00238 }
00239
00240 void Inflator::OutputByte(byte b)
00241 {
00242 m_window[m_current++] = b;
00243 if (m_current == m_window.size())
00244 {
00245 ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush);
00246 m_lastFlush = 0;
00247 m_current = 0;
00248 m_wrappedAround = true;
00249 }
00250 }
00251
00252 void Inflator::OutputString(const byte *string, size_t length)
00253 {
00254 while (length)
00255 {
00256 size_t len = UnsignedMin(length, m_window.size() - m_current);
00257 memcpy(m_window + m_current, string, len);
00258 m_current += len;
00259 if (m_current == m_window.size())
00260 {
00261 ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush);
00262 m_lastFlush = 0;
00263 m_current = 0;
00264 m_wrappedAround = true;
00265 }
00266 string += len;
00267 length -= len;
00268 }
00269 }
00270
00271 void Inflator::OutputPast(unsigned int length, unsigned int distance)
00272 {
00273 size_t start;
00274 if (distance <= m_current)
00275 start = m_current - distance;
00276 else if (m_wrappedAround && distance <= m_window.size())
00277 start = m_current + m_window.size() - distance;
00278 else
00279 throw BadBlockErr();
00280
00281 if (start + length > m_window.size())
00282 {
00283 for (; start < m_window.size(); start++, length--)
00284 OutputByte(m_window[start]);
00285 start = 0;
00286 }
00287
00288 if (start + length > m_current || m_current + length >= m_window.size())
00289 {
00290 while (length--)
00291 OutputByte(m_window[start++]);
00292 }
00293 else
00294 {
00295 memcpy(m_window + m_current, m_window + start, length);
00296 m_current += length;
00297 }
00298 }
00299
00300 size_t Inflator::Put2(const byte *inString, size_t length, int messageEnd, bool blocking)
00301 {
00302 if (!blocking)
00303 throw BlockingInputOnly("Inflator");
00304
00305 LazyPutter lp(m_inQueue, inString, length);
00306 ProcessInput(messageEnd != 0);
00307
00308 if (messageEnd)
00309 if (!(m_state == PRE_STREAM || m_state == AFTER_END))
00310 throw UnexpectedEndErr();
00311
00312 Output(0, NULL, 0, messageEnd, blocking);
00313 return 0;
00314 }
00315
00316 bool Inflator::IsolatedFlush(bool hardFlush, bool blocking)
00317 {
00318 if (!blocking)
00319 throw BlockingInputOnly("Inflator");
00320
00321 if (hardFlush)
00322 ProcessInput(true);
00323 FlushOutput();
00324
00325 return false;
00326 }
00327
00328 void Inflator::ProcessInput(bool flush)
00329 {
00330 while (true)
00331 {
00332 switch (m_state)
00333 {
00334 case PRE_STREAM:
00335 if (!flush && m_inQueue.CurrentSize() < MaxPrestreamHeaderSize())
00336 return;
00337 ProcessPrestreamHeader();
00338 m_state = WAIT_HEADER;
00339 m_wrappedAround = false;
00340 m_current = 0;
00341 m_lastFlush = 0;
00342 m_window.New(1 << GetLog2WindowSize());
00343 break;
00344 case WAIT_HEADER:
00345 {
00346
00347 const size_t MAX_HEADER_SIZE = BitsToBytes(3+5+5+4+19*7+286*15+19*15);
00348 if (m_inQueue.CurrentSize() < (flush ? 1 : MAX_HEADER_SIZE))
00349 return;
00350 DecodeHeader();
00351 break;
00352 }
00353 case DECODING_BODY:
00354 if (!DecodeBody())
00355 return;
00356 break;
00357 case POST_STREAM:
00358 if (!flush && m_inQueue.CurrentSize() < MaxPoststreamTailSize())
00359 return;
00360 ProcessPoststreamTail();
00361 m_state = m_repeat ? PRE_STREAM : AFTER_END;
00362 Output(0, NULL, 0, GetAutoSignalPropagation(), true);
00363 if (m_inQueue.IsEmpty())
00364 return;
00365 break;
00366 case AFTER_END:
00367 m_inQueue.TransferTo(*AttachedTransformation());
00368 return;
00369 }
00370 }
00371 }
00372
00373 void Inflator::DecodeHeader()
00374 {
00375 if (!m_reader.FillBuffer(3))
00376 throw UnexpectedEndErr();
00377 m_eof = m_reader.GetBits(1) != 0;
00378 m_blockType = (byte)m_reader.GetBits(2);
00379 switch (m_blockType)
00380 {
00381 case 0:
00382 {
00383 m_reader.SkipBits(m_reader.BitsBuffered() % 8);
00384 if (!m_reader.FillBuffer(32))
00385 throw UnexpectedEndErr();
00386 m_storedLen = (word16)m_reader.GetBits(16);
00387 word16 nlen = (word16)m_reader.GetBits(16);
00388 if (nlen != (word16)~m_storedLen)
00389 throw BadBlockErr();
00390 break;
00391 }
00392 case 1:
00393 m_nextDecode = LITERAL;
00394 break;
00395 case 2:
00396 {
00397 if (!m_reader.FillBuffer(5+5+4))
00398 throw UnexpectedEndErr();
00399 unsigned int hlit = m_reader.GetBits(5);
00400 unsigned int hdist = m_reader.GetBits(5);
00401 unsigned int hclen = m_reader.GetBits(4);
00402
00403 FixedSizeSecBlock<unsigned int, 286+32> codeLengths;
00404 unsigned int i;
00405 static const unsigned int border[] = {
00406 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
00407 std::fill(codeLengths.begin(), codeLengths+19, 0);
00408 for (i=0; i<hclen+4; i++)
00409 codeLengths[border[i]] = m_reader.GetBits(3);
00410
00411 try
00412 {
00413 HuffmanDecoder codeLengthDecoder(codeLengths, 19);
00414 for (i = 0; i < hlit+257+hdist+1; )
00415 {
00416 unsigned int k = 0, count = 0, repeater = 0;
00417 bool result = codeLengthDecoder.Decode(m_reader, k);
00418 if (!result)
00419 throw UnexpectedEndErr();
00420 if (k <= 15)
00421 {
00422 count = 1;
00423 repeater = k;
00424 }
00425 else switch (k)
00426 {
00427 case 16:
00428 if (!m_reader.FillBuffer(2))
00429 throw UnexpectedEndErr();
00430 count = 3 + m_reader.GetBits(2);
00431 if (i == 0)
00432 throw BadBlockErr();
00433 repeater = codeLengths[i-1];
00434 break;
00435 case 17:
00436 if (!m_reader.FillBuffer(3))
00437 throw UnexpectedEndErr();
00438 count = 3 + m_reader.GetBits(3);
00439 repeater = 0;
00440 break;
00441 case 18:
00442 if (!m_reader.FillBuffer(7))
00443 throw UnexpectedEndErr();
00444 count = 11 + m_reader.GetBits(7);
00445 repeater = 0;
00446 break;
00447 }
00448 if (i + count > hlit+257+hdist+1)
00449 throw BadBlockErr();
00450 std::fill(codeLengths + i, codeLengths + i + count, repeater);
00451 i += count;
00452 }
00453 m_dynamicLiteralDecoder.Initialize(codeLengths, hlit+257);
00454 if (hdist == 0 && codeLengths[hlit+257] == 0)
00455 {
00456 if (hlit != 0)
00457 throw BadBlockErr();
00458 }
00459 else
00460 m_dynamicDistanceDecoder.Initialize(codeLengths+hlit+257, hdist+1);
00461 m_nextDecode = LITERAL;
00462 }
00463 catch (HuffmanDecoder::Err &)
00464 {
00465 throw BadBlockErr();
00466 }
00467 break;
00468 }
00469 default:
00470 throw BadBlockErr();
00471 }
00472 m_state = DECODING_BODY;
00473 }
00474
00475 bool Inflator::DecodeBody()
00476 {
00477 bool blockEnd = false;
00478 switch (m_blockType)
00479 {
00480 case 0:
00481 assert(m_reader.BitsBuffered() == 0);
00482 while (!m_inQueue.IsEmpty() && !blockEnd)
00483 {
00484 size_t size;
00485 const byte *block = m_inQueue.Spy(size);
00486 size = UnsignedMin(m_storedLen, size);
00487 assert(size <= 0xffff);
00488
00489 OutputString(block, size);
00490 m_inQueue.Skip(size);
00491 m_storedLen = m_storedLen - (word16)size;
00492 if (m_storedLen == 0)
00493 blockEnd = true;
00494 }
00495 break;
00496 case 1:
00497 case 2:
00498 static const unsigned int lengthStarts[] = {
00499 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31,
00500 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258};
00501 static const unsigned int lengthExtraBits[] = {
00502 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
00503 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0};
00504 static const unsigned int distanceStarts[] = {
00505 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193,
00506 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145,
00507 8193, 12289, 16385, 24577};
00508 static const unsigned int distanceExtraBits[] = {
00509 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
00510 7, 7, 8, 8, 9, 9, 10, 10, 11, 11,
00511 12, 12, 13, 13};
00512
00513 const HuffmanDecoder& literalDecoder = GetLiteralDecoder();
00514 const HuffmanDecoder& distanceDecoder = GetDistanceDecoder();
00515
00516 switch (m_nextDecode)
00517 {
00518 case LITERAL:
00519 while (true)
00520 {
00521 if (!literalDecoder.Decode(m_reader, m_literal))
00522 {
00523 m_nextDecode = LITERAL;
00524 break;
00525 }
00526 if (m_literal < 256)
00527 OutputByte((byte)m_literal);
00528 else if (m_literal == 256)
00529 {
00530 blockEnd = true;
00531 break;
00532 }
00533 else
00534 {
00535 if (m_literal > 285)
00536 throw BadBlockErr();
00537 unsigned int bits;
00538 case LENGTH_BITS:
00539 bits = lengthExtraBits[m_literal-257];
00540 if (!m_reader.FillBuffer(bits))
00541 {
00542 m_nextDecode = LENGTH_BITS;
00543 break;
00544 }
00545 m_literal = m_reader.GetBits(bits) + lengthStarts[m_literal-257];
00546 case DISTANCE:
00547 if (!distanceDecoder.Decode(m_reader, m_distance))
00548 {
00549 m_nextDecode = DISTANCE;
00550 break;
00551 }
00552 case DISTANCE_BITS:
00553 bits = distanceExtraBits[m_distance];
00554 if (!m_reader.FillBuffer(bits))
00555 {
00556 m_nextDecode = DISTANCE_BITS;
00557 break;
00558 }
00559 m_distance = m_reader.GetBits(bits) + distanceStarts[m_distance];
00560 OutputPast(m_literal, m_distance);
00561 }
00562 }
00563 break;
00564 default:
00565 assert(0);
00566 }
00567 }
00568 if (blockEnd)
00569 {
00570 if (m_eof)
00571 {
00572 FlushOutput();
00573 m_reader.SkipBits(m_reader.BitsBuffered()%8);
00574 if (m_reader.BitsBuffered())
00575 {
00576
00577 SecBlockWithHint<byte, 4> buffer(m_reader.BitsBuffered() / 8);
00578 for (unsigned int i=0; i<buffer.size(); i++)
00579 buffer[i] = (byte)m_reader.GetBits(8);
00580 m_inQueue.Unget(buffer, buffer.size());
00581 }
00582 m_state = POST_STREAM;
00583 }
00584 else
00585 m_state = WAIT_HEADER;
00586 }
00587 return blockEnd;
00588 }
00589
00590 void Inflator::FlushOutput()
00591 {
00592 if (m_state != PRE_STREAM)
00593 {
00594 assert(m_current >= m_lastFlush);
00595 ProcessDecompressedData(m_window + m_lastFlush, m_current - m_lastFlush);
00596 m_lastFlush = m_current;
00597 }
00598 }
00599
00600 struct NewFixedLiteralDecoder
00601 {
00602 HuffmanDecoder * operator()() const
00603 {
00604 unsigned int codeLengths[288];
00605 std::fill(codeLengths + 0, codeLengths + 144, 8);
00606 std::fill(codeLengths + 144, codeLengths + 256, 9);
00607 std::fill(codeLengths + 256, codeLengths + 280, 7);
00608 std::fill(codeLengths + 280, codeLengths + 288, 8);
00609 member_ptr<HuffmanDecoder> pDecoder(new HuffmanDecoder);
00610 pDecoder->Initialize(codeLengths, 288);
00611 return pDecoder.release();
00612 }
00613 };
00614
00615 struct NewFixedDistanceDecoder
00616 {
00617 HuffmanDecoder * operator()() const
00618 {
00619 unsigned int codeLengths[32];
00620 std::fill(codeLengths + 0, codeLengths + 32, 5);
00621 member_ptr<HuffmanDecoder> pDecoder(new HuffmanDecoder);
00622 pDecoder->Initialize(codeLengths, 32);
00623 return pDecoder.release();
00624 }
00625 };
00626
00627 const HuffmanDecoder& Inflator::GetLiteralDecoder() const
00628 {
00629 return m_blockType == 1 ? Singleton<HuffmanDecoder, NewFixedLiteralDecoder>().Ref() : m_dynamicLiteralDecoder;
00630 }
00631
00632 const HuffmanDecoder& Inflator::GetDistanceDecoder() const
00633 {
00634 return m_blockType == 1 ? Singleton<HuffmanDecoder, NewFixedDistanceDecoder>().Ref() : m_dynamicDistanceDecoder;
00635 }
00636
00637 NAMESPACE_END