工作生活

C++ 使用 websocket 协议

2019-07-03  本文已影响0人  _给我一支烟_

开发了一款微信小游戏《约战24点》 服务器是用 C++ 写的,与客户端之间采用的是 websocket 协议通信。C++ 中使用 websocket 需要对协议数据进行处理才能使用。

1. websocket 协议数据格式详解

0                   1                   2                   3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len |    Extended payload length    |
|I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
|N|V|V|V|       |S|             |   (if payload len==126/127)   |
| |1|2|3|       |K|             |                               |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|     Extended payload length continued, if payload len == 127  |
+ - - - - - - - - - - - - - - - +-------------------------------+
|                               |Masking-key, if MASK set to 1  |
+-------------------------------+-------------------------------+
| Masking-key (continued)       |          Payload Data         |
+-------------------------------- - - - - - - - - - - - - - - - +
:                     Payload Data continued ...                :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|                     Payload Data continued ...                |
+---------------------------------------------------------------+

•FIN
标识是否为此消息的最后一个数据包,占1bit
•RSV1, RSV2, RSV3: 用于扩展协议,一般为0,各占1bit

•Opcode
数据包类型(frame type),占4bits
0x0:标识一个中间数据包
0x1:标识一个text类型数据包
0x2:标识一个binary类型数据包
0x3-7:保留
0x8:标识一个断开连接类型数据包
0x9:标识一个ping类型数据包
0xA:表示一个pong类型数据包
0xB-F:保留

•MASK:占1bits
用于标识PayloadData是否经过掩码处理。如果是1,Masking-key域的数据即是掩码密钥,用于解码PayloadData。客户端发出的数据帧需要进行掩码处理,所以此位是1。

•Payload length
Payload data的长度,占7bits,7+16bits,7+64bits:
如果其值在0-125,则是payload的真实长度。
如果值是126,则后面2个字节形成的16bits无符号整型数的值是payload的真实长度。注意,网络字节序,需要转换。
如果值是127,则后面8个字节形成的64bits无符号整型数的值是payload的真实长度。注意,网络字节序,需要转换。
这里的长度表示遵循一个原则,用最少的字节表示长度(尽量减少不必要的传输)。举例说,payload真实长度是124,在0-125之间,必须用前7位表示;不允许长度1是126或127,然后长度2是124,这样违反原则。

•Payload data
应用层数据

---------------------server解析client端的数据---------------------------
接收到客户端数据后的解析规则如下:
•1byte
◦1bit: frame-fin,x0表示该message后续还有frame;x1表示是message的最后一个frame
◦3bit: 分别是frame-rsv1、frame-rsv2和frame-rsv3,通常都是x0
◦4bit: frame-opcode,x0表示是延续frame;x1表示文本frame;x2表示二进制frame;x3-7保留给非控制frame;x8表示关 闭连接;x9表示ping;xA表示pong;xB-F保留给控制frame

•2byte
◦1bit: Mask,1表示该frame包含掩码;0表示无掩码
◦7bit、7bit+2byte、7bit+8byte: 7bit取整数值,若在0-125之间,则是负载数据长度;若是126表示,后两个byte取无符号16位整数值,是负载长度;127表示后8个 byte,取64位无符号整数值,是负载长度
◦3-6byte: 这里假定负载长度在0-125之间,并且Mask为1,则这4个byte是掩码
◦7-end byte: 长度是上面取出的负载长度,包括扩展数据和应用数据两部分,通常没有扩展数据;若Mask为1,则此数据需要解码,解码规则为1 -4byte掩码循环和数据byte做异或操作。

2. C++对 websocket 封装处理

WebSocket.h
//
// Description: WebSocket RFC6544 codec, written in C++.
//

#ifndef PROJECT_WEBSOCKET_H
#define PROJECT_WEBSOCKET_H

#include <string>
#include <vector>

enum WSFrameType {
    ERROR_FRAME=0xFF00,
    INCOMPLETE_FRAME=0xFE00,

    OPENING_FRAME=0x3300,
    CLOSING_FRAME=0x3400,

    INCOMPLETE_TEXT_FRAME=0x01,
    INCOMPLETE_BINARY_FRAME=0x02,

    TEXT_FRAME=0x81,
    BINARY_FRAME=0x82,

    PING_FRAME=0x19,
    PONG_FRAME=0x1A
};

enum WSStatus
{
    WS_STATUS_UNCONNECT = 1,
    WS_STATUS_CONNECT = 2,
};

class WebSocket
{
public:
    WebSocket();

    //解析 WebSocket 的握手数据
    bool parseHandshake(const std::string& request);

    //应答 WebSocket 的握手
    std::string respondHandshake();

    //解析 WebSocket 的协议具体数据,客户端-->服务器
    int getWSFrameData(char* msg, int msgLen, std::vector<char>& outBuf, int* outLen);

    //封装 WebSocket 协议的数据,服务器-->客户端
    int makeWSFrameData(char* msg, int msgLen, std::vector<char>& outBuf);

    //封装 WebSocket 协议的数据头(二进制数据)
    static int makeWSFrameDataHeader(int len, std::vector<char>& header);

private:
    std::string websocketKey_;//握手中客户端发来的key
};


#endif //PROJECT_WEBSOCKET_H

2. WebSocket.cpp

#include "WebSocket.h"
#include "BaseFunc.h"
#include <openssl/sha.h>  //for SHA1
#include <arpa/inet.h>    //for ntohl
#include <string.h>


WebSocket::WebSocket()
{
}

bool WebSocket::parseHandshake(const std::string& request)
{
    // 解析WEBSOCKET请求头信息
    bool ret = false;
    std::istringstream stream(request.c_str());
    std::string reqType;
    std::getline(stream, reqType);
    if (reqType.substr(0, 4) != "GET ")
        return ret;

    std::string header;
    std::string::size_type pos = 0;
    while (std::getline(stream, header) && header != "\r")
    {
        header.erase(header.end() - 1);
        pos = header.find(": ", 0);
        if (pos != std::string::npos)
        {
            std::string key = header.substr(0, pos);
            std::string value = header.substr(pos + 2);
            if (key == "Sec-WebSocket-Key")
            {
                ret = true;
                websocketKey_ = value;
                break;
            }
        }
    }

    return ret;
}


std::string WebSocket::respondHandshake()
{
    // 算出WEBSOCKET响应信息
    std::string response = "HTTP/1.1 101 Switching Protocols\r\n";
    response += "Upgrade: websocket\r\n";
    response += "Connection: upgrade\r\n";
    response += "Sec-WebSocket-Accept: ";

    //使用请求传过来的KEY+协议字符串,先用SHA1加密然后使用base64编码算出一个应答的KEY
    const std::string magicKey("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
    std::string serverKey = websocketKey_ + magicKey;
    //LOG_INFO << "serverKey:" << serverKey;

    //SHA1
    unsigned char digest[SHA_DIGEST_LENGTH];
    SHA1((unsigned char*)serverKey.c_str(), serverKey.length(), (unsigned char*)&digest);
    //printf("DIGEST:"); for(int i=0; i<20; i++) printf("%02x ",digest[i]); printf("\n");

    //Base64
    char basestr[1024] = {0};
    base64_encode((char*)digest, SHA_DIGEST_LENGTH, basestr);

    //完整的握手应答
    response = response + std::string(basestr) + "\r\n";
    //LOG_INFO << "RESPONSE:" << response;

    return response;
}


int WebSocket::getWSFrameData(char* msg, int msgLen, std::vector<char>& outBuf, int* outLen)
{
    if(msgLen < 2)
        return INCOMPLETE_FRAME;

    uint8_t fin_ = 0;
    uint8_t opcode_ = 0;
    uint8_t mask_ = 0;
    uint8_t masking_key_[4] = {0,0,0,0};
    uint64_t payload_length_ = 0;
    int pos = 0;
    //FIN
    fin_ = (unsigned char)msg[pos] >> 7;
    //Opcode
    opcode_ = msg[pos] & 0x0f;
    pos++;
    //MASK
    mask_ = (unsigned char)msg[pos] >> 7;
    //Payload length
    payload_length_ = msg[pos] & 0x7f;
    pos++;
    if(payload_length_ == 126)
    {
        uint16_t length = 0;
        memcpy(&length, msg + pos, 2);
        pos += 2;
        payload_length_ = ntohs(length);
    }
    else if(payload_length_ == 127)
    {
        uint32_t length = 0;
        memcpy(&length, msg + pos, 4);
        pos += 4;
        payload_length_ = ntohl(length);
    }
    //Masking-key
    if(mask_ == 1)
    {
        for(int i = 0; i < 4; i++)
            masking_key_[i] = msg[pos + i];
        pos += 4;
    }
    //取出消息数据
    if (msgLen >= pos + payload_length_ )
    {
        //Payload data
        *outLen = pos + payload_length_;
        outBuf.clear();
        if(mask_ != 1)
        {
            char* dataBegin = msg + pos;
            outBuf.insert(outBuf.begin(), dataBegin, dataBegin+payload_length_);
        }
        else
        {
            for(uint i = 0; i < payload_length_; i++)
            {
                int j = i % 4;
                outBuf.push_back(msg[pos + i] ^ masking_key_[j]);
            }
        }
    }
    else
    {
        return INCOMPLETE_FRAME;
    }

//    printf("WEBSOCKET PROTOCOL\n"
//            "FIN: %d\n"
//            "OPCODE: %d\n"
//            "MASK: %d\n"
//            "PAYLOADLEN: %d\n"
//            "outLen:%d\n",
//            fin_, opcode_, mask_, payload_length_, *outLen);

    //断开连接类型数据包
    if ((int)opcode_ == 0x8)
        return -1;

    return 0;
}


int WebSocket::makeWSFrameData(char* msg, int msgLen, std::vector<char>& outBuf)
{
    std::vector<char> header;
    makeWSFrameDataHeader(msgLen, header);
    outBuf.insert(outBuf.begin(), header.begin(), header.end());
    outBuf.insert(outBuf.end(), msg, msg+msgLen);
    return 0;
}

int WebSocket::makeWSFrameDataHeader(int len, std::vector<char>& header)
{
    header.push_back((char)BINARY_FRAME);
    if(len <= 125)
    {
        header.push_back((char)len);
    }
    else if(len <= 65535)
    {
        header.push_back((char)126);//16 bit length follows
        header.push_back((char)((len >> 8) & 0xFF));// leftmost first
        header.push_back((char)(len & 0xFF));
    }
    else // >2^16-1 (65535)
    {
        header.push_back((char)127);//64 bit length follows

        // write 8 bytes length (significant first)
        // since msg_length is int it can be no longer than 4 bytes = 2^32-1
        // padd zeroes for the first 4 bytes
        for(int i=3; i>=0; i--)
        {
            header.push_back((char)0);
        }
        // write the actual 32bit msg_length in the next 4 bytes
        for(int i=3; i>=0; i--)
        {
            header.push_back((char)((len >> 8*i) & 0xFF));
        }
    }

    return 0;
}

3. BaseFunc.h

#include <openssl/pem.h>
#include <openssl/bio.h>
#include <openssl/evp.h>


// base64 编码
int base64_encode(char *in_str, int in_len, char *out_str)
{
    BIO *b64, *bio;
    BUF_MEM *bptr = NULL;
    size_t size = 0;

    if (in_str == NULL || out_str == NULL)
        return -1;

    b64 = BIO_new(BIO_f_base64());
    bio = BIO_new(BIO_s_mem());
    bio = BIO_push(b64, bio);

    BIO_write(bio, in_str, in_len);
    BIO_flush(bio);

    BIO_get_mem_ptr(bio, &bptr);
    memcpy(out_str, bptr->data, bptr->length);
    out_str[bptr->length] = '\0';
    size = bptr->length;

    BIO_free_all(bio);
    return 0;
}

// base64 解码
int base64_decode(char *in_str, int in_len, char *out_str)
{
    BIO *b64, *bio;
    BUF_MEM *bptr = NULL;
    //int counts;
    int size = 0;

    if (in_str == NULL || out_str == NULL)
        return -1;

    b64 = BIO_new(BIO_f_base64());
    BIO_set_flags(b64, BIO_FLAGS_BASE64_NO_NL);

    bio = BIO_new_mem_buf(in_str, in_len);
    bio = BIO_push(b64, bio);

    size = BIO_read(bio, out_str, in_len);
    out_str[size] = '\0';

    BIO_free_all(bio);
    return 0;
}
上一篇 下一篇

猜你喜欢

热点阅读