//
//  SocketHelper.swift
//  MpAccSocks5SDK
//

import Foundation
import Network

// 参考: https://www.cnblogs.com/renleiguanchashi/p/16592411.html
private struct ConnectReq {
    let ver: UInt8 = 0x05 // 0x04: socks4, 0x05: socks5
    let cmd: UInt8        // 1: tcp链接、2: bind 3. udp
    let rsv: UInt8 = 0x00 // 保留字段
    let addrType: UInt8 = 0x01 // 1: ipv4 3: 域名地址 4: ipv6
    let addr: UInt32 // 需网络字节顺序
    let port: UInt16 // 网络字节顺序
}

private struct ConnectRsp {
    var ver: UInt8 = 0x00 // 0x04: socks4, 0x05: socks5
    var rsp: UInt8 = 0x00 // 0x00 代表成功
    var rsv: UInt8 = 0x00 // 保留字段
    var addrType: UInt8 = 0x00 // 1: ipv4 3: 域名地址 4: ipv6
    var addr: UInt32 = 0x00 // 需网络字节顺序
    var port: UInt16 = 0x00 // 网络字节顺序
}

enum SocketError: Error {
    case portError
    case createError
    case connectError(socketId: Int32)
    case sendError(socketId: Int32, message: String)
    case receiveError(socketId: Int32, message: String)
}


// MARK: UDP func
public class SocketHelper:NSObject {
    
    private let socksProxyHost: String = "127.0.0.1"
    
    @objc
    public func connectToTcpSocket(ipv4: String, port: UInt16,
                                   withSocksProxyPort socksProxyPort: Int,
                                   username: String?,
                                   password: String?,
                                   error: NSErrorPointer) -> Int32 {
        
        do {
            guard socksProxyPort > 0 else { throw SocketError.portError }
            // 链接本地socket
            let socketFd = try connectToLocalSocket(
                proxyHost: socksProxyHost,
                proxyPort: socksProxyPort,
                username: username,
                password: password)
            // 配置目标服务器地址
            try setupTargetServerInfo(socketFd: socketFd, serverType: .tcp, ipv4: ipv4, port: port)
            return socketFd
        } catch let err {
            setupError(errPtr: error, error: err)
            return -1
        }
    }
    
    @objc
    public func createUdpSocket(targetIPv4 ipv4: String, port: UInt16,
                                withSocksProxyPort socksProxyPort: Int,
                                username: String?,
                                password: String?,
                                error: NSErrorPointer) -> UDPSocksToken? {
        do {
            guard socksProxyPort > 0 else { throw SocketError.portError }
            // 链接本地socket
            let socketFd = try connectToLocalSocket(
                proxyHost: socksProxyHost,
                proxyPort: socksProxyPort,
                username: username,
                password: password)
            
            // 配置目标服务器地址
            let conRsp = try setupTargetServerInfo(socketFd: socketFd, serverType: .udp, ipv4: ipv4, port: port)
            
            // 配置Token
            let inAddr = in_addr(s_addr: conRsp.addr)
            guard let cString = inet_ntoa(inAddr) else {
                throw SocketError.receiveError(socketId: socketFd, message: "设置target Server Ip错误")
            }
            return UDPSocksToken(
                socketFd: socketFd,
                proxyIp: String(cString: cString),
                proxyPort: UInt16(conRsp.port).bigEndian,
                dstAddr: inet_addr(ipv4),
                dstPort: UInt16(port).bigEndian
            )
        } catch let err {
            setupError(errPtr: error, error: err)
            return nil
        }
    }
}

// MARK: common func
extension SocketHelper {
    private enum ServerType: UInt8 {
        case tcp = 1
        case udp = 3
    }
    
    @discardableResult
    private func setupTargetServerInfo(socketFd: Int32, serverType: ServerType, ipv4: String, port: UInt16) throws -> ConnectRsp {
        // 配置目标服务器地址
        var requestConnectInfo = ConnectReq(cmd: serverType.rawValue, addr: inet_addr(ipv4), port: port.bigEndian)
        
        var targetReq = withUnsafeBytes(of: &requestConnectInfo) {  [UInt8]($0) }
        let targetRsp = try sendMessage(socketId: socketFd, message: &targetReq, receiveLen: 10)
        
        NSLog("feiyu \(targetRsp)")
        
        var connectRsp = ConnectRsp()
        withUnsafeMutableBytes(of: &connectRsp) { buffer in
            buffer.copyBytes(from: targetRsp)
        }
        guard connectRsp.ver == 0x05, // socks5
              connectRsp.rsp == 0x00, // 代表服务器链接目标服务器成功
              connectRsp.addrType == 0x01 // ipv4
        else {
            throw SocketError.receiveError(socketId: socketFd, message: "target ip: port 设置失败")
        }
        return connectRsp
    }
    
    private func connectToLocalSocket(proxyHost: String,
                                      proxyPort: Int,
                                      username: String?,
                                      password: String?) throws -> Int32 {
        // 创建socket
        let socketFd = socket(AF_INET, SOCK_STREAM, 0);
        guard socketFd != -1 else {
            throw SocketError.createError
        }
        // 链接
        var localSocketAddr = sockaddr_in()
        localSocketAddr.sin_family = sa_family_t(AF_INET)
        localSocketAddr.sin_addr.s_addr = inet_addr(proxyHost)
        localSocketAddr.sin_port = UInt16(proxyPort).bigEndian
        let conResult = withUnsafePointer(to: &localSocketAddr) { ptr in
            connect(socketFd, UnsafePointer<sockaddr>(OpaquePointer(ptr)), socklen_t(MemoryLayout<sockaddr_in>.stride))
        }
        guard conResult != -1 else {
            throw SocketError.connectError(socketId: socketFd)
        }
        
        // 方法选择请求
        var authReq: [UInt8] = [ 0x05, 0x02, 0x00, 0x02]
        let rsp = try sendMessage(socketId: socketFd, message: &authReq, receiveLen: 2)
        
        // 0x00不需要验证 则直接返回
        if rsp[1] == 0x00 {
            return socketFd
        }
        
        guard rsp[1] == 0x02, let username, let password else {
            throw SocketError.receiveError(
                socketId: socketFd,
                message: """
                    select valid method error >
                    rsp: \(rsp),
                    username: \(String(describing: username)),
                    password: \(String(describing: password))
                """
            )
        }
        
        // 用户信息校验
        var userValidReq: [UInt8] = [0x01]
        
        let user = username.data(using: .utf8)!
        let userLen = UInt8(user.count)
        let pass = password.data(using: .utf8)!
        let passLen = UInt8(pass.count)
        
        userValidReq.append(userLen)
        userValidReq.append(contentsOf: [UInt8](user))
        userValidReq.append(passLen)
        userValidReq.append(contentsOf: [UInt8](pass))
        
        let validRsp = try sendMessage(socketId: socketFd, message: &userValidReq, receiveLen: 2)
        // 0x00 代表验证成功
        guard validRsp[1] == 0x00 else {
            throw SocketError.receiveError(
                socketId: socketFd,
                message: "select user valid error: \(validRsp[1])"
            )
        }
        return socketFd
    }
    
    private func setupError(errPtr: NSErrorPointer, error: Error? = nil, message: String? = nil) {
        if let err = error as? SocketError {
            let msg: String
            var socketFd: Int32? = nil
            switch err {
            case .portError:
                msg = "please set socks host & port first"
            case .createError:
                msg = "create socket error"
            case .connectError(let socketId):
                msg = "connect socket error"
                socketFd = socketId
            case .sendError(let socketId, let msgStr):
                socketFd = socketId
                msg = msgStr
            case .receiveError(let socketId, let msgStr):
                socketFd = socketId
                msg = msgStr
            }
            if let socketFd {
                close(socketFd)
            }
            errPtr?.pointee = NSError(domain: "socksHelper", code: -2, userInfo: [
                NSLocalizedFailureErrorKey : msg
            ])
        } else if let error {
            errPtr?.pointee = error as NSError
        } else {
            errPtr?.pointee = NSError(domain: "socksHelper", code: -1, userInfo: [
                NSLocalizedFailureErrorKey : message ?? "empty"
            ])
        }
    }
    
    private func sendMessage(socketId: Int32, message: inout [UInt8], receiveLen: Int) throws -> [UInt8] {
        // 发起链接请求
        var len = send(socketId, &message, message.count, 0)
        guard len != -1 else {
            throw SocketError.sendError(socketId: socketId, message: "send method error: \(message)")
        }
        var rsp = [UInt8](repeating: 0, count: receiveLen)
        len = recv(socketId, &rsp, rsp.count, 0)
        if len < 0 {
            throw SocketError.receiveError(socketId: socketId, message: "receive error: \(message)")
        }
        if len != receiveLen {
            throw SocketError.receiveError(
                socketId: socketId,
                message: "receive bytes error > expectation: \(receiveLen) actual: \(len)"
            )
        }
        return rsp
    }
    
}

extension SocketHelper {
    func getSocksPort() throws -> UInt16 {
        var targetPort: UInt16?
        
        let params = NWParameters.udp
        params.prohibitedInterfaceTypes = [.wifi, .cellular]
        let sema = DispatchSemaphore(value: 0)
        let semaStop = DispatchSemaphore(value: 0)
        let listener = try NWListener(using: params)
        listener.newConnectionHandler = { _ in }
        listener.stateUpdateHandler = { [weak listener] newState in
            NSLog("\(#function): \(newState)")
            switch newState {
            case .setup: break
            case .waiting(_): break
            case .ready:
                if let port = listener?.port?.rawValue {
                    targetPort = port
                }
                sema.signal()
            case .cancelled:
                semaStop.signal()
            case .failed: fallthrough
            @unknown default:
                sema.signal()
                semaStop.signal()
            }
        }
        listener.start(queue: .global())
        sema.wait()
        listener.cancel()
        semaStop.wait()
        if let targetPort {
            return targetPort
        }
        throw NSError(domain: "com.tencent.mpacc.demo", code: -9999);
    }
}
