00001 #ifndef __J2K__LZH__LZHLDecompressor_CPP__
00002 #define __J2K__LZH__LZHLDecompressor_CPP__
00003
00004 LZHLDecompressor::LZHLDecompressor() {
00005 nBits = 0;
00006 bits = 0;
00007 }
00008
00009 LZHLDecompressor::~LZHLDecompressor() { }
00010
00011 inline int LZHLDecompressor::_get( const BYTE*& src, const BYTE* srcEnd, int n )
00012 {
00013 assert( n <= 8 );
00014 if ( nBits < n ) {
00015 if ( src >= srcEnd ) {
00016 nBits = 0;
00017 return -1;
00018 }
00019
00020 bits |= ( *src++ << ( 24 - nBits ) );
00021 nBits += 8;
00022 }
00023
00024 int ret = bits >> ( 32 - n );
00025 bits <<= n;
00026 nBits -= n;
00027 return ret;
00028 }
00029
00030 BOOL LZHLDecompressor::decompress( BYTE* dst, size_t* dstSz, const BYTE* src, size_t* srcSz )
00031 {
00032 BYTE* startDst = dst;
00033 const BYTE* startSrc = src;
00034 const BYTE* endSrc = src + *srcSz;
00035 const BYTE* endDst = dst + *dstSz;
00036 nBits = 0;
00037
00038 for (;;) {
00039 int grp = _get( src, endSrc, 4 );
00040 if ( grp < 0 ) {
00041 return FALSE;
00042 }
00043
00044 Group& group = groupTable[ grp ];
00045
00046 int symbol;
00047 int nBits = group.nBits;
00048
00049 if ( nBits == 0 ) {
00050 symbol = symbolTable[ group.pos ];
00051 } else {
00052 assert( nBits <= 8 );
00053 int got = _get( src, endSrc, nBits );
00054
00055 if ( got < 0 ) {
00056 return FALSE;
00057 }
00058
00059 int pos = group.pos + got;
00060
00061 if ( pos >= NHUFFSYMBOLS ) {
00062 return FALSE;
00063 }
00064
00065 symbol = symbolTable[ pos ];
00066 }
00067
00068 assert( symbol < NHUFFSYMBOLS );
00069 ++stat[ symbol ];
00070
00071 int matchOver;
00072 BOOL shift = FALSE;
00073
00074 if ( symbol < 256 ) {
00075 if ( dst >= endDst ) {
00076 return FALSE;
00077 }
00078
00079 *dst++ = (BYTE)symbol;
00080 _toBuf( symbol );
00081 continue;
00082
00083 } else if ( symbol == NHUFFSYMBOLS - 2 ) {
00084 HuffStatTmpStruct s[ NHUFFSYMBOLS ];
00085 makeSortedTmp( s );
00086
00087 for ( int i=0; i < NHUFFSYMBOLS ; ++i )
00088 symbolTable[ i ] = s[ i ].i;
00089
00090 int lastNBits = 0;
00091 int pos = 0;
00092
00093 for ( i=0; i < 16 ; ++i ) {
00094
00095 for ( int n=0 ;; ++n )
00096 if ( _get( src, endSrc, 1 ) )
00097 break;
00098
00099 lastNBits += n;
00100 groupTable[ i ].nBits = lastNBits;
00101 groupTable[ i ].pos = pos;
00102
00103 pos += 1 << lastNBits;
00104 }
00105
00106 assert( pos < NHUFFSYMBOLS + 255 );
00107 continue;
00108
00109
00110 } else if ( symbol == NHUFFSYMBOLS - 1 )
00111 break;
00112
00113 static struct MatchOverItem {
00114 int nExtraBits;
00115 int base;
00116 } _matchOverTable[] = {
00117 { 1, 8 },
00118 { 2, 10 },
00119 { 3, 14 },
00120 { 4, 22 },
00121 { 5, 38 },
00122 { 6, 70 },
00123 { 7, 134 },
00124 { 8, 262 }
00125 };
00126
00127 if ( symbol < 256 + 8 ) {
00128 matchOver = symbol - 256;
00129 } else {
00130 MatchOverItem* item = &_matchOverTable[ symbol - 256 - 8 ];
00131 int extra = _get( src, endSrc, item->nExtraBits );
00132
00133 if ( extra < 0 ) {
00134 return FALSE;
00135 }
00136
00137 matchOver = item->base + extra;
00138 }
00139
00140 int dispPrefix = _get( src, endSrc, 3 );
00141
00142 if ( dispPrefix < 0 ) {
00143 return FALSE;
00144 }
00145
00146 static struct DispItem {
00147 int nBits;
00148 int disp;
00149 } _dispTable[] = {
00150 { 0, 0 },
00151 { 0, 1 },
00152 { 1, 2 },
00153 { 2, 4 },
00154 { 3, 8 },
00155 { 4, 16 },
00156 { 5, 32 },
00157 { 6, 64 }
00158 };
00159
00160 DispItem* item = &_dispTable[ dispPrefix ];
00161 nBits = item->nBits + LZBUFBITS - 7;
00162
00163 int disp = 0;
00164 assert( nBits <= 16 );
00165
00166 if ( nBits > 8 ) {
00167 nBits -= 8;
00168 disp |= _get( src, endSrc, 8 ) << nBits;
00169 }
00170
00171 assert( nBits <= 8 );
00172 int got = _get( src, endSrc, nBits );
00173
00174 if ( got < 0 ) {
00175 return FALSE;
00176 }
00177
00178 disp |= got;
00179 disp += item->disp << (LZBUFBITS - 7);
00180 assert( disp >=0 && disp < LZBUFSIZE );
00181
00182 int matchLen = matchOver + LZMIN;
00183
00184 if ( dst + matchLen > endDst ) {
00185 return FALSE;
00186 }
00187
00188 int pos = bufPos - disp;
00189 if ( matchLen < disp ) {
00190 _bufCpy( dst, pos, matchLen );
00191 } else {
00192 _bufCpy( dst, pos, disp );
00193
00194 for ( int i=0; i < matchLen - disp; ++i ) {
00195 dst[ i + disp ] = dst[ i ];
00196 }
00197 }
00198
00199 _toBuf( dst, matchLen );
00200 dst += matchLen;
00201
00202 }
00203
00204 if ( dstSz )
00205 *dstSz -= dst - startDst;
00206
00207 if ( srcSz )
00208 *srcSz -= src - startSrc;
00209
00210 return TRUE;
00211 }
00212
00213 #endif