From db5a2ca3de4385862b18ca1b6526a80de35021c6 Mon Sep 17 00:00:00 2001 From: zonyitoo Date: Sat, 4 Jan 2020 22:26:10 +0800 Subject: [PATCH] [#184] Uses blocking ConnectEx on Windows (FIXME) --- Cargo.toml | 1 + build/build-release | 4 +- src/relay/tcprelay/utils/tfo/bsd.rs | 20 ++-- src/relay/tcprelay/utils/tfo/windows.rs | 153 +++++++++++++++++++++--- 4 files changed, 151 insertions(+), 27 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 970901cc..18644d94 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,6 +79,7 @@ cfg-if = "0.1" [target.'cfg(windows)'.dependencies] winapi = { version = "0.3", features = ["mswsock", "winsock2"] } +lazy_static = "1.4" # [patch.crates-io] # libc = { git = "https://github.com/zonyitoo/libc.git", branch = "feature-linux-fastopen-connect", optional = true } diff --git a/build/build-release b/build/build-release index d952667c..803ede83 100755 --- a/build/build-release +++ b/build/build-release @@ -57,5 +57,5 @@ function build() { echo "* Done build package ${PKG_NAME}" } -build "x86_64-unknown-linux-musl" -#build "x86_64-pc-windows-gnu" +#build "x86_64-unknown-linux-musl" +build "x86_64-pc-windows-gnu" diff --git a/src/relay/tcprelay/utils/tfo/bsd.rs b/src/relay/tcprelay/utils/tfo/bsd.rs index 13c63a17..12dba73d 100644 --- a/src/relay/tcprelay/utils/tfo/bsd.rs +++ b/src/relay/tcprelay/utils/tfo/bsd.rs @@ -3,13 +3,13 @@ use std::{ io::{self, Error}, mem, - net::{self, SocketAddr}, + net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream}, os::unix::io::AsRawFd, }; use libc; use log::error; -use tokio::net::{TcpListener, TcpStream}; +use tokio::net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream}; fn create_socket(domain: libc::c_int) -> io::Result { unsafe { @@ -87,7 +87,7 @@ pub async fn bind_listener(addr: &SocketAddr) -> io::Result { return Err(Error::last_os_error()); } - TcpListener::from_std(net::TcpListener::from_raw_fd(sockfd)) + TcpListener::from_std(StdTcpListener::from_raw_fd(sockfd)) } } @@ -124,7 +124,7 @@ impl ConnectContext { } } -pub async fn connect_stream(addr: &SocketAddr) -> io::Result { +pub async fn connect_stream(addr: &SocketAddr) -> io::Result<(TcpStream, ConnectContext)> { let domain = match addr { SocketAddr::V4(..) => libc::AF_INET, SocketAddr::V6(..) => libc::AF_INET6, @@ -176,10 +176,16 @@ pub async fn connect_stream(addr: &SocketAddr) -> io::Result { return Err(Error::last_os_error()); } - TcpStream::from_std(net::TcpStream::from_raw_fd(sockfd)) + TcpStream::from_std(StdTcpStream::from_raw_fd(sockfd)).map(|s| { + ( + s, + ConnectContext { + socket: sockfd, + remote_addr: *addr, + }, + ) + }) } - - TcpStream::from_std(stream) } // Borrowed from net2 diff --git a/src/relay/tcprelay/utils/tfo/windows.rs b/src/relay/tcprelay/utils/tfo/windows.rs index bda81ad1..a07a9bbf 100644 --- a/src/relay/tcprelay/utils/tfo/windows.rs +++ b/src/relay/tcprelay/utils/tfo/windows.rs @@ -5,18 +5,45 @@ use std::{ mem, net::{self, IpAddr, SocketAddr}, os::windows::io::AsRawSocket, + ptr, }; -use log::error; +use lazy_static::lazy_static; +use log::{error, warn}; use net2::TcpBuilder; use tokio::net::{TcpListener, TcpStream}; use winapi::{ ctypes::{c_char, c_int}, shared::{ - minwindef::DWORD, - ws2def::{ADDRESS_FAMILY, AF_INET, AF_INET6, IPPROTO_TCP, SOCKADDR, SOCKADDR_IN}, + minwindef::{BOOL, DWORD, FALSE, LPDWORD, LPVOID, TRUE}, + ws2def::{ + ADDRESS_FAMILY, + AF_INET, + AF_INET6, + IPPROTO_TCP, + SIO_GET_EXTENSION_FUNCTION_POINTER, + SOCKADDR, + SOCKADDR_IN, + }, + }, + um::{ + minwinbase::OVERLAPPED, + mswsock::{LPFN_CONNECTEX, WSAID_CONNECTEX}, + winnt::PVOID, + winsock2::{ + bind, + closesocket, + setsockopt, + socket, + WSAGetLastError, + WSAGetOverlappedResult, + WSAIoctl, + INVALID_SOCKET, + SOCKET, + SOCKET_ERROR, + SOCK_STREAM, + }, }, - um::winsock2::{bind, connect, setsockopt, WSAGetLastError, SOCKET, SOCKET_ERROR}, }; // ws2ipdef.h @@ -61,7 +88,101 @@ pub async fn bind_listener(addr: &SocketAddr) -> io::Result { TcpListener::from_std(listener) } -pub async fn connect_stream(addr: &SocketAddr) -> io::Result { +lazy_static! { + static ref PFN_CONNECTEX_OPT: LPFN_CONNECTEX = unsafe { + let socket = socket(AF_INET, SOCK_STREAM, 0); + if socket == INVALID_SOCKET { + return None; + } + + let mut guid = WSAID_CONNECTEX; + let mut num_bytes: DWORD = 0; + + let mut connectex: LPFN_CONNECTEX = None; + + let ret = WSAIoctl( + socket, + SIO_GET_EXTENSION_FUNCTION_POINTER, + &mut guid as *mut _ as LPVOID, + mem::size_of_val(&guid) as DWORD, + &mut connectex as *mut _ as LPVOID, + mem::size_of_val(&connectex) as DWORD, + &mut num_bytes as *mut _, + ptr::null_mut(), + None, + ); + + if ret != 0 { + let err = WSAGetLastError(); + let e = Error::from_raw_os_error(err); + + warn!("Failed to get ConnectEx function from WSA extension, error: {}", e); + } + + let _ = closesocket(socket); + + connectex + }; +} + +pub struct ConnectContext { + // Reference to the partial connected socket fd + // This struct doesn't own the HANDLE, so do not close it while dropping + socket: SOCKET, + + // Target address for calling `ConnectEx` + remote_addr: SocketAddr, +} + +impl ConnectContext { + /// Performing actual connect operation + pub fn connect_with_data(self, buf: &[u8]) -> io::Result { + unsafe { + // https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nc-mswsock-lpfn_connectex + let connect_ex = PFN_CONNECTEX_OPT.expect("LPFN_CONNECTEX function doesn't exists"); + let (saddr, saddr_len) = addr2raw(&self.remote_addr); + + let mut overlapped: OVERLAPPED = mem::zeroed(); + + let mut bytes_sent: DWORD = 0; + let ret: BOOL = connect_ex( + self.socket, + saddr, + saddr_len, + buf.as_ptr() as PVOID, + buf.len() as DWORD, + &mut bytes_sent as *mut _ as LPDWORD, + &mut overlapped as *mut _, + ); + + if ret == FALSE { + let mut bytes_sent: DWORD = 0; + let mut flags: DWORD = 0; + + // FIXME: Blocking call. + let ret: BOOL = WSAGetOverlappedResult( + self.socket, + &mut overlapped as *mut _, + &mut bytes_sent as LPDWORD, + TRUE, + &mut flags as LPDWORD, + ); + + if ret == TRUE { + Ok(bytes_sent as usize) + } else { + let err = WSAGetLastError(); + Err(Error::from_raw_os_error(err)) + } + } else { + // Connect succeeded + Ok(bytes_sent as usize) + } + } + } +} + +pub async fn connect_stream(addr: &SocketAddr) -> io::Result<(TcpStream, ConnectContext)> { let builder = match addr.ip() { IpAddr::V4(..) => TcpBuilder::new_v4()?, IpAddr::V6(..) => TcpBuilder::new_v6()?, @@ -113,21 +234,17 @@ pub async fn connect_stream(addr: &SocketAddr) -> io::Result { let err = WSAGetLastError(); return Err(Error::from_raw_os_error(err)); } - - // FIXME: MSDN suggests to use ConnectEx instead of connect - // But it requires dynamic load from WSAIoctl and cache it in a global variable - // That sucks. - - let (saddr, saddr_len) = addr2raw(addr); - let ret = connect(socket, saddr, saddr_len); - - if ret == SOCKET_ERROR { - let err = WSAGetLastError(); - return Err(Error::from_raw_os_error(err)); - } } - TcpStream::from_std(stream) + TcpStream::from_std(stream).map(|s| { + ( + s, + ConnectContext { + socket, + remote_addr: *addr, + }, + ) + }) } // Borrowed from net2