00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifdef _WIN32
00013 #include <windows.h>
00014 #endif
00015
00016 #include <assert.h>
00017 #include <map>
00018
00019 using namespace std;
00020
00021 #include "lzhl.h"
00022 #include "lzhl_tcp.h"
00023
00024 typedef unsigned char BYTE;
00025
00026 struct LZHL_SOCKET {
00027 public:
00028 LZHL_CHANDLE ch;
00029 LZHL_DHANDLE dh;
00030
00031 BYTE* dBuf;
00032 int dSz;
00033 int dDisp;
00034
00035 public:
00036 LZHL_SOCKET() {
00037 ch = LZHL_CHANDLE_NULL;
00038 dh = LZHL_DHANDLE_NULL;
00039 dBuf = 0;
00040 }
00041 };
00042
00043 typedef map< SOCKET, LZHL_SOCKET > GlobalMapType;
00044 static GlobalMapType globalMap;
00045
00046 SOCKET lzhl_socket( int af, int type, int protocol ) {
00047 SOCKET sock = socket( af, type, protocol );
00048
00049 if ( sock >= 0 )
00050 globalMap.insert( GlobalMapType::value_type( sock, LZHL_SOCKET() ) );
00051
00052 return sock;
00053 }
00054
00055 SOCKET lzhl_accept( SOCKET s, struct sockaddr* addr, int* addrlen ) {
00056 SOCKET sock = accept( s, addr, addrlen );
00057
00058 if ( sock >= 0 )
00059 globalMap.insert( GlobalMapType::value_type( sock, LZHL_SOCKET() ) );
00060
00061 return sock;
00062 }
00063
00064 static void _putInt( BYTE*& p, unsigned int val ) {
00065 for(;;) {
00066 if ( val <= 127 ) {
00067 *p++ = (BYTE)val;
00068 break;
00069
00070 } else {
00071 *p++ = (BYTE)( 0x80 | ( val & 0x7F ) );
00072 val >>= 7;
00073 }
00074 }
00075 }
00076
00077
00078 static int _getInt( BYTE*& p, int sz, unsigned int* val ) {
00079 unsigned int bits = 0;
00080 int nBits = 0;
00081
00082 for(;;) {
00083 if ( sz == 0 )
00084 return 1;
00085
00086 BYTE c = *p++;
00087 --sz;
00088 bits |= ( ( c & 0x7F ) << nBits );
00089 nBits += 7;
00090
00091 if ( nBits > 35 )
00092 return 2;
00093
00094 if ( ( c & 0x80 ) == 0 ) {
00095 *val = bits;
00096 return 0;
00097 }
00098 }
00099 }
00100
00101 int lzhl_send( SOCKET sock, const char* data, int dataSz, int flags ) {
00102 GlobalMapType::iterator iter = globalMap.find( sock );
00103 if ( iter != globalMap.end() ) {
00104 LZHL_CHANDLE& ch = (*iter).second.ch;
00105
00106 if ( ch == LZHL_CHANDLE_NULL )
00107 ch = LZHLCreateCompressor();
00108
00109 size_t maxSz = 10 + LZHLCompressorCalcMaxBuf( dataSz );
00110 BYTE* buf = new BYTE[ maxSz ];
00111 size_t compSz = LZHLCompress( ch, buf + 10, data, dataSz );
00112
00113 int dSz;
00114
00115 {
00116 BYTE tmp[ 5 ];
00117 BYTE* p = tmp;
00118 _putInt( p, compSz );
00119 int szSz = p - tmp;
00120 dSz = szSz;
00121 assert( dSz <= 10 );
00122 memcpy( buf + 10 - dSz, tmp, szSz );
00123
00124 p = tmp;
00125 _putInt( p, dataSz );
00126 szSz = p - tmp;
00127 dSz += szSz;
00128 assert( dSz <= 10 );
00129 memcpy( buf + 10 - dSz, tmp, szSz );
00130 }
00131
00132 {
00133 BYTE* p = buf + 10 - dSz;
00134 int sz = dSz + compSz;
00135
00136 while( sz ) {
00137 int bytes = send( sock, (char*)( buf + 10 - dSz ), dSz + compSz, flags );
00138
00139 if ( bytes < 0 )
00140 return bytes;
00141
00142 assert( bytes <= sz );
00143 sz -= bytes;
00144 p += bytes;
00145 }
00146 }
00147
00148 delete [] buf;
00149 return dataSz;
00150 } else {
00151 return send( sock, data, dataSz, flags );
00152 }
00153 }
00154
00155 int lzhl_recv( SOCKET sock, char* buf, int bufSz, int flags )
00156 {
00157 GlobalMapType::iterator iter = globalMap.find( sock );
00158 if ( iter != globalMap.end() ) {
00159 LZHL_SOCKET& ls = (*iter).second;
00160 LZHL_DHANDLE& dh = ls.dh;
00161
00162 if ( dh == LZHL_DHANDLE_NULL )
00163 dh = LZHLCreateDecompressor();
00164
00165 if ( ls.dBuf == 0 ) {
00166 unsigned int dataSz, compSz;
00167 BYTE tmp[ 10 ];
00168 int bytesRead = 0;
00169 int hdrSz;
00170
00171 for (;;) {
00172 int bytes = recv( sock, (char*)( tmp + bytesRead ), 1, flags );
00173
00174 if ( bytes <= 0 )
00175 return bytes;
00176
00177 bytesRead += bytes;
00178 BYTE* p = tmp;
00179
00180 int err = _getInt( p, bytesRead, &dataSz );
00181
00182 if ( err == 1 )
00183 continue;
00184 else if ( err > 0 )
00185 return -1;
00186
00187 err = _getInt( p, bytesRead - ( p - tmp ), &compSz );
00188
00189 if ( err == 1 )
00190 continue;
00191 else if ( err > 0 )
00192 return -1;
00193
00194 hdrSz = p - tmp;
00195 break;
00196 }
00197
00198 BYTE* compBuf = new BYTE[ compSz ];
00199 bytesRead -= hdrSz;
00200 memcpy( compBuf, tmp + hdrSz, bytesRead );
00201 for (;;) {
00202 int bytes = recv( sock, (char*)( compBuf + bytesRead ), compSz - bytesRead, flags );
00203 if ( bytes <= 0 ) {
00204 delete [] compBuf;
00205 return bytes;
00206 }
00207
00208 bytesRead += bytes;
00209 if ( bytesRead == compSz )
00210 break;
00211 }
00212
00213 ls.dSz = dataSz;
00214 ls.dDisp = 0;
00215 ls.dBuf = new BYTE[ dataSz ];
00216
00217 int Ok = LZHLDecompress( dh, ls.dBuf, &dataSz, compBuf, &compSz );
00218 delete [] compBuf;
00219
00220 if ( !Ok ) {
00221 delete [] ls.dBuf;
00222 ls.dBuf = 0;
00223 return -1;
00224 }
00225
00226 }
00227
00228 assert( ls.dBuf );
00229 int sz = min( bufSz, ls.dSz - ls.dDisp );
00230 memcpy( buf, ls.dBuf + ls.dDisp, sz );
00231 ls.dDisp += sz;
00232
00233 if ( ls.dDisp == ls.dSz ) {
00234 delete [] ls.dBuf;
00235 ls.dBuf = 0;
00236 }
00237 return sz;
00238
00239 } else {
00240 return recv( sock, buf, bufSz, flags );
00241 }
00242 }
00243
00244 int lzhl_closesocket( SOCKET sock ) {
00245 GlobalMapType::iterator iter = globalMap.find( sock );
00246
00247 if ( iter != globalMap.end() ) {
00248 LZHL_CHANDLE ch = (*iter).second.ch;
00249
00250 if ( ch != LZHL_CHANDLE_NULL )
00251 LZHLDestroyCompressor( ch );
00252
00253 LZHL_DHANDLE dh = (*iter).second.dh;
00254
00255 if ( dh != LZHL_DHANDLE_NULL )
00256 LZHLDestroyDecompressor( dh );
00257
00258 delete [] (*iter).second.dBuf;
00259 }
00260 return closesocket( sock );
00261 }