//-------------------------------------
//zsocket class
//
//Author: Mark Gordon (msg555)
//Date: 9/2/05
//Libraries: Ws2_32.lib
//-------------------------------------

/* Events
virtual void error(int num, std::string description) {}
virtual void statechange(SocketState oldstate) {}
virtual void connectionrequest(int requestid) {}
virtual void dataarrival(int bytessent) {}    
*/

#ifndef zsocket_hpp
#define zsocket_hpp

//#include <ws2tcpip.h>
#include <windows.h>
#include <iostream>
#include <sstream>
#include <list>
#include <queue>

using namespace std;

#define DEFAULT_TIMEOUT 20000 //20 seconds

enum SocketState
{
    sckClosed, sckListening, sckConnecting, sckConnectionPending, sckConnected
};

#define SOCKET_MESSAGE (WM_USER + 0x401)
#define SOCKET_ADDSELECT (WM_USER + 0x402)
#define zqueue(x) std::queue<x, std::list<x, std::allocator<x> > >

class zsocket
{
public:
    zsocket();
    ~zsocket();
    
    //Connect functions ------------------------
    //Syncronous
    bool syncconnect(std::string zhost, int zport,int timeout);
    bool syncconnect(int timeout);
    bool syncconnect(std::string zhost, int zport);
    bool syncconnect();
    //Asyncronous
    void sckconnect(std::string zhost, int zport, int timeout); 
    void sckconnect(int timeout);
    void sckconnect(std::string zhost, int zport); 
    void sckconnect();
    //------------------------------------------

    //Send functions
    void senddata(void * buf, int length);
    void sendstring(std::string str);
    void sendfile(std::string file);
    void flush(); //Wait for all queued data to be sent for this socket

    //recieve functions
    std::string getdata();
    std::string peekdata();

    //listen
    bool scklisten(int zport);  //socket begins to listen on zport.  note that sck.port() is only used for remote ports
    bool scklisten();

    //misselanious
    void close(); //Creates a new socket.  Set's the socket's state to sckClosed
    void sckaccept(int requestid); //Accept a connection request to your listening socket.  The requestid will be passed to your socket in the connectionrequest event.
    bool waitforconnection(int);  //waits for a connection request and connects to it syncronously

    //Properties (get)
    SocketState state() {return m_state;}
    int getlasttime() {return m_lasttime;}
    int timeout() {return m_timeout;}
    int maxsendbuffersize() {return m_sendbuffersize;}

    //Properties (set)
    void sethost(std::string zhost) {m_remotehost = zhost;}
    void setport(int zport) {m_port = zport;}
    void settimeout(int ztimeout) {m_timeout = ztimeout;}
    
    std::string localip();
    std::string localhost();
    int port();
    std::string remoteip();
    std::string remotehost();

    //Events
    virtual void error(int num, std::string description) {}
    virtual void statechange(SocketState oldstate) {}
    virtual void connectionrequest(int requestid) {}
    virtual void dataarrival(int bytessent) {}    

private:
    SOCKET sockid;
    SocketState m_state;
    zsocket * nextsck;
    zsocket * prevsck;
    int m_lasttime;
    int m_timeout;
    int m_recievebuffersize;
    int m_sendbuffersize;
    bool m_write;
    int sending;

    zqueue(std::string) q_str;
    zqueue(bool) q_file;
    zqueue(long long) q_pos;
    HANDLE q_hfile; //Only one file will ever be open at a time so we only need one handle

    std::string recvbuffer;

    int m_port;
    std::string m_remotehost;
    std::string m_remoteip;

    void createsocket();
    void destroysocket();
    void changestate(SocketState);
    void getbuffersizes();
    void getremoteinfo();
    void sendqueueddata();

    static HWND sckwnd;
    static int count;
    static zsocket * firstsck;
    static zsocket * lastsck;
    static DWORD msgthreadid;

    static DWORD WINAPI msgloop(LPVOID);
    static LRESULT CALLBACK msgproc (HWND, UINT, WPARAM, LPARAM);
    static DWORD WINAPI asyncconnect(LPVOID);
    static std::string errortostring(int error);
};

zsocket::zsocket() : prevsck(NULL), nextsck(NULL), m_state(sckClosed), sockid(INVALID_SOCKET), m_port(0), m_timeout(DEFAULT_TIMEOUT), m_recievebuffersize(0), m_sendbuffersize(0), m_write(false), q_hfile(INVALID_HANDLE_VALUE), sending(false), m_lasttime(0)
{
    count++;
    if (count == 1)
    {
        //Initiate Socket Message Window and loop
        WSADATA empty;
        WSAStartup(MAKEWORD(2, 2), &empty);
        firstsck = this;
        lastsck = this;
        CreateThread(NULL, 0, zsocket::msgloop, NULL, 0, NULL);
        while (sckwnd == NULL)
            Sleep(1);
    }
    else
    {
        prevsck = lastsck;
        prevsck->nextsck = this;
        lastsck = this;
    }
    createsocket();

}

zsocket::~zsocket()
{
    count--;
    //cout << "Socket Count : " << count << endl;

    destroysocket();
    if (count == 0)
    {
        WSACleanup();
        firstsck = NULL;
        lastsck = NULL;
    }
    else
    {
        if (firstsck == this)
            firstsck = nextsck;
        if (lastsck == this)
            lastsck = prevsck;
        if (prevsck != NULL)
            prevsck->nextsck = nextsck;
        if (nextsck != NULL)
            nextsck->prevsck = prevsck;
    }
    //if the count goes to 0, the Socket Message loop will termintate
}

void zsocket::createsocket()
{
    sockid = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
    SendMessage(sckwnd, SOCKET_ADDSELECT, (WPARAM)sockid, 0);
    changestate(sckClosed);
}

void zsocket::destroysocket()
{
    closesocket(sockid);
    changestate(sckClosed);
}

void zsocket::changestate(SocketState newstate)
{
    m_write = false;
    while (q_str.size())
    {
        q_str.pop();
        q_file.pop();
        q_pos.pop();
    }
    if (newstate == sckConnected)
    {
        m_lasttime = GetTickCount();
        getbuffersizes();
        getremoteinfo();
    }
    else if (m_state == sckConnected)
    {
        m_remotehost = "";
        m_remoteip = "";
    }
    SocketState oldstate = m_state;
    m_state = newstate;
    statechange(oldstate);
}

void zsocket::getbuffersizes()
{
    int result;
    int size = sizeof(int);
    result = getsockopt(sockid, SOL_SOCKET, SO_RCVBUF, (char *) &m_recievebuffersize, &size);
    if (result == SOCKET_ERROR)
    {
        close();
        int errorcode = GetLastError();
        error(errorcode, errortostring(errorcode));
        return;
    }
    result = getsockopt(sockid, SOL_SOCKET, SO_SNDBUF, (char *) &m_sendbuffersize, &size);
    if (result == SOCKET_ERROR)
    {
        close();
        int errorcode = GetLastError();
        error(errorcode, errortostring(errorcode));
        return;
    }    
}

void zsocket::getremoteinfo()
{
    sockaddr_in ri;
    int size = sizeof(ri);
    if(getpeername(sockid, (sockaddr *)&ri, &size))
    {
        return;
    }
    else
    {
        m_port = ntohs(ri.sin_port);
        m_remoteip = (char *)inet_ntoa(ri.sin_addr);
        if (m_remotehost.length() == 0)
            m_remotehost = m_remoteip;
        /* Works, but takes far too long
        hostent * rh = gethostbyaddr((char *)(&(ri.sin_addr.S_un.S_addr)), 4, AF_INET);
        if (rh)
            m_remotehost = rh->h_name; */
    }
    
}

void zsocket::close()
{
    destroysocket();
    createsocket();
}

bool zsocket::scklisten(int port)
{
    m_port = port;
    return scklisten();
}

bool zsocket::scklisten()
{
    SOCKADDR_IN serverInfo;
	serverInfo.sin_family = AF_INET;
	serverInfo.sin_addr.s_addr = INADDR_ANY;
	serverInfo.sin_port = htons(m_port);
	if (bind(sockid, (LPSOCKADDR)&serverInfo, sizeof(struct sockaddr)) == SOCKET_ERROR)
		return false;
	if (listen(sockid, SOMAXCONN) == SOCKET_ERROR)
		return false;
    changestate(sckListening);
    return true;
}

bool zsocket::waitforconnection(int timeout)
{
    int starttime = GetTickCount();
    while (GetTickCount() - starttime < timeout || timeout == 0)
    {
        SOCKET sck = accept(sockid, NULL, NULL);
        if (sck != INVALID_SOCKET)
        {
            destroysocket();
            sockid = sck;
            changestate(sckConnected);
            return true;
        }
    }
    return false;
}

std::string zsocket::getdata()
{
    std::string cpy = recvbuffer;
    recvbuffer = "";
    return cpy;
}

std::string zsocket::peekdata()
{
    return recvbuffer;
}

//***************************************************************************
//Connect Functions.  Overloaded for easy use, syncconnect() will eventually
// be called no matter what connect function is originally called
//sycnconnect(...) is Syncronous
//sckconnect(...) is Asyncronous
//note that the socket's state will be changed 
//      to sckConnecting when the call is made to one of the asyncronous functions
//      if the connection is succefull the state will be changed to sckConnceted
//      if the connection is unsuccefull the state will be change to sckClosed
bool zsocket::syncconnect(std::string zhost, int zport, int ztimeout)
{
    m_remotehost = zhost;
    m_port = zport;
    m_timeout = ztimeout;
    return syncconnect();
}
bool zsocket::syncconnect(int ztimeout)
{
    m_timeout = ztimeout;
    return syncconnect();
}
bool zsocket::syncconnect(std::string zhost, int zport)
{
    m_remotehost = zhost;
    m_port = zport;
    return syncconnect();
}
bool zsocket::syncconnect()
{
    //Close the socket so it is ready for connecting
    close();

    hostent * host = gethostbyname(m_remotehost.c_str());
    if (host == NULL)
    {
        unsigned int addr = inet_addr(m_remotehost.c_str());
        host = gethostbyaddr((char *)&addr, 4, AF_INET);
        if (host == NULL)
            return false;
    }

    sockaddr_in addr;
    addr.sin_family = AF_INET;
    addr.sin_port = htons(m_port);
    addr.sin_addr = *((in_addr *)host->h_addr);
    memset(&(addr.sin_zero), 0, 8); 

    if (connect(sockid, (const sockaddr *)&addr, sizeof(addr)) == SOCKET_ERROR)
    {
        int result;
        result = WSAGetLastError();
        if (result != WSAEWOULDBLOCK)
        {
            close();
            error(result, errortostring(result));
            return false;
        }
    }
    else
    {
        changestate(sckConnected);
        return true;
    }
    int stime = GetTickCount();
    while (m_state != sckConnected)
    {
        if (msgthreadid == GetCurrentThreadId())
        {//Run message loop so this, and other sockets can continue to function
            MSG sckmessages;
            if (PeekMessage(&sckmessages, NULL, 0, 0, 0))
            {
                GetMessage(&sckmessages, NULL, 0, 0);
                TranslateMessage(&sckmessages);
                DispatchMessage(&sckmessages);
            }
        }
        if (GetTickCount() - stime >= m_timeout)
        {
            close();
            error(WSAETIMEDOUT, errortostring(WSAETIMEDOUT));
            return false;
        }
        Sleep(1);
    }
    return true;
}
//Asyncronous
void zsocket::sckconnect(std::string zhost, int zport, int ztimeout)
{
    m_remotehost = zhost;
    m_port = zport;
    m_timeout = ztimeout;
    sckconnect();
}
void zsocket::sckconnect(int ztimeout)
{
    m_timeout = ztimeout;
    sckconnect();
}
void zsocket::sckconnect(std::string zhost, int zport)
{
    m_remotehost = zhost;
    m_port = zport;
    sckconnect();
}
void zsocket::sckconnect()
{
    changestate(sckConnecting);
    CreateThread(NULL, 0, zsocket::asyncconnect, this, 0, NULL);
}
//This function is static, param is a pointer to the zsocket object that created the thread
DWORD WINAPI zsocket::asyncconnect(LPVOID param)
{
    zsocket * sck = (zsocket *)(param);
    sck->syncconnect();
	return 0;
}
//***************************************************************************

void zsocket::senddata(void * buf, int len)
{
    q_str.push(std::string((char *) buf, len));
    q_file.push(false);
    q_pos.push(0);
    sendqueueddata();
}

void zsocket::sendstring(std::string str)
{
    q_str.push(str);
    q_file.push(false);
    q_pos.push(0);
    sendqueueddata();
}

void zsocket::sendfile(std::string file)
{
    q_str.push(file);
    q_file.push(true);
    q_pos.push(0);
    sendqueueddata();
}

void zsocket::flush()
{
    std::cout << "Flushing" << std::endl;
    while (q_str.size() && m_state == sckConnected)
    { 
        if (msgthreadid == GetCurrentThreadId())
        {//Run message loop so this, and other sockets can continue to function
            MSG sckmessages;
            if (PeekMessage(&sckmessages, NULL, 0, 0, 0))
            {
                GetMessage(&sckmessages, NULL, 0, 0);
                TranslateMessage(&sckmessages);
                DispatchMessage(&sckmessages);
            }
        }
        Sleep(1);
    }
}

void zsocket::sendqueueddata()
{
    if (m_write == false)
        return;

    if (sending++)
    {
        sending--;
        return;
    }

    m_lasttime = GetTickCount();

    while (q_str.size())
    {
        int res = 0;
        long long q_len = 0;
    
        //Open file
        if (q_file.front() && q_hfile == INVALID_HANDLE_VALUE)
        {
            q_hfile = CreateFile(q_str.front().c_str(), GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
        }

        //Get total enqued data length
        if (q_file.front() && q_hfile != INVALID_HANDLE_VALUE)
        {
            long * cpy = new long[2];
            cpy[0] = GetFileSize(q_hfile, (LPDWORD) &(cpy[1]));
            memcpy(&q_len, cpy, sizeof(long long));
            delete [] cpy;
        }

        else if (!q_file.front())
        {
            q_len = q_str.front().length();
        }

        //While the current sending position of the enqueued data is not the full length
        while (q_pos.front() < q_len)
        {
            m_lasttime = GetTickCount();

            int chunksize;

            if ((q_len - q_pos.front()) > m_sendbuffersize) //Send chunk
                chunksize = m_sendbuffersize;
            else
                chunksize = q_len - q_pos.front();


            if (q_file.front())
            {
                char * buf = new char[chunksize];
                ReadFile(q_hfile, buf, chunksize, (LPDWORD)&chunksize, NULL);
                res = send(sockid, buf, chunksize, 0);
                delete [] buf;
            }
            else
            {
                res = send(sockid, q_str.front().c_str() + q_pos.front(), chunksize, 0);
            }
            if (res == SOCKET_ERROR)
            {
                int err = WSAGetLastError();
                if (err == WSAEWOULDBLOCK)
                {
                    m_write = false;
                }
                else
                {
                    error(err, errortostring(err));
                    
                    if (q_file.front())
                    {
                        CloseHandle(q_hfile);
                        q_hfile = INVALID_HANDLE_VALUE;
                    }
                    while (q_str.size())
                    {
                        q_str.pop();
                        q_file.pop();
                        q_pos.pop();
                    }

                    m_write = false;
                }
            }
            else
            {
                q_pos.front() += res;
            }
            if (q_file.front())
            { //Adjust file pointer back to the bytes that weren't sent
                if (res == SOCKET_ERROR)
                    SetFilePointer(q_hfile, -chunksize, NULL, FILE_CURRENT);
                else
                    SetFilePointer(q_hfile, res - chunksize, NULL, FILE_CURRENT);
            }
            if (res == SOCKET_ERROR)
            {
                sending = 0;
                return;
            }
        }        

        //Enqued data has finished sending, move on to next queued data of finish if no more data is queued
        if (q_file.front())
        {
            CloseHandle(q_hfile);
            q_hfile = INVALID_HANDLE_VALUE;
        }
        
        if (q_str.size()) q_str.pop();
        if (q_file.size()) q_file.pop();
        if (q_pos.size()) q_pos.pop();
    }
    sending = 0;
}

void zsocket::sckaccept(int requestid)
{
    destroysocket();
    sockid = requestid;
    SendMessage(sckwnd, SOCKET_ADDSELECT, (WPARAM)sockid, 0);
    changestate(sckConnected);
}

std::string zsocket::localip()
{
    std::string host = localhost();
    hostent * hostinfo = gethostbyname(host.c_str());

    std::ostringstream oss;
    for (int i = 0; i < 4; i++)
    {
        oss << (int)((unsigned char)(hostinfo->h_addr_list[0][i]));
        if (i != 3)
            oss << ".";
    }
    return oss.str();
}

std::string zsocket::localhost()
{
    char * buf = new char[256];
    
    if(gethostname(buf, 256))
        return "LocalHost";
    else
        return buf;
}

std::string zsocket::remoteip()
{
    return m_remoteip;
}

std::string zsocket::remotehost()
{
    return m_remotehost;
}

int zsocket::port()
{
    return m_port;
}

//************************************************
//Static members
//************************************************
HWND zsocket::sckwnd = NULL;
int zsocket::count = 0;
zsocket * zsocket::firstsck = NULL;
zsocket * zsocket::lastsck = NULL;
DWORD zsocket::msgthreadid = 0;

DWORD WINAPI zsocket::msgloop(LPVOID)
{
    msgthreadid = GetCurrentThreadId();

    WNDCLASSEX wincl;
    wincl.hInstance = 0;
    wincl.lpszClassName = "zsocket";
    wincl.lpfnWndProc = zsocket::msgproc;
    wincl.style = 0;
    wincl.cbSize = sizeof (WNDCLASSEX);
    wincl.hIcon = 0;
    wincl.hIconSm = 0;
    wincl.hCursor = 0;
    wincl.lpszMenuName = NULL;
    wincl.cbClsExtra = 0;
    wincl.cbWndExtra = 0;
    wincl.hbrBackground = 0;
    RegisterClassEx (&wincl);

    sckwnd = CreateWindowEx(0, "zsocket", "zsocket message window", 0, 0, 0, 0, 0, 0, 0, 0, NULL);
    MSG sckmessages;
    while (GetMessage(&sckmessages, NULL, 0, 0) && count)
    {
        TranslateMessage(&sckmessages);
        DispatchMessage(&sckmessages);
    }
    DestroyWindow(sckwnd);
	return 0;
}

LRESULT CALLBACK zsocket::msgproc (HWND hwnd, UINT umsg, WPARAM wparam, LPARAM lparam)
{
    if (umsg == SOCKET_MESSAGE)
    {
        zsocket * sck = NULL;
        for (zsocket * fsck = firstsck; fsck != NULL; fsck = fsck->nextsck)
        {
            if (fsck->sockid == wparam)
            {
                sck = fsck;
                break;
            }
        }
        if (sck == NULL)
            return 0;

        int error = WSAGETSELECTERROR(lparam);
        int event = WSAGETSELECTEVENT(lparam);
        char * buf;
        
        if (error != 0)
        {
            if (sck->state() == sckClosed)  //Ignore error
                return 0;
            if(error != WSAEWOULDBLOCK)
                sck->close();
            sck->error(error, errortostring(error));
        }
        else
        {
            sck->m_lasttime = GetTickCount();
            int result;
            switch (event)
            {
                case FD_READ:
                    buf = new char[sck->m_recievebuffersize];
                    result = recv(sck->sockid, buf, sck->m_recievebuffersize, 0);
                    if (result == SOCKET_ERROR)
                    {
                        delete [] buf;
                        error = GetLastError();
                        if (error != WSAEWOULDBLOCK)
                        {
                            sck->close();
                            sck->error(error, errortostring(error));
                        }
                        break;
                    }
                    sck->recvbuffer += std::string(buf, result);
                    delete [] buf;
                    sck->dataarrival(sck->recvbuffer.length());
                    break;
                case FD_WRITE:
                    sck->m_write = true;
                    sck->sendqueueddata();
                    break;
                case FD_ACCEPT:
                    sck->connectionrequest(accept(sck->sockid, NULL, 0));
                    break;
                case FD_CONNECT:
                    sck->changestate(sckConnected);
                    break;
                case FD_CLOSE:
                    sck->close();
                    break;
            }
        }
        return 0;
    }
    else if (umsg == SOCKET_ADDSELECT)
    {
        WSAAsyncSelect((SOCKET)wparam, sckwnd, SOCKET_MESSAGE, FD_READ | FD_WRITE | FD_ACCEPT | FD_CONNECT | FD_CLOSE);
    }

    return DefWindowProc(hwnd, umsg, wparam, lparam);
}

std::string zsocket::errortostring(int error)
{
    switch(error)
    {
        case WSAEACCES:
            return "Permission denied.";
        case WSAEADDRINUSE:
            return "Address already in use.";
        case WSAEADDRNOTAVAIL:
            return "Cannot assign requested address.";
        case WSAEAFNOSUPPORT:
            return "Address family not supported by protocol family.";
        case WSAEALREADY:
            return "Operation already in progress.";
        case WSAECONNABORTED:
            return "Software caused connection abort.";
        case WSAECONNREFUSED:
            return "Connection refused.";
        case WSAECONNRESET:
            return "Connection reset by peer.";
        case WSAEDESTADDRREQ:
            return "Destination address required.";
        case WSAEFAULT:
            return "Bad address.";
        case WSAEHOSTUNREACH:
            return "No route to host.";
        case WSAEINPROGRESS:
            return "Operation now in progress.";
        case WSAEINTR:
            return "Interrupted function call.";
        case WSAEINVAL:
            return "Invalid argument.";
        case WSAEISCONN:
            return "Socket is already connected.";
        case WSAEMFILE:
            return "Too many open files.";
        case WSAEMSGSIZE:
            return "Message too long.";
        case WSAENETDOWN:
            return "Network is down.";
        case WSAENETRESET:
            return "Network dropped connection on reset.";
        case WSAENETUNREACH:
            return "Network is unreachable.";
        case WSAENOBUFS:
            return "No buffer space available.";
        case WSAENOPROTOOPT:
            return "Bad protocol option.";
        case WSAENOTCONN:
            return "Socket is not connected.";
        case WSAENOTSOCK:
            return "Socket operation on nonsocket.";
        case WSAEOPNOTSUPP:
            return "Operation not supported.";
        case WSAEPFNOSUPPORT:
            return "Protocol family not supported.";
        case WSAEPROCLIM:
            return "Too many processes.";
        case WSAEPROTONOSUPPORT:
            return "Protocol not supported.";
        case WSAEPROTOTYPE:
            return "Protocol wrong type for socket.";
        case WSAESHUTDOWN:
            return "Cannot send after socket shutdown.";
        case WSAESOCKTNOSUPPORT:
            return "Socket type not supported.";
        case WSAETIMEDOUT:
            return "Connection timed out.";
        case WSAEWOULDBLOCK:
            return "Resource temporarily unavailable.";
        case WSAHOST_NOT_FOUND:
            return "Host not found.";
        case WSANOTINITIALISED:
            return "Successful WSAStartup not yet performed.";
        case WSANO_DATA:
            return "Valid name, no data record of requested type.";
        case WSANO_RECOVERY:
            return "This is a nonrecoverable error.";
        case WSASYSNOTREADY:
            return "Network subsystem is unavailable.";
        case WSATRY_AGAIN:
            return "Nonauthoritative host not found.";
        case WSAVERNOTSUPPORTED:
            return "Winsock.dll version out of range.";
        default:
            return "Unknown error.";
    }
    return "Unknown error.";
}

#endif