[#184] Uses blocking ConnectEx on Windows (FIXME)

This commit is contained in:
zonyitoo
2020-01-04 22:26:10 +08:00
parent b35758c2c6
commit db5a2ca3de
4 changed files with 151 additions and 27 deletions

View File

@@ -79,6 +79,7 @@ cfg-if = "0.1"
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["mswsock", "winsock2"] }
lazy_static = "1.4"
# [patch.crates-io]
# libc = { git = "https://github.com/zonyitoo/libc.git", branch = "feature-linux-fastopen-connect", optional = true }

View File

@@ -57,5 +57,5 @@ function build() {
echo "* Done build package ${PKG_NAME}"
}
build "x86_64-unknown-linux-musl"
#build "x86_64-pc-windows-gnu"
#build "x86_64-unknown-linux-musl"
build "x86_64-pc-windows-gnu"

View File

@@ -3,13 +3,13 @@
use std::{
io::{self, Error},
mem,
net::{self, SocketAddr},
net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream},
os::unix::io::AsRawFd,
};
use libc;
use log::error;
use tokio::net::{TcpListener, TcpStream};
use tokio::net::{TcpListener as TokioTcpListener, TcpStream as TokioTcpStream};
fn create_socket(domain: libc::c_int) -> io::Result<libc::c_int> {
unsafe {
@@ -87,7 +87,7 @@ pub async fn bind_listener(addr: &SocketAddr) -> io::Result<TcpListener> {
return Err(Error::last_os_error());
}
TcpListener::from_std(net::TcpListener::from_raw_fd(sockfd))
TcpListener::from_std(StdTcpListener::from_raw_fd(sockfd))
}
}
@@ -124,7 +124,7 @@ impl ConnectContext {
}
}
pub async fn connect_stream(addr: &SocketAddr) -> io::Result<TcpStream> {
pub async fn connect_stream(addr: &SocketAddr) -> io::Result<(TcpStream, ConnectContext)> {
let domain = match addr {
SocketAddr::V4(..) => libc::AF_INET,
SocketAddr::V6(..) => libc::AF_INET6,
@@ -176,10 +176,16 @@ pub async fn connect_stream(addr: &SocketAddr) -> io::Result<TcpStream> {
return Err(Error::last_os_error());
}
TcpStream::from_std(net::TcpStream::from_raw_fd(sockfd))
TcpStream::from_std(StdTcpStream::from_raw_fd(sockfd)).map(|s| {
(
s,
ConnectContext {
socket: sockfd,
remote_addr: *addr,
},
)
})
}
TcpStream::from_std(stream)
}
// Borrowed from net2

View File

@@ -5,18 +5,45 @@ use std::{
mem,
net::{self, IpAddr, SocketAddr},
os::windows::io::AsRawSocket,
ptr,
};
use log::error;
use lazy_static::lazy_static;
use log::{error, warn};
use net2::TcpBuilder;
use tokio::net::{TcpListener, TcpStream};
use winapi::{
ctypes::{c_char, c_int},
shared::{
minwindef::DWORD,
ws2def::{ADDRESS_FAMILY, AF_INET, AF_INET6, IPPROTO_TCP, SOCKADDR, SOCKADDR_IN},
minwindef::{BOOL, DWORD, FALSE, LPDWORD, LPVOID, TRUE},
ws2def::{
ADDRESS_FAMILY,
AF_INET,
AF_INET6,
IPPROTO_TCP,
SIO_GET_EXTENSION_FUNCTION_POINTER,
SOCKADDR,
SOCKADDR_IN,
},
},
um::{
minwinbase::OVERLAPPED,
mswsock::{LPFN_CONNECTEX, WSAID_CONNECTEX},
winnt::PVOID,
winsock2::{
bind,
closesocket,
setsockopt,
socket,
WSAGetLastError,
WSAGetOverlappedResult,
WSAIoctl,
INVALID_SOCKET,
SOCKET,
SOCKET_ERROR,
SOCK_STREAM,
},
},
um::winsock2::{bind, connect, setsockopt, WSAGetLastError, SOCKET, SOCKET_ERROR},
};
// ws2ipdef.h
@@ -61,7 +88,101 @@ pub async fn bind_listener(addr: &SocketAddr) -> io::Result<TcpListener> {
TcpListener::from_std(listener)
}
pub async fn connect_stream(addr: &SocketAddr) -> io::Result<TcpStream> {
lazy_static! {
static ref PFN_CONNECTEX_OPT: LPFN_CONNECTEX = unsafe {
let socket = socket(AF_INET, SOCK_STREAM, 0);
if socket == INVALID_SOCKET {
return None;
}
let mut guid = WSAID_CONNECTEX;
let mut num_bytes: DWORD = 0;
let mut connectex: LPFN_CONNECTEX = None;
let ret = WSAIoctl(
socket,
SIO_GET_EXTENSION_FUNCTION_POINTER,
&mut guid as *mut _ as LPVOID,
mem::size_of_val(&guid) as DWORD,
&mut connectex as *mut _ as LPVOID,
mem::size_of_val(&connectex) as DWORD,
&mut num_bytes as *mut _,
ptr::null_mut(),
None,
);
if ret != 0 {
let err = WSAGetLastError();
let e = Error::from_raw_os_error(err);
warn!("Failed to get ConnectEx function from WSA extension, error: {}", e);
}
let _ = closesocket(socket);
connectex
};
}
pub struct ConnectContext {
// Reference to the partial connected socket fd
// This struct doesn't own the HANDLE, so do not close it while dropping
socket: SOCKET,
// Target address for calling `ConnectEx`
remote_addr: SocketAddr,
}
impl ConnectContext {
/// Performing actual connect operation
pub fn connect_with_data(self, buf: &[u8]) -> io::Result<usize> {
unsafe {
// https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nc-mswsock-lpfn_connectex
let connect_ex = PFN_CONNECTEX_OPT.expect("LPFN_CONNECTEX function doesn't exists");
let (saddr, saddr_len) = addr2raw(&self.remote_addr);
let mut overlapped: OVERLAPPED = mem::zeroed();
let mut bytes_sent: DWORD = 0;
let ret: BOOL = connect_ex(
self.socket,
saddr,
saddr_len,
buf.as_ptr() as PVOID,
buf.len() as DWORD,
&mut bytes_sent as *mut _ as LPDWORD,
&mut overlapped as *mut _,
);
if ret == FALSE {
let mut bytes_sent: DWORD = 0;
let mut flags: DWORD = 0;
// FIXME: Blocking call.
let ret: BOOL = WSAGetOverlappedResult(
self.socket,
&mut overlapped as *mut _,
&mut bytes_sent as LPDWORD,
TRUE,
&mut flags as LPDWORD,
);
if ret == TRUE {
Ok(bytes_sent as usize)
} else {
let err = WSAGetLastError();
Err(Error::from_raw_os_error(err))
}
} else {
// Connect succeeded
Ok(bytes_sent as usize)
}
}
}
}
pub async fn connect_stream(addr: &SocketAddr) -> io::Result<(TcpStream, ConnectContext)> {
let builder = match addr.ip() {
IpAddr::V4(..) => TcpBuilder::new_v4()?,
IpAddr::V6(..) => TcpBuilder::new_v6()?,
@@ -113,21 +234,17 @@ pub async fn connect_stream(addr: &SocketAddr) -> io::Result<TcpStream> {
let err = WSAGetLastError();
return Err(Error::from_raw_os_error(err));
}
// FIXME: MSDN suggests to use ConnectEx instead of connect
// But it requires dynamic load from WSAIoctl and cache it in a global variable
// That sucks.
let (saddr, saddr_len) = addr2raw(addr);
let ret = connect(socket, saddr, saddr_len);
if ret == SOCKET_ERROR {
let err = WSAGetLastError();
return Err(Error::from_raw_os_error(err));
}
}
TcpStream::from_std(stream)
TcpStream::from_std(stream).map(|s| {
(
s,
ConnectContext {
socket,
remote_addr: *addr,
},
)
})
}
// Borrowed from net2