//-------------------------------------
//zsocket class
//
//Author: Mark Gordon (msg555)
//Date: 9/2/05
//Libraries: Ws2_32.lib
//-------------------------------------
/* Events
virtual void error(int num, std::string description) {}
virtual void statechange(SocketState oldstate) {}
virtual void connectionrequest(int requestid) {}
virtual void dataarrival(int bytessent) {}
*/
#ifndef zsocket_hpp
#define zsocket_hpp
//#include <ws2tcpip.h>
#include <windows.h>
#include <iostream>
#include <sstream>
#include <list>
#include <queue>
using namespace std;
#define DEFAULT_TIMEOUT 20000 //20 seconds
enum SocketState
{
sckClosed, sckListening, sckConnecting, sckConnectionPending, sckConnected
};
#define SOCKET_MESSAGE (WM_USER + 0x401)
#define SOCKET_ADDSELECT (WM_USER + 0x402)
#define zqueue(x) std::queue<x, std::list<x, std::allocator<x> > >
class zsocket
{
public:
zsocket();
~zsocket();
//Connect functions ------------------------
//Syncronous
bool syncconnect(std::string zhost, int zport,int timeout);
bool syncconnect(int timeout);
bool syncconnect(std::string zhost, int zport);
bool syncconnect();
//Asyncronous
void sckconnect(std::string zhost, int zport, int timeout);
void sckconnect(int timeout);
void sckconnect(std::string zhost, int zport);
void sckconnect();
//------------------------------------------
//Send functions
void senddata(void * buf, int length);
void sendstring(std::string str);
void sendfile(std::string file);
void flush(); //Wait for all queued data to be sent for this socket
//recieve functions
std::string getdata();
std::string peekdata();
//listen
bool scklisten(int zport); //socket begins to listen on zport. note that sck.port() is only used for remote ports
bool scklisten();
//misselanious
void close(); //Creates a new socket. Set's the socket's state to sckClosed
void sckaccept(int requestid); //Accept a connection request to your listening socket. The requestid will be passed to your socket in the connectionrequest event.
bool waitforconnection(int); //waits for a connection request and connects to it syncronously
//Properties (get)
SocketState state() {return m_state;}
int getlasttime() {return m_lasttime;}
int timeout() {return m_timeout;}
int maxsendbuffersize() {return m_sendbuffersize;}
//Properties (set)
void sethost(std::string zhost) {m_remotehost = zhost;}
void setport(int zport) {m_port = zport;}
void settimeout(int ztimeout) {m_timeout = ztimeout;}
std::string localip();
std::string localhost();
int port();
std::string remoteip();
std::string remotehost();
//Events
virtual void error(int num, std::string description) {}
virtual void statechange(SocketState oldstate) {}
virtual void connectionrequest(int requestid) {}
virtual void dataarrival(int bytessent) {}
private:
SOCKET sockid;
SocketState m_state;
zsocket * nextsck;
zsocket * prevsck;
int m_lasttime;
int m_timeout;
int m_recievebuffersize;
int m_sendbuffersize;
bool m_write;
int sending;
zqueue(std::string) q_str;
zqueue(bool) q_file;
zqueue(long long) q_pos;
HANDLE q_hfile; //Only one file will ever be open at a time so we only need one handle
std::string recvbuffer;
int m_port;
std::string m_remotehost;
std::string m_remoteip;
void createsocket();
void destroysocket();
void changestate(SocketState);
void getbuffersizes();
void getremoteinfo();
void sendqueueddata();
static HWND sckwnd;
static int count;
static zsocket * firstsck;
static zsocket * lastsck;
static DWORD msgthreadid;
static DWORD WINAPI msgloop(LPVOID);
static LRESULT CALLBACK msgproc (HWND, UINT, WPARAM, LPARAM);
static DWORD WINAPI asyncconnect(LPVOID);
static std::string errortostring(int error);
};
zsocket::zsocket() : prevsck(NULL), nextsck(NULL), m_state(sckClosed), sockid(INVALID_SOCKET), m_port(0), m_timeout(DEFAULT_TIMEOUT), m_recievebuffersize(0), m_sendbuffersize(0), m_write(false), q_hfile(INVALID_HANDLE_VALUE), sending(false), m_lasttime(0)
{
count++;
if (count == 1)
{
//Initiate Socket Message Window and loop
WSADATA empty;
WSAStartup(MAKEWORD(2, 2), &empty);
firstsck = this;
lastsck = this;
CreateThread(NULL, 0, zsocket::msgloop, NULL, 0, NULL);
while (sckwnd == NULL)
Sleep(1);
}
else
{
prevsck = lastsck;
prevsck->nextsck = this;
lastsck = this;
}
createsocket();
}
zsocket::~zsocket()
{
count--;
//cout << "Socket Count : " << count << endl;
destroysocket();
if (count == 0)
{
WSACleanup();
firstsck = NULL;
lastsck = NULL;
}
else
{
if (firstsck == this)
firstsck = nextsck;
if (lastsck == this)
lastsck = prevsck;
if (prevsck != NULL)
prevsck->nextsck = nextsck;
if (nextsck != NULL)
nextsck->prevsck = prevsck;
}
//if the count goes to 0, the Socket Message loop will termintate
}
void zsocket::createsocket()
{
sockid = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
SendMessage(sckwnd, SOCKET_ADDSELECT, (WPARAM)sockid, 0);
changestate(sckClosed);
}
void zsocket::destroysocket()
{
closesocket(sockid);
changestate(sckClosed);
}
void zsocket::changestate(SocketState newstate)
{
m_write = false;
while (q_str.size())
{
q_str.pop();
q_file.pop();
q_pos.pop();
}
if (newstate == sckConnected)
{
m_lasttime = GetTickCount();
getbuffersizes();
getremoteinfo();
}
else if (m_state == sckConnected)
{
m_remotehost = "";
m_remoteip = "";
}
SocketState oldstate = m_state;
m_state = newstate;
statechange(oldstate);
}
void zsocket::getbuffersizes()
{
int result;
int size = sizeof(int);
result = getsockopt(sockid, SOL_SOCKET, SO_RCVBUF, (char *) &m_recievebuffersize, &size);
if (result == SOCKET_ERROR)
{
close();
int errorcode = GetLastError();
error(errorcode, errortostring(errorcode));
return;
}
result = getsockopt(sockid, SOL_SOCKET, SO_SNDBUF, (char *) &m_sendbuffersize, &size);
if (result == SOCKET_ERROR)
{
close();
int errorcode = GetLastError();
error(errorcode, errortostring(errorcode));
return;
}
}
void zsocket::getremoteinfo()
{
sockaddr_in ri;
int size = sizeof(ri);
if(getpeername(sockid, (sockaddr *)&ri, &size))
{
return;
}
else
{
m_port = ntohs(ri.sin_port);
m_remoteip = (char *)inet_ntoa(ri.sin_addr);
if (m_remotehost.length() == 0)
m_remotehost = m_remoteip;
/* Works, but takes far too long
hostent * rh = gethostbyaddr((char *)(&(ri.sin_addr.S_un.S_addr)), 4, AF_INET);
if (rh)
m_remotehost = rh->h_name; */
}
}
void zsocket::close()
{
destroysocket();
createsocket();
}
bool zsocket::scklisten(int port)
{
m_port = port;
return scklisten();
}
bool zsocket::scklisten()
{
SOCKADDR_IN serverInfo;
serverInfo.sin_family = AF_INET;
serverInfo.sin_addr.s_addr = INADDR_ANY;
serverInfo.sin_port = htons(m_port);
if (bind(sockid, (LPSOCKADDR)&serverInfo, sizeof(struct sockaddr)) == SOCKET_ERROR)
return false;
if (listen(sockid, SOMAXCONN) == SOCKET_ERROR)
return false;
changestate(sckListening);
return true;
}
bool zsocket::waitforconnection(int timeout)
{
int starttime = GetTickCount();
while (GetTickCount() - starttime < timeout || timeout == 0)
{
SOCKET sck = accept(sockid, NULL, NULL);
if (sck != INVALID_SOCKET)
{
destroysocket();
sockid = sck;
changestate(sckConnected);
return true;
}
}
return false;
}
std::string zsocket::getdata()
{
std::string cpy = recvbuffer;
recvbuffer = "";
return cpy;
}
std::string zsocket::peekdata()
{
return recvbuffer;
}
//***************************************************************************
//Connect Functions. Overloaded for easy use, syncconnect() will eventually
// be called no matter what connect function is originally called
//sycnconnect(...) is Syncronous
//sckconnect(...) is Asyncronous
//note that the socket's state will be changed
// to sckConnecting when the call is made to one of the asyncronous functions
// if the connection is succefull the state will be changed to sckConnceted
// if the connection is unsuccefull the state will be change to sckClosed
bool zsocket::syncconnect(std::string zhost, int zport, int ztimeout)
{
m_remotehost = zhost;
m_port = zport;
m_timeout = ztimeout;
return syncconnect();
}
bool zsocket::syncconnect(int ztimeout)
{
m_timeout = ztimeout;
return syncconnect();
}
bool zsocket::syncconnect(std::string zhost, int zport)
{
m_remotehost = zhost;
m_port = zport;
return syncconnect();
}
bool zsocket::syncconnect()
{
//Close the socket so it is ready for connecting
close();
hostent * host = gethostbyname(m_remotehost.c_str());
if (host == NULL)
{
unsigned int addr = inet_addr(m_remotehost.c_str());
host = gethostbyaddr((char *)&addr, 4, AF_INET);
if (host == NULL)
return false;
}
sockaddr_in addr;
addr.sin_family = AF_INET;
addr.sin_port = htons(m_port);
addr.sin_addr = *((in_addr *)host->h_addr);
memset(&(addr.sin_zero), 0, 8);
if (connect(sockid, (const sockaddr *)&addr, sizeof(addr)) == SOCKET_ERROR)
{
int result;
result = WSAGetLastError();
if (result != WSAEWOULDBLOCK)
{
close();
error(result, errortostring(result));
return false;
}
}
else
{
changestate(sckConnected);
return true;
}
int stime = GetTickCount();
while (m_state != sckConnected)
{
if (msgthreadid == GetCurrentThreadId())
{//Run message loop so this, and other sockets can continue to function
MSG sckmessages;
if (PeekMessage(&sckmessages, NULL, 0, 0, 0))
{
GetMessage(&sckmessages, NULL, 0, 0);
TranslateMessage(&sckmessages);
DispatchMessage(&sckmessages);
}
}
if (GetTickCount() - stime >= m_timeout)
{
close();
error(WSAETIMEDOUT, errortostring(WSAETIMEDOUT));
return false;
}
Sleep(1);
}
return true;
}
//Asyncronous
void zsocket::sckconnect(std::string zhost, int zport, int ztimeout)
{
m_remotehost = zhost;
m_port = zport;
m_timeout = ztimeout;
sckconnect();
}
void zsocket::sckconnect(int ztimeout)
{
m_timeout = ztimeout;
sckconnect();
}
void zsocket::sckconnect(std::string zhost, int zport)
{
m_remotehost = zhost;
m_port = zport;
sckconnect();
}
void zsocket::sckconnect()
{
changestate(sckConnecting);
CreateThread(NULL, 0, zsocket::asyncconnect, this, 0, NULL);
}
//This function is static, param is a pointer to the zsocket object that created the thread
DWORD WINAPI zsocket::asyncconnect(LPVOID param)
{
zsocket * sck = (zsocket *)(param);
sck->syncconnect();
return 0;
}
//***************************************************************************
void zsocket::senddata(void * buf, int len)
{
q_str.push(std::string((char *) buf, len));
q_file.push(false);
q_pos.push(0);
sendqueueddata();
}
void zsocket::sendstring(std::string str)
{
q_str.push(str);
q_file.push(false);
q_pos.push(0);
sendqueueddata();
}
void zsocket::sendfile(std::string file)
{
q_str.push(file);
q_file.push(true);
q_pos.push(0);
sendqueueddata();
}
void zsocket::flush()
{
std::cout << "Flushing" << std::endl;
while (q_str.size() && m_state == sckConnected)
{
if (msgthreadid == GetCurrentThreadId())
{//Run message loop so this, and other sockets can continue to function
MSG sckmessages;
if (PeekMessage(&sckmessages, NULL, 0, 0, 0))
{
GetMessage(&sckmessages, NULL, 0, 0);
TranslateMessage(&sckmessages);
DispatchMessage(&sckmessages);
}
}
Sleep(1);
}
}
void zsocket::sendqueueddata()
{
if (m_write == false)
return;
if (sending++)
{
sending--;
return;
}
m_lasttime = GetTickCount();
while (q_str.size())
{
int res = 0;
long long q_len = 0;
//Open file
if (q_file.front() && q_hfile == INVALID_HANDLE_VALUE)
{
q_hfile = CreateFile(q_str.front().c_str(), GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
}
//Get total enqued data length
if (q_file.front() && q_hfile != INVALID_HANDLE_VALUE)
{
long * cpy = new long[2];
cpy[0] = GetFileSize(q_hfile, (LPDWORD) &(cpy[1]));
memcpy(&q_len, cpy, sizeof(long long));
delete [] cpy;
}
else if (!q_file.front())
{
q_len = q_str.front().length();
}
//While the current sending position of the enqueued data is not the full length
while (q_pos.front() < q_len)
{
m_lasttime = GetTickCount();
int chunksize;
if ((q_len - q_pos.front()) > m_sendbuffersize) //Send chunk
chunksize = m_sendbuffersize;
else
chunksize = q_len - q_pos.front();
if (q_file.front())
{
char * buf = new char[chunksize];
ReadFile(q_hfile, buf, chunksize, (LPDWORD)&chunksize, NULL);
res = send(sockid, buf, chunksize, 0);
delete [] buf;
}
else
{
res = send(sockid, q_str.front().c_str() + q_pos.front(), chunksize, 0);
}
if (res == SOCKET_ERROR)
{
int err = WSAGetLastError();
if (err == WSAEWOULDBLOCK)
{
m_write = false;
}
else
{
error(err, errortostring(err));
if (q_file.front())
{
CloseHandle(q_hfile);
q_hfile = INVALID_HANDLE_VALUE;
}
while (q_str.size())
{
q_str.pop();
q_file.pop();
q_pos.pop();
}
m_write = false;
}
}
else
{
q_pos.front() += res;
}
if (q_file.front())
{ //Adjust file pointer back to the bytes that weren't sent
if (res == SOCKET_ERROR)
SetFilePointer(q_hfile, -chunksize, NULL, FILE_CURRENT);
else
SetFilePointer(q_hfile, res - chunksize, NULL, FILE_CURRENT);
}
if (res == SOCKET_ERROR)
{
sending = 0;
return;
}
}
//Enqued data has finished sending, move on to next queued data of finish if no more data is queued
if (q_file.front())
{
CloseHandle(q_hfile);
q_hfile = INVALID_HANDLE_VALUE;
}
if (q_str.size()) q_str.pop();
if (q_file.size()) q_file.pop();
if (q_pos.size()) q_pos.pop();
}
sending = 0;
}
void zsocket::sckaccept(int requestid)
{
destroysocket();
sockid = requestid;
SendMessage(sckwnd, SOCKET_ADDSELECT, (WPARAM)sockid, 0);
changestate(sckConnected);
}
std::string zsocket::localip()
{
std::string host = localhost();
hostent * hostinfo = gethostbyname(host.c_str());
std::ostringstream oss;
for (int i = 0; i < 4; i++)
{
oss << (int)((unsigned char)(hostinfo->h_addr_list[0][i]));
if (i != 3)
oss << ".";
}
return oss.str();
}
std::string zsocket::localhost()
{
char * buf = new char[256];
if(gethostname(buf, 256))
return "LocalHost";
else
return buf;
}
std::string zsocket::remoteip()
{
return m_remoteip;
}
std::string zsocket::remotehost()
{
return m_remotehost;
}
int zsocket::port()
{
return m_port;
}
//************************************************
//Static members
//************************************************
HWND zsocket::sckwnd = NULL;
int zsocket::count = 0;
zsocket * zsocket::firstsck = NULL;
zsocket * zsocket::lastsck = NULL;
DWORD zsocket::msgthreadid = 0;
DWORD WINAPI zsocket::msgloop(LPVOID)
{
msgthreadid = GetCurrentThreadId();
WNDCLASSEX wincl;
wincl.hInstance = 0;
wincl.lpszClassName = "zsocket";
wincl.lpfnWndProc = zsocket::msgproc;
wincl.style = 0;
wincl.cbSize = sizeof (WNDCLASSEX);
wincl.hIcon = 0;
wincl.hIconSm = 0;
wincl.hCursor = 0;
wincl.lpszMenuName = NULL;
wincl.cbClsExtra = 0;
wincl.cbWndExtra = 0;
wincl.hbrBackground = 0;
RegisterClassEx (&wincl);
sckwnd = CreateWindowEx(0, "zsocket", "zsocket message window", 0, 0, 0, 0, 0, 0, 0, 0, NULL);
MSG sckmessages;
while (GetMessage(&sckmessages, NULL, 0, 0) && count)
{
TranslateMessage(&sckmessages);
DispatchMessage(&sckmessages);
}
DestroyWindow(sckwnd);
return 0;
}
LRESULT CALLBACK zsocket::msgproc (HWND hwnd, UINT umsg, WPARAM wparam, LPARAM lparam)
{
if (umsg == SOCKET_MESSAGE)
{
zsocket * sck = NULL;
for (zsocket * fsck = firstsck; fsck != NULL; fsck = fsck->nextsck)
{
if (fsck->sockid == wparam)
{
sck = fsck;
break;
}
}
if (sck == NULL)
return 0;
int error = WSAGETSELECTERROR(lparam);
int event = WSAGETSELECTEVENT(lparam);
char * buf;
if (error != 0)
{
if (sck->state() == sckClosed) //Ignore error
return 0;
if(error != WSAEWOULDBLOCK)
sck->close();
sck->error(error, errortostring(error));
}
else
{
sck->m_lasttime = GetTickCount();
int result;
switch (event)
{
case FD_READ:
buf = new char[sck->m_recievebuffersize];
result = recv(sck->sockid, buf, sck->m_recievebuffersize, 0);
if (result == SOCKET_ERROR)
{
delete [] buf;
error = GetLastError();
if (error != WSAEWOULDBLOCK)
{
sck->close();
sck->error(error, errortostring(error));
}
break;
}
sck->recvbuffer += std::string(buf, result);
delete [] buf;
sck->dataarrival(sck->recvbuffer.length());
break;
case FD_WRITE:
sck->m_write = true;
sck->sendqueueddata();
break;
case FD_ACCEPT:
sck->connectionrequest(accept(sck->sockid, NULL, 0));
break;
case FD_CONNECT:
sck->changestate(sckConnected);
break;
case FD_CLOSE:
sck->close();
break;
}
}
return 0;
}
else if (umsg == SOCKET_ADDSELECT)
{
WSAAsyncSelect((SOCKET)wparam, sckwnd, SOCKET_MESSAGE, FD_READ | FD_WRITE | FD_ACCEPT | FD_CONNECT | FD_CLOSE);
}
return DefWindowProc(hwnd, umsg, wparam, lparam);
}
std::string zsocket::errortostring(int error)
{
switch(error)
{
case WSAEACCES:
return "Permission denied.";
case WSAEADDRINUSE:
return "Address already in use.";
case WSAEADDRNOTAVAIL:
return "Cannot assign requested address.";
case WSAEAFNOSUPPORT:
return "Address family not supported by protocol family.";
case WSAEALREADY:
return "Operation already in progress.";
case WSAECONNABORTED:
return "Software caused connection abort.";
case WSAECONNREFUSED:
return "Connection refused.";
case WSAECONNRESET:
return "Connection reset by peer.";
case WSAEDESTADDRREQ:
return "Destination address required.";
case WSAEFAULT:
return "Bad address.";
case WSAEHOSTUNREACH:
return "No route to host.";
case WSAEINPROGRESS:
return "Operation now in progress.";
case WSAEINTR:
return "Interrupted function call.";
case WSAEINVAL:
return "Invalid argument.";
case WSAEISCONN:
return "Socket is already connected.";
case WSAEMFILE:
return "Too many open files.";
case WSAEMSGSIZE:
return "Message too long.";
case WSAENETDOWN:
return "Network is down.";
case WSAENETRESET:
return "Network dropped connection on reset.";
case WSAENETUNREACH:
return "Network is unreachable.";
case WSAENOBUFS:
return "No buffer space available.";
case WSAENOPROTOOPT:
return "Bad protocol option.";
case WSAENOTCONN:
return "Socket is not connected.";
case WSAENOTSOCK:
return "Socket operation on nonsocket.";
case WSAEOPNOTSUPP:
return "Operation not supported.";
case WSAEPFNOSUPPORT:
return "Protocol family not supported.";
case WSAEPROCLIM:
return "Too many processes.";
case WSAEPROTONOSUPPORT:
return "Protocol not supported.";
case WSAEPROTOTYPE:
return "Protocol wrong type for socket.";
case WSAESHUTDOWN:
return "Cannot send after socket shutdown.";
case WSAESOCKTNOSUPPORT:
return "Socket type not supported.";
case WSAETIMEDOUT:
return "Connection timed out.";
case WSAEWOULDBLOCK:
return "Resource temporarily unavailable.";
case WSAHOST_NOT_FOUND:
return "Host not found.";
case WSANOTINITIALISED:
return "Successful WSAStartup not yet performed.";
case WSANO_DATA:
return "Valid name, no data record of requested type.";
case WSANO_RECOVERY:
return "This is a nonrecoverable error.";
case WSASYSNOTREADY:
return "Network subsystem is unavailable.";
case WSATRY_AGAIN:
return "Nonauthoritative host not found.";
case WSAVERNOTSUPPORTED:
return "Winsock.dll version out of range.";
default:
return "Unknown error.";
}
return "Unknown error.";
}
#endif