Socket SSL/TLS之iOS/Mac

2019-03-21  本文已影响0人  WorldPeace_hp
#import <Foundation/Foundation.h>

NS_ASSUME_NONNULL_BEGIN

@protocol SocketDelegate;

@interface HPSSLSocket : NSObject

- (instancetype)initWithDelegate:(id<SocketDelegate>)delegate;

- (void)connectToHost:(NSString *)host onPort:(int)port;

- (void)startSSL:(NSDictionary *)sslSettings;

- (void)writeData:(NSData *)data;
- (void)readData;

@end

@protocol SocketDelegate <NSObject>
@required
- (void)didConnectSocket:(HPSSLSocket *)socket host:(NSString *)host port:(int)port;
- (void)didDisconnectSocket:(HPSSLSocket *)socket error:(nullable NSError *)error;

- (void)didWriteData:(HPSSLSocket *)socket;
- (void)didReadSocket:(HPSSLSocket *)socket data:(NSData *)data;

- (void)didReceiveChallenge:(HPSSLSocket *)socket trust:(SecTrustRef)trust completion:(void (^)(BOOL shouldTrustPeer))completion;
- (void)didSecure:(HPSSLSocket *)socket;

@end

NS_ASSUME_NONNULL_END
#import "HPSSLSocket.h"

#include <netdb.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <sys/time.h>
#include <sys/types.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <netinet/in.h>

@interface HPSSLSocket()
{
    __weak id<SocketDelegate> _delegate;
    
    int _socketFD;
    NSString *_host;
    int _port;
    
    SSLContextRef _sslContext;
    BOOL _isSecure;
}

@property (retain, readonly) NSMutableData *sslWriteData;
@property (retain, readonly) NSMutableData *sslReadData;

@end

@implementation HPSSLSocket

#pragma mark -
#pragma mark -- Life Cycle
- (instancetype)initWithDelegate:(id<SocketDelegate>)aDelegate {
    self = [super init];
    if(self) {
        _delegate = aDelegate;
        
        _sslWriteData = [[NSMutableData alloc] init];
        _sslReadData = [[NSMutableData alloc] init];
    }
    
    return self;
}

- (void)dealloc {
    NSLog(@"%s",__func__);
    
    if (_socketFD != -1) {
        close(_socketFD);
    }
    
    if (_sslContext) {
        CFRelease(_sslContext);
        _sslContext = nil;
    }
}

#pragma mark -
#pragma mark -- Connect
- (void)connectToHost:(NSString *)host onPort:(int)port {
    _host = host;
    _port = port;
    
    if (!host || !port) {
        [self closeWithError:[self otherError:@"host/port error"]];
        return;
    }
    
    struct sockaddr_in connectAddr;
    if (![self getIPByName:[host UTF8String] addr:&connectAddr]) {
        [self closeWithError:[self otherError:@"connect address error"]];
        return;
    }
    
    _socketFD = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
    if (_socketFD == -1) {
        [self closeWithError:[self otherError:@"socket error"]];
        return;
    }
    
    // Prevent SIGPIPE signals
    int nosigpipe = 1;
    setsockopt(_socketFD, SOL_SOCKET, SO_NOSIGPIPE, &nosigpipe, sizeof(nosigpipe));
    
    int flag = fcntl(_socketFD, F_GETFL);
    fcntl(_socketFD, F_SETFL, flag | O_NONBLOCK);
    
    connectAddr.sin_port = htons(port);
    int result = connect(_socketFD, (struct sockaddr*)&connectAddr, sizeof(connectAddr));
    if (result == 0) {
        [self didConnect];
    }
    else {
        int err = (int)errno;
        
        if (err == EINPROGRESS) {
            // if the socket is writable, which indicates connection is established
            [self select:_socketFD read:NO write:YES complete:^(int result) {
                if (result == 0) {
                    [self didConnect];
                }
                else {
                    [self closeWithError:[self otherError:@"select socket error"]];
                }
            }];
            return;
        }
        else {
            [self closeWithError:[self otherError:@"connect error"]];
            return;
        }
    }
}

- (void)didConnect {
    if (_delegate && [_delegate respondsToSelector:@selector(didConnectSocket:host:port:)]) {
        [_delegate didConnectSocket:self host:_host port:_port];
    }
}

- (BOOL)getIPByName:(const char*)hostname addr:(struct sockaddr_in*)addr {
    struct addrinfo hint;
    memset(&hint, 0, sizeof(hint));
    hint.ai_family = AF_INET;
    hint.ai_socktype = SOCK_STREAM;
    hint.ai_protocol = IPPROTO_TCP;
    hint.ai_flags = AI_ADDRCONFIG;
    
    BOOL ret = NO;
    struct addrinfo* addrInfo = NULL;
    getaddrinfo(hostname, NULL, &hint, &addrInfo);
    if (addrInfo) {
        for (struct addrinfo* a = addrInfo; a != NULL; a = a->ai_next) {
            if (a->ai_family == AF_INET) {
                ret = YES;
                memcpy(addr, a->ai_addr, a->ai_addrlen);
                break;
            }
        }
        freeaddrinfo(addrInfo);
    }
    return ret;
}

- (void)select:(int)sock read:(BOOL)read write:(BOOL)write complete:(void (^)(int result))complete {
    fd_set readSet;
    fd_set writeSet;
    fd_set exceptSet;
    
    if (read) {
        FD_ZERO(&readSet);
        FD_SET(sock, &readSet);
    }
    if (write) {
        FD_ZERO(&writeSet);
        FD_SET(sock, &writeSet);
    }
    FD_ZERO(&exceptSet);
    FD_SET(sock, &exceptSet);
    
    struct timeval timeout;
    timeout.tv_sec = 30;
    timeout.tv_usec = 0;
    int ret = select(sock+1, read ? &readSet : NULL, write ? &writeSet : NULL, &exceptSet, &timeout);
    if (ret == 0) {
        //timeout
        complete(-1);
    }
    else if (ret < 0) {
        //handle error
        complete(errno);
    }
    else {
        //success
        if (FD_ISSET(sock, &exceptSet)) {
            complete(-1);
        }
        else if (FD_ISSET(sock, &readSet) || FD_ISSET(sock, &writeSet)) {
            complete(0);
        }
        else {
            complete(-1);
        }
    }
}

#pragma mark -
#pragma mark -- Inner Read/Write
- (void)_send:(int)sock data:(NSData*)data offset:(size_t)offset complete:(void (^)(int err))complete {
    [self select:sock read:NO write:YES complete:^(int result) {
        
        if (result) {
            //error
            complete(result);
            return;
        }
        
        size_t length = data.length - offset;
        ssize_t nsend = send(sock, (const char*)data.bytes + offset, length, 0);
        if (nsend > 0) {
            if (nsend < length) {
                //partial send
                size_t newOffset = offset + nsend;
                [self _send:sock data:data offset:newOffset complete:complete];
            }
            else {
                //all send
                complete(0);
            }
            return;
        }
        else if (nsend == 0) {
            // should not happen
            complete(-1);
            return;
        }
        else {
            int err = errno;
            if (err == EAGAIN || err == EWOULDBLOCK) {
                //should not happen
                if (offset < data.length){
                    //retry
                    [self _send:sock data:data offset:offset complete:complete];
                }
                else {
                    //all data is sent
                    complete(0);
                }
            }
            else {
                //notify error happened
                complete(err);
            }
        }
    }];
}

- (void)_recive:(void (^)(int err, const void* data, ssize_t length))complete {
    [self select:_socketFD read:YES write:NO complete:^(int result) {
        
        if (result != 0) {
            NSLog(@"select result = %d",result);
            complete(result, NULL, 0);
        }
        else {
            char data[16*1024];
            ssize_t nread = recv(self->_socketFD, data, sizeof(data), 0);
            int error = errno;
            NSLog(@"self->_socketFD = %d, nread = %zd, error = %d",self->_socketFD,nread,error);
            if (nread > 0) {
                NSLog(@"nread_1 = %zd, error = %d",nread,error);
                complete(0, data, nread);
            }
            else {
                NSLog(@"nread_2 = %zd, error = %d",nread,error);
                complete(error, NULL, nread);
            }
            return;
        }
    }];
}

#pragma mark -
#pragma mark -- STL/SSL
- (void)startSSL:(NSDictionary *)sslSettings {
    if (!_delegate && !sslSettings) {
        return;
    }
    
    _sslContext = SSLCreateContext(kCFAllocatorDefault, kSSLClientSide, kSSLStreamType);
    
    OSStatus status = SSLSetIOFuncs(_sslContext, &SSLReadFunction, &SSLWriteFunction);
    if (status != noErr) {
        [self closeWithError:[self otherError:@"SSLSetIOFuncs Error"]];
        return;
    }
    
    status = SSLSetConnection(_sslContext, (__bridge SSLConnectionRef)self);
    if (status != noErr) {
        [self closeWithError:[self otherError:@"SSLSetConnection Error"]];
        return;
    }
    
    status = SSLSetSessionOption(_sslContext, kSSLSessionOptionBreakOnServerAuth, true);
    if (status != noErr)
    {
        [self closeWithError:[self otherError:@"Error in SSLSetSessionOption"]];
        return;
    }
    
    [self sslHandshake];
}

- (void)sslHandshake {
    OSStatus status = errSSLWouldBlock;
    do {
        status = SSLHandshake(_sslContext);
        
        if (status == errSSLPeerAuthCompleted) {
            status = [self peerTrust];
        }
        
    } while (status == errSSLWouldBlock);
    
    if (status == noErr) {
        //success
        [self didSSLHandshake];
        return;
    }
    else {
        // handshake failed
        [self closeWithError:[self otherError:@"handshake failed"]];
    }
}

- (OSStatus)peerTrust {
    SecTrustRef trustRef = NULL;
    OSStatus status = SSLCopyPeerTrust(_sslContext, &trustRef);
    if (status == noErr && trustRef) {
        __block BOOL trustPeer = NO;
        if (_delegate && [_delegate respondsToSelector:@selector(didReceiveChallenge:trust:completion:)]) {
            [_delegate didReceiveChallenge:self trust:trustRef completion:^(BOOL shouldTrustPeer) {
                trustPeer = shouldTrustPeer;
            }];
        }
        
        status = trustPeer ? errSSLWouldBlock:errSSLBadCert;
        
        if (trustRef) {
            CFRelease(trustRef);
            trustRef = NULL;
        }
    }
    
    return status;
}

- (void)didSSLHandshake {
    _isSecure = YES;
    
    if (_delegate && [_delegate respondsToSelector:@selector(didSecure:)]) {
        [_delegate didSecure:self];
    }
}

- (OSStatus)sslReadWithBuffer:(void *)buffer length:(size_t *)bufferLength {
    if (*bufferLength == 0) {
        //        NSLog(@"1._SSLRead, noErr");
        return noErr;
    }
    else if (_sslReadData.length > 0) {
        if (*bufferLength > _sslReadData.length) {
            memcpy(buffer, _sslReadData.bytes, _sslReadData.length);
            *bufferLength = _sslReadData.length;
            _sslReadData.length = 0;
            //            NSLog(@"1._SSLRead,not enough data require-length=%lu, cache-length=%lu ", *bufferLength, _sslReadData.length);
            return errSSLWouldBlock;
        }
        else {
            memcpy(buffer, _sslReadData.bytes, *bufferLength);
            [_sslReadData replaceBytesInRange:NSMakeRange(0, *bufferLength) withBytes:NULL length:0];
            //            NSLog(@"1._SSLRead,replaceBytesInRange require-length=%lu, cache-length=%lu ", *bufferLength, _sslReadData.length);
            return noErr;
        }
    }
    else {
        *bufferLength = 0;
        //        NSLog(@"1._SSLRead, errSSLWouldBlock");
        
        __block ssize_t blockLength;
        [self _recive:^(int err, const void *data, ssize_t length) {
            NSLog(@"3.reciving length:%zu callbacek err:%d",length,err);
            blockLength = length;
            if (err) {
                //                NSLog(@"connection recv fail, err=%d", err);
                return;
            }
            else if (length == 0) {
                //                NSLog(@"connection recv 0 byte");
                return;
            }
            
            [self->_sslReadData appendBytes:data length:length];
            NSLog(@"3.1.reciving length:%zu callbacek err:%d",length,err);
        }];
        
        //        NSLog(@"4.errSSLWouldBlock");
        if (blockLength > 0) {
            return errSSLWouldBlock;
        }
        else if (blockLength == 0) {
            return errSSLClosedAbort;
        }
        else {
            if (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN) {
                return errSSLWouldBlock;
            }
            else {
                return errSSLClosedAbort;
            }
        }
    }
}

- (OSStatus)sslWriteWithBuffer:(const void *)buffer length:(size_t *)bufferLength {
    //    NSLog(@"1.ssl write buffer length = %lu",*bufferLength);
    
    NSData *data = [NSData dataWithBytes:buffer length:*bufferLength];
    __block int blockError = 0;
    [self _send:_socketFD data:data offset:0 complete:^(int err) {
        if (err == 0) {
            blockError = noErr;
            //            NSLog(@"2.ssl write buffer done");
        }
        else {
            //            NSLog(@"2.ssl write Error : %d",err);
            blockError = err;
        }
    }];
    
    if (_isSecure) {
        //        NSLog(@"3.ssl write noErr");
        return blockError;
    }
    else {
        //        NSLog(@"3.ssl write contining");
        return errSSLWouldBlock;
    }
}

static OSStatus SSLReadFunction(SSLConnectionRef connection, void *buffer, size_t *bufferLength) {
    HPSSLSocket *_self = (__bridge HPSSLSocket *)connection;
    
    return [_self sslReadWithBuffer:buffer length:bufferLength];
}

static OSStatus SSLWriteFunction(SSLConnectionRef connection, const void *buffer, size_t *bufferLength) {
    HPSSLSocket *_self = (__bridge HPSSLSocket *)connection;
    
    return [_self sslWriteWithBuffer:buffer length:bufferLength];
}

#pragma mark -
#pragma mark -- Write/Read
- (void)writeData:(NSData*)data {
    if (!data) {
        return;
    }
    
    [self _writeData:data offset:0];
}

- (void)_writeData:(NSData*)data offset:(size_t)offset {
    size_t processed = 0;
    
    size_t sendLength = !data ? 0 : data.length - offset;
    
    OSStatus status = SSLWrite(_sslContext, data ? (const char*)data.bytes + offset : NULL, sendLength, &processed);
    if (status == noErr) {
        if (processed < sendLength) {
            NSAssert(false, @"i'm not sure this could happen");
            [self _writeData:data offset:offset + processed];    //try again with updated offset
        }
        else {
            // all data is sent
            [self _didWrite];
            return;
        }
    }
    else if (status == errSSLWouldBlock) {
        [self _writeData:nil offset:0];
    }
    else {
        // other error
        [self closeWithError:[self otherError:@"write failed"]];
    }
}

- (void)_didWrite {
    if (_delegate && [_delegate respondsToSelector:@selector(didWriteData:)]) {
        [_delegate didWriteData:self];
    }
}

- (void)readData {
    OSStatus status;
    char buf[16*1024];
    size_t processed = 0;
    
    do {
        status = SSLRead(_sslContext, buf, sizeof(buf), &processed);
    } while (status == errSSLWouldBlock);
    
    if (status == noErr) {
        NSData *data = [NSData dataWithBytes:buf length:processed];
        [self _didRead:data];
        
        return;
    }
    else if (status == errSSLClosedGraceful) {
        //            NSLog(@"recv, SSLRead return closed graceful");
        
        [self _didRead:nil];
        
        return;
    }
    else {
        [self closeWithError:[self otherError:@"read failed"]];
    }
}

- (void)_didRead:(NSData *)data {
    if (_delegate && [_delegate respondsToSelector:@selector(didReadSocket:data:)]) {
        [_delegate didReadSocket:self data:data];
    }
}

#pragma mark -
#pragma mark -- Disconnect
- (void)closeWithError:(NSError *)error {
    close(_socketFD);
    _socketFD = -1;
    
    if (_delegate && [_delegate respondsToSelector:@selector(didDisconnectSocket:error:)]) {
        [_delegate didDisconnectSocket:self error:error];
    }
}

#pragma mark -
#pragma mark -- Error
- (NSError *)otherError:(NSString *)errMsg {
    NSDictionary *userInfo = [NSDictionary dictionaryWithObject:errMsg forKey:NSLocalizedDescriptionKey];
    
    return [NSError errorWithDomain:@"SocketErrorDomain" code:5 userInfo:userInfo];
}

@end
上一篇 下一篇

猜你喜欢

热点阅读