IOCP 模型2 AcceptEx
阅读原文时间:2023年07月13日阅读:3

// IOCP2.cpp : Defines the entry point for the console application.
//

#include "stdafx.h"
#include
#include
#include
#include
#pragma comment(lib, "WS2_32.lib")

#define MAX_BUFFER 256
#define MAX_TIMEOUT 1000
#define MAX_SOCKET 1024
#define MAX_THREAD 64
#define MAX_ACCEPT 5

typedef enum _OPERATION_INFO_
{
OP_NULL,
OP_ACCEPT,
OP_READ,
OP_WRITE
}OPERATIONINFO;

typedef struct _PER_HANDLE_DATA_
{
public:
_PER_HANDLE_DATA_()
{
clean();
}
~_PER_HANDLE_DATA_()
{
clean();
}
protected:
void clean()
{
sock = INVALID_SOCKET;
memset(&addr, 0, sizeof(addr));
addr.sin_addr.S_un.S_addr = INADDR_ANY;
addr.sin_port = htons(0);
addr.sin_family = AF_INET;
}
public:
SOCKET sock;
SOCKADDR_IN addr;
}PERHANDLEDATA, *PPERHANDLEDATA;

typedef struct _PER_IO_DTATA_
{
public:
_PER_IO_DTATA_()
{
clean();
}
~_PER_IO_DTATA_()
{
clean();
}
void clean()
{
ZeroMemory(&ol, sizeof(ol));
memset(buf, 0, sizeof(buf));
sAccept = INVALID_SOCKET;
sListen = INVALID_SOCKET;
wsaBuf.buf = buf;
wsaBuf.len = MAX_BUFFER;
opType = OP_NULL;
}
public:
WSAOVERLAPPED ol;
SOCKET sAccept; // Only valid with AcceptEx
SOCKET sListen; // Only valid with AcceptEx
WSABUF wsaBuf;
char buf[MAX_BUFFER];
OPERATIONINFO opType;
}PERIODATA, *PPERIODATA;

HANDLE hThread[MAX_THREAD] = {0};
PERIODATA* pAcceptData[MAX_ACCEPT] = {0};
int g_nThread = 0;
BOOL g_bExitThread = FALSE;
LPFN_ACCEPTEX lpfnAcceptEx = NULL;
LPFN_GETACCEPTEXSOCKADDRS lpfnGetAcceptExSockAddrs = NULL;
GUID GuidAcceptEx = WSAID_ACCEPTEX;
GUID GuidGetAcceptExSockAddrs = WSAID_GETACCEPTEXSOCKADDRS;

unsigned __stdcall ThreadProc(LPVOID lParam);
BOOL PostAccept(PERIODATA* pIoData);

int _tmain(int argc, _TCHAR* argv[])
{
WORD wVersionRequested = MAKEWORD(2, 2);
WSADATA wsaData;
if(0 != WSAStartup(wVersionRequested, &wsaData))
{
printf("WSAStartup failed with error code: %d/n", GetLastError());
return EXIT_FAILURE;
}

if(2 != HIBYTE(wsaData.wVersion) || 2 != LOBYTE(wsaData.wVersion))  
{  
    printf("Socket version not supported./n");  
    WSACleanup();  
    return EXIT\_FAILURE;  
}

// Create IOCP  
HANDLE hIOCP = CreateIoCompletionPort(INVALID\_HANDLE\_VALUE, NULL, NULL, 0);  
if(NULL == hIOCP)  
{  
    printf("CreateIoCompletionPort failed with error code: %d/n", WSAGetLastError());  
    WSACleanup();  
    return EXIT\_FAILURE;  
}

// Create worker thread  
SYSTEM\_INFO si = {0};  
GetSystemInfo(&si);  
for(int i = 0; i < (int)si.dwNumberOfProcessors+2; i++)  
{  
    hThread\[g\_nThread\] = (HANDLE)\_beginthreadex(NULL, 0, ThreadProc, (LPVOID)hIOCP, 0, NULL);  
    if(NULL == hThread\[g\_nThread\])  
    {  
        printf("\_beginthreadex failed with error code: %d/n", GetLastError());  
        continue;  
    }  
    ++g\_nThread;

    if(g\_nThread > MAX\_THREAD)  
    {  
        break;  
    }  
}

// Create listen SOCKET  
SOCKET sListen = WSASocket(AF\_INET, SOCK\_STREAM, IPPROTO\_TCP, NULL, 0, WSA\_FLAG\_OVERLAPPED);  
if(INVALID\_SOCKET == sListen)  
{  
    printf("WSASocket failed with error code: %d/n", WSAGetLastError());  
    goto EXIT\_CODE;  
}

// Associate SOCKET with IOCP  
if(NULL == CreateIoCompletionPort((HANDLE)sListen, hIOCP, NULL, 0))  
{  
    printf("CreateIoCompletionPort failed with error code: %d/n", WSAGetLastError());  
    if(INVALID\_SOCKET != sListen)  
    {  
        closesocket(sListen);  
        sListen = INVALID\_SOCKET;  
    }  
    goto EXIT\_CODE;  
}

// Bind SOCKET  
SOCKADDR\_IN addr;  
memset(&addr, 0, sizeof(addr));  
addr.sin\_family = AF\_INET;  
addr.sin\_addr.S\_un.S\_addr = inet\_addr("127.0.0.1");  
addr.sin\_port = htons(5050);  
if(SOCKET\_ERROR == bind(sListen, (LPSOCKADDR)&addr, sizeof(addr)))  
{  
    printf("bind failed with error code: %d/n", WSAGetLastError());  
    if(INVALID\_SOCKET != sListen)  
    {  
        closesocket(sListen);  
        sListen = INVALID\_SOCKET;  
    }  
    goto EXIT\_CODE;  
}

// Start Listen  
if(SOCKET\_ERROR == listen(sListen, 200))  
{  
    printf("listen failed with error code: %d/n", WSAGetLastError());  
    if(INVALID\_SOCKET != sListen)  
    {  
        closesocket(sListen);  
        sListen = INVALID\_SOCKET;  
    }  
    goto EXIT\_CODE;  
}

printf("Server start, wait for client to connect .../n");

DWORD dwBytes = 0;  
if(SOCKET\_ERROR == WSAIoctl(sListen, SIO\_GET\_EXTENSION\_FUNCTION\_POINTER, &GuidAcceptEx, sizeof(GuidAcceptEx), &lpfnAcceptEx,  
    sizeof(lpfnAcceptEx), &dwBytes, NULL, NULL))  
{  
    printf("WSAIoctl failed with error code: %d/n", WSAGetLastError());  
    if(INVALID\_SOCKET != sListen)  
    {  
        closesocket(sListen);  
        sListen = INVALID\_SOCKET;  
    }  
    goto EXIT\_CODE;  
}

if(SOCKET\_ERROR == WSAIoctl(sListen, SIO\_GET\_EXTENSION\_FUNCTION\_POINTER, &GuidGetAcceptExSockAddrs,  
    sizeof(GuidGetAcceptExSockAddrs), &lpfnGetAcceptExSockAddrs, sizeof(lpfnGetAcceptExSockAddrs),  
    &dwBytes, NULL, NULL))  
{  
    printf("WSAIoctl failed with error code: %d/n", WSAGetLastError());  
    if(INVALID\_SOCKET != sListen)  
    {  
        closesocket(sListen);  
        sListen = INVALID\_SOCKET;  
    }  
    goto EXIT\_CODE;  
}

// Post MAX\_ACCEPT accept  
for(int i=0; i<MAX\_ACCEPT; i++)  
{  
    pAcceptData\[i\] = new PERIODATA;  
    pAcceptData\[i\]->sListen = sListen;  
    PostAccept(pAcceptData\[i\]);  
}  
// After 1 hour later, Server shutdown.  
Sleep(1000 \* 60 \*60);

EXIT_CODE:
g_bExitThread = TRUE;

PostQueuedCompletionStatus(hIOCP, 0, NULL, NULL);  
WaitForMultipleObjects(g\_nThread, hThread, TRUE, INFINITE);  
for(int i = 0; i < g\_nThread; i++)  
{  
    CloseHandle(hThread\[g\_nThread\]);  
}

for(int i=0; i<MAX\_ACCEPT; i++)  
{  
    if(pAcceptData\[i\])  
    {  
        delete pAcceptData\[i\];  
        pAcceptData\[i\] = NULL;  
    }  
}

if(INVALID\_SOCKET != sListen)  
{  
    closesocket(sListen);  
    sListen = INVALID\_SOCKET;  
}  
CloseHandle(hIOCP); // Close IOCP

WSACleanup();  
return 0;  

}

BOOL PostAccept(PERIODATA* pIoData)
{
if(INVALID_SOCKET == pIoData->sListen)
{
return FALSE;
}

DWORD dwBytes = 0;  
pIoData->opType = OP\_ACCEPT;  
pIoData->sAccept = WSASocket(AF\_INET, SOCK\_STREAM, IPPROTO\_TCP, NULL, 0, WSA\_FLAG\_OVERLAPPED);  
if(INVALID\_SOCKET == pIoData->sAccept)  
{  
    printf("WSASocket failed with error code: %d/n", WSAGetLastError());  
    return FALSE;  
}

if(FALSE == lpfnAcceptEx(pIoData->sListen, pIoData->sAccept, pIoData->wsaBuf.buf, pIoData->wsaBuf.len - ((sizeof(SOCKADDR\_IN)+16)\*2),  
    sizeof(SOCKADDR\_IN)+16, sizeof(SOCKADDR\_IN)+16, &dwBytes, &(pIoData->ol)))  
{  
    if(WSA\_IO\_PENDING != WSAGetLastError())  
    {  
        printf("lpfnAcceptEx failed with error code: %d/n", WSAGetLastError());

        return FALSE;  
    }  
}  
return TRUE;  

}

unsigned __stdcall ThreadProc(LPVOID lParam)
{
HANDLE hIOCP = (HANDLE)lParam;

PERHANDLEDATA\* pPerHandleData = NULL;  
PERIODATA\* pPerIoData = NULL;  
WSAOVERLAPPED\* lpOverlapped = NULL;  
DWORD dwTrans = 0;  
DWORD dwFlags = 0;  
while(!g\_bExitThread)  
{  
    BOOL bRet = GetQueuedCompletionStatus(hIOCP, &dwTrans, (PULONG\_PTR)&pPerHandleData, &lpOverlapped, MAX\_TIMEOUT);  
    if(!bRet)  
    {  
        // Timeout and exit thread  
        if(WAIT\_TIMEOUT == GetLastError())  
        {  
            continue;  
        }  
        // Error  
        printf("GetQueuedCompletionStatus failed with error: %d/n", GetLastError());  
        continue;  
    }  
    else  
    {  
        pPerIoData = CONTAINING\_RECORD(lpOverlapped, PERIODATA, ol);  
        if(NULL == pPerIoData)  
        {  
            // Exit thread  
            break;  
        }

        if((0 == dwTrans) && (OP\_READ == pPerIoData->opType || OP\_WRITE == pPerIoData->opType))  
        {  
            // Client leave.  
            printf("Client: <%s : %d> leave./n", inet\_ntoa(pPerHandleData->addr.sin\_addr), ntohs(pPerHandleData->addr.sin\_port));  
            closesocket(pPerHandleData->sock);  
            delete pPerHandleData;  
            delete pPerIoData;  
            continue;  
        }  
        else  
        {  
            switch(pPerIoData->opType)  
            {  
            case OP\_ACCEPT: // Accept  
                {  
                    SOCKADDR\_IN\* remote = NULL;  
                    SOCKADDR\_IN\* local = NULL;  
                    int remoteLen = sizeof(SOCKADDR\_IN);  
                    int localLen = sizeof(SOCKADDR\_IN);  
                    lpfnGetAcceptExSockAddrs(pPerIoData->wsaBuf.buf, pPerIoData->wsaBuf.len - ((sizeof(SOCKADDR\_IN)+16)\*2),  
                        sizeof(SOCKADDR\_IN)+16, sizeof(SOCKADDR\_IN)+16, (LPSOCKADDR\*)&local, &localLen, (LPSOCKADDR\*)&remote, &remoteLen);  
                    printf("Client <%s : %d> come in./n", inet\_ntoa(remote->sin\_addr), ntohs(remote->sin\_port));  
                    printf("Recv Data: <%s : %d> %s./n", inet\_ntoa(remote->sin\_addr), ntohs(remote->sin\_port), pPerIoData->wsaBuf.buf);

                    if(NULL != pPerHandleData)  
                    {  
                        delete pPerHandleData;  
                        pPerHandleData = NULL;  
                    }  
                    pPerHandleData = new PERHANDLEDATA;  
                    pPerHandleData->sock = pPerIoData->sAccept;

                    PERHANDLEDATA\* pPerHandle = new PERHANDLEDATA;  
                    pPerHandle->sock = pPerIoData->sAccept;  
                    PERIODATA\* pPerIo = new PERIODATA;  
                    pPerIo->opType = OP\_WRITE;  
                    strcpy\_s(pPerIo->buf, MAX\_BUFFER, pPerIoData->buf);  
                    DWORD dwTrans = strlen(pPerIo->buf);  
                    memcpy(&(pPerHandleData->addr), remote, sizeof(SOCKADDR\_IN));  
                    // Associate with IOCP  
                    if(NULL == CreateIoCompletionPort((HANDLE)(pPerHandleData->sock), hIOCP, (ULONG\_PTR)pPerHandleData, 0))  
                    {  
                        printf("CreateIoCompletionPort failed with error code: %d/n", GetLastError());  
                        closesocket(pPerHandleData->sock);  
                        delete pPerHandleData;  
                        continue;  
                    }

                    // Post Accept  
                    memset(&(pPerIoData->ol), 0, sizeof(pPerIoData->ol));  
                    PostAccept(pPerIoData);

                    // Post Receive  
                    DWORD dwFlags = 0;  
                    if(SOCKET\_ERROR == WSASend(pPerHandle->sock, &(pPerIo->wsaBuf), 1,  
                        &dwTrans, dwFlags, &(pPerIo->ol), NULL))  
                    {  
                        if(WSA\_IO\_PENDING != WSAGetLastError())  
                        {  
                            printf("WSASend failed with error code: %d/n", WSAGetLastError());  
                            closesocket(pPerHandle->sock);  
                            delete pPerHandle;  
                            delete pPerIo;  
                            continue;  
                        }  
                    }  
                }  
                break;

            case OP\_READ: // Read  
                printf("recv client <%s : %d> data: %s/n", inet\_ntoa(pPerHandleData->addr.sin\_addr), ntohs(pPerHandleData->addr.sin\_port), pPerIoData->buf);  
                pPerIoData->opType = OP\_WRITE;  
                memset(&(pPerIoData->ol), 0, sizeof(pPerIoData->ol));  
                if(SOCKET\_ERROR == WSASend(pPerHandleData->sock, &(pPerIoData->wsaBuf), 1, &dwTrans, dwFlags, &(pPerIoData->ol), NULL))  
                {  
                    if(WSA\_IO\_PENDING != WSAGetLastError())  
                    {  
                        printf("WSASend failed with error code: %d./n", WSAGetLastError());  
                        continue;  
                    }  
                }  
                break;

            case OP\_WRITE: // Write  
                {  
                    pPerIoData->opType = OP\_READ;  
                    dwFlags = 0;  
                    memset(&(pPerIoData->ol), 0, sizeof(pPerIoData->ol));  
                    memset(pPerIoData->buf, 0, sizeof(pPerIoData->buf));  
                    pPerIoData->wsaBuf.buf = pPerIoData->buf;  
                    dwTrans = pPerIoData->wsaBuf.len = MAX\_BUFFER;  
                    if(SOCKET\_ERROR == WSARecv(pPerHandleData->sock, &(pPerIoData->wsaBuf), 1, &dwTrans, &dwFlags, &(pPerIoData->ol), NULL))  
                    {  
                        if(WSA\_IO\_PENDING != WSAGetLastError())  
                        {  
                            printf("WSARecv failed with error code: %d./n", WSAGetLastError());  
                            continue;  
                        }  
                    }  
                }  
                break;

            default:  
                break;  
            }  
        }  
    }  
}  
return 0;  

}