From 6373040a4d0047b3218587acfa83e3ed89b91ded Mon Sep 17 00:00:00 2001 From: zonyitoo Date: Sun, 15 Mar 2020 23:00:52 +0800 Subject: [PATCH] Customized UnixStream for protect() RPC ref #209 --- Cargo.toml | 1 + rustfmt.toml | 1 + src/context.rs | 4 +- src/crypto/cipher.rs | 3 +- src/relay/sys/unix/mod.rs | 31 +++---- src/relay/sys/unix/uds.rs | 173 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 192 insertions(+), 21 deletions(-) create mode 100644 src/relay/sys/unix/uds.rs diff --git a/Cargo.toml b/Cargo.toml index 18485d1c..ded14189 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -94,6 +94,7 @@ bloomfilter = "^1.0.2" spin = "0.5" # mio = { version = "0.7", features = ["udp"] } mio = "0.6" +mio-uds = "0.6" serde_json = "1.0" regex = "1" strum = "0.18" diff --git a/rustfmt.toml b/rustfmt.toml index a2241e92..bfb0dca0 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,3 +1,4 @@ +edition = "2018" max_width = 120 #indent_style = "Visual" #fn_call_width = 120 diff --git a/src/context.rs b/src/context.rs index 987f48c8..a017f763 100644 --- a/src/context.rs +++ b/src/context.rs @@ -296,9 +296,7 @@ impl Context { match self.config.acl { // Proxy everything by default None => false, - Some(ref a) => { - a.check_qname_in_proxy_list(qname).await - } + Some(ref a) => a.check_qname_in_proxy_list(qname).await, } } diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index e2a4822f..aadd55cb 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -3,7 +3,8 @@ use std::{ convert::From, fmt::{self, Debug, Display}, - io, mem, + io, + mem, str::{self, FromStr}, }; diff --git a/src/relay/sys/unix/mod.rs b/src/relay/sys/unix/mod.rs index 2a7d5fbb..986217dd 100644 --- a/src/relay/sys/unix/mod.rs +++ b/src/relay/sys/unix/mod.rs @@ -42,15 +42,16 @@ pub fn sockaddr_to_std(saddr: &libc::sockaddr_storage) -> io::Result cfg_if! { if #[cfg(target_os = "android")] { + mod uds; + /// This is a RPC for Android to `protect()` socket for connecting to remote servers /// /// https://developer.android.com/reference/android/net/VpnService#protect(java.net.Socket) /// /// More detail could be found in [shadowsocks-android](https://github.com/shadowsocks/shadowsocks-android) project. - fn protect(protect_path: &Option, fd: RawFd) -> io::Result<()> { - use std::{io::Read, os::unix::net::UnixStream, time::Duration}; - - use sendfd::{SendWithFd}; + async fn protect(protect_path: &Option, fd: RawFd) -> io::Result<()> { + use std::{io::Read, time::Duration}; + use tokio::io::AsyncReadExt; // ignore if protect_path is not specified let path = match protect_path { @@ -58,23 +59,19 @@ cfg_if! { None => return Ok(()), }; - // it's safe to use blocking socket here - let mut stream = UnixStream::connect(path)?; - stream - .set_read_timeout(Some(Duration::new(1, 0))) - .expect("couldn't set read timeout"); - stream - .set_write_timeout(Some(Duration::new(1, 0))) - .expect("couldn't set write timeout"); + let timeout = Some(Duration::new(1, 0)); + + let mut stream = self::uds::UnixStream::connect(path).await?; // send fds let dummy: [u8; 1] = [1]; let fds: [RawFd; 1] = [fd]; - stream.send_with_fd(&dummy, &fds)?; + stream.send_with_fd(&dummy, &fds).await?; // receive the return value let mut response = [0; 1]; - stream.read_exact(&mut response)?; + stream.read_exact(&mut response).await?; + if response[0] == 0xFF { return Err(Error::new(ErrorKind::Other, "protect() failed")); } @@ -83,7 +80,7 @@ cfg_if! { } } else { #[inline(always)] - fn protect(_protect_path: &Option, _fd: RawFd) -> io::Result<()> { + async fn protect(_protect_path: &Option, _fd: RawFd) -> io::Result<()> { Ok(()) } } @@ -103,7 +100,7 @@ pub async fn tcp_stream_connect(saddr: &SocketAddr, context: &Context) -> io::Re // Any traffic to localhost should not be protected // This is a workaround for VPNService if cfg!(target_os = "android") && !saddr.ip().is_loopback() { - protect(&context.config().protect_path, socket.as_raw_fd())?; + protect(&context.config().protect_path, socket.as_raw_fd()).await?; } // it's important that the socket is protected before connecting @@ -119,7 +116,7 @@ pub async fn create_udp_socket_with_context(addr: &SocketAddr, context: &Context // Any traffic to localhost should be protected // This is a workaround for VPNService if cfg!(target_os = "android") && !addr.ip().is_loopback() { - protect(&context.config().protect_path, socket.as_raw_fd())?; + protect(&context.config().protect_path, socket.as_raw_fd()).await?; } Ok(socket) diff --git a/src/relay/sys/unix/uds.rs b/src/relay/sys/unix/uds.rs new file mode 100644 index 00000000..69056c61 --- /dev/null +++ b/src/relay/sys/unix/uds.rs @@ -0,0 +1,173 @@ +//! Android specific features + +use std::{ + convert::TryInto, + io::{self, Error, ErrorKind, Read, Write}, + mem::{self, MaybeUninit}, + net::Shutdown, + os::unix::io::{AsRawFd, RawFd}, + path::Path, + pin::Pin, + ptr, + slice, + task::{Context, Poll}, +}; + +use futures::{future, ready}; +use mio_uds::UnixStream as MioUnixStream; +use tokio::io::{AsyncRead, AsyncWrite, PollEvented}; + +/// A UnixStream supports transferring FDs between processes +pub struct UnixStream { + io: PollEvented, +} + +impl UnixStream { + /// Connects to the socket named by `path`. + pub async fn connect>(path: P) -> io::Result { + let uds = MioUnixStream::connect(path)?; + let io = PollEvented::new(uds)?; + + future::poll_fn(|cx| io.poll_write_ready(cx)).await?; + Ok(UnixStream { io }) + } + + fn poll_send_with_fd(&self, cx: &mut Context, buf: &[u8], fds: &[RawFd]) -> Poll> { + ready!(self.io.poll_write_ready(cx))?; + + let fd = self.io.get_ref().as_raw_fd(); + match send_with_fd(fd, buf, fds) { + Err(ref err) if err.kind() == ErrorKind::WouldBlock => Poll::Pending, + x => Poll::Ready(x), + } + } + + /// Send data with file descriptors + pub async fn send_with_fd(&mut self, buf: &[u8], fds: &[RawFd]) -> io::Result { + future::poll_fn(|cx| self.poll_send_with_fd(cx, buf, fds)).await + } + + /// Shuts down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the + /// specified portions to immediately return with an appropriate value + /// (see the documentation of `Shutdown`). + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.io.get_ref().shutdown(how) + } +} + +impl AsyncRead for UnixStream { + unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit]) -> bool { + false + } + + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + self.poll_read_priv(cx, buf) + } +} + +impl AsyncWrite for UnixStream { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + self.poll_write_priv(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + self.shutdown(Shutdown::Write)?; + Poll::Ready(Ok(())) + } +} + +impl UnixStream { + // == Poll IO functions that takes `&self` == + // + // They are not public because (taken from the doc of `PollEvented`): + // + // While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the + // caller must ensure that there are at most two tasks that use a + // `PollEvented` instance concurrently. One for reading and one for writing. + // While violating this requirement is "safe" from a Rust memory model point + // of view, it will result in unexpected behavior in the form of lost + // notifications and tasks hanging. + + pub(crate) fn poll_read_priv(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll> { + ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?; + + match self.io.get_ref().read(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_read_ready(cx, mio::Ready::readable())?; + Poll::Pending + } + x => Poll::Ready(x), + } + } + + pub(crate) fn poll_write_priv(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + ready!(self.io.poll_write_ready(cx))?; + + match self.io.get_ref().write(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.clear_write_ready(cx)?; + Poll::Pending + } + x => Poll::Ready(x), + } + } +} + +/// A common implementation of `sendmsg` that sends provided bytes with ancillary file descriptors +/// over either a datagram or stream unix socket. +/// +/// Borrowed from: https://github.com/Standard-Cognition/sendfd +fn send_with_fd(socket: RawFd, bs: &[u8], fds: &[RawFd]) -> io::Result { + unsafe { + let mut iov = libc::iovec { + // NB: this casts *const to *mut, and in doing so we trust the OS to be a good citizen + // and not mutate our buffer. This is the API we have to live with. + iov_base: bs.as_ptr() as *const _ as *mut _, + iov_len: bs.len(), + }; + + // Construct msghdr + // + // 1. Allocate memory for msg_control + let cmsg_fd_len = fds.len() * mem::size_of::(); + let cmsg_buffer_len = libc::CMSG_SPACE(cmsg_fd_len as u32) as usize; + let mut cmsg_buffer = Vec::with_capacity(cmsg_buffer_len); + cmsg_buffer.set_len(cmsg_buffer_len); + + let mut msghdr = libc::msghdr { + msg_name: ptr::null_mut(), + msg_namelen: 0, + msg_iov: &mut iov as *mut _, + msg_iovlen: 1, + msg_control: cmsg_buffer.as_mut_ptr(), + msg_controllen: cmsg_buffer_len.try_into().unwrap(), + ..mem::zeroed() + }; + + // Fill cmsg with the file descriptors we are sending. + let cmsg_header = libc::CMSG_FIRSTHDR(&mut msghdr as *mut _); + cmsg_header.write(libc::cmsghdr { + cmsg_level: libc::SOL_SOCKET, + cmsg_type: libc::SCM_RIGHTS, + cmsg_len: libc::CMSG_LEN(cmsg_fd_len as u32).try_into().unwrap(), + }); + + let cmsg_data = libc::CMSG_DATA(cmsg_header); + let cmsg_data_slice = slice::from_raw_parts_mut(cmsg_data as *mut RawFd, fds.len()); + cmsg_data_slice.copy_from_slice(fds); + + let count = libc::sendmsg(socket, &msghdr as *const _, 0); + if count < 0 { + let err = Error::last_os_error(); + Err(err) + } else { + Ok(count as usize) + } + } +}