Customized UnixStream for protect() RPC

ref #209
This commit is contained in:
zonyitoo
2020-03-15 23:00:52 +08:00
parent bd2fe561bc
commit 6373040a4d
6 changed files with 192 additions and 21 deletions

View File

@@ -94,6 +94,7 @@ bloomfilter = "^1.0.2"
spin = "0.5"
# mio = { version = "0.7", features = ["udp"] }
mio = "0.6"
mio-uds = "0.6"
serde_json = "1.0"
regex = "1"
strum = "0.18"

View File

@@ -1,3 +1,4 @@
edition = "2018"
max_width = 120
#indent_style = "Visual"
#fn_call_width = 120

View File

@@ -296,9 +296,7 @@ impl Context {
match self.config.acl {
// Proxy everything by default
None => false,
Some(ref a) => {
a.check_qname_in_proxy_list(qname).await
}
Some(ref a) => a.check_qname_in_proxy_list(qname).await,
}
}

View File

@@ -3,7 +3,8 @@
use std::{
convert::From,
fmt::{self, Debug, Display},
io, mem,
io,
mem,
str::{self, FromStr},
};

View File

@@ -42,15 +42,16 @@ pub fn sockaddr_to_std(saddr: &libc::sockaddr_storage) -> io::Result<SocketAddr>
cfg_if! {
if #[cfg(target_os = "android")] {
mod uds;
/// This is a RPC for Android to `protect()` socket for connecting to remote servers
///
/// https://developer.android.com/reference/android/net/VpnService#protect(java.net.Socket)
///
/// More detail could be found in [shadowsocks-android](https://github.com/shadowsocks/shadowsocks-android) project.
fn protect(protect_path: &Option<String>, fd: RawFd) -> io::Result<()> {
use std::{io::Read, os::unix::net::UnixStream, time::Duration};
use sendfd::{SendWithFd};
async fn protect(protect_path: &Option<String>, fd: RawFd) -> io::Result<()> {
use std::{io::Read, time::Duration};
use tokio::io::AsyncReadExt;
// ignore if protect_path is not specified
let path = match protect_path {
@@ -58,23 +59,19 @@ cfg_if! {
None => return Ok(()),
};
// it's safe to use blocking socket here
let mut stream = UnixStream::connect(path)?;
stream
.set_read_timeout(Some(Duration::new(1, 0)))
.expect("couldn't set read timeout");
stream
.set_write_timeout(Some(Duration::new(1, 0)))
.expect("couldn't set write timeout");
let timeout = Some(Duration::new(1, 0));
let mut stream = self::uds::UnixStream::connect(path).await?;
// send fds
let dummy: [u8; 1] = [1];
let fds: [RawFd; 1] = [fd];
stream.send_with_fd(&dummy, &fds)?;
stream.send_with_fd(&dummy, &fds).await?;
// receive the return value
let mut response = [0; 1];
stream.read_exact(&mut response)?;
stream.read_exact(&mut response).await?;
if response[0] == 0xFF {
return Err(Error::new(ErrorKind::Other, "protect() failed"));
}
@@ -83,7 +80,7 @@ cfg_if! {
}
} else {
#[inline(always)]
fn protect(_protect_path: &Option<String>, _fd: RawFd) -> io::Result<()> {
async fn protect(_protect_path: &Option<String>, _fd: RawFd) -> io::Result<()> {
Ok(())
}
}
@@ -103,7 +100,7 @@ pub async fn tcp_stream_connect(saddr: &SocketAddr, context: &Context) -> io::Re
// Any traffic to localhost should not be protected
// This is a workaround for VPNService
if cfg!(target_os = "android") && !saddr.ip().is_loopback() {
protect(&context.config().protect_path, socket.as_raw_fd())?;
protect(&context.config().protect_path, socket.as_raw_fd()).await?;
}
// it's important that the socket is protected before connecting
@@ -119,7 +116,7 @@ pub async fn create_udp_socket_with_context(addr: &SocketAddr, context: &Context
// Any traffic to localhost should be protected
// This is a workaround for VPNService
if cfg!(target_os = "android") && !addr.ip().is_loopback() {
protect(&context.config().protect_path, socket.as_raw_fd())?;
protect(&context.config().protect_path, socket.as_raw_fd()).await?;
}
Ok(socket)

173
src/relay/sys/unix/uds.rs Normal file
View File

@@ -0,0 +1,173 @@
//! Android specific features
use std::{
convert::TryInto,
io::{self, Error, ErrorKind, Read, Write},
mem::{self, MaybeUninit},
net::Shutdown,
os::unix::io::{AsRawFd, RawFd},
path::Path,
pin::Pin,
ptr,
slice,
task::{Context, Poll},
};
use futures::{future, ready};
use mio_uds::UnixStream as MioUnixStream;
use tokio::io::{AsyncRead, AsyncWrite, PollEvented};
/// A UnixStream supports transferring FDs between processes
pub struct UnixStream {
io: PollEvented<MioUnixStream>,
}
impl UnixStream {
/// Connects to the socket named by `path`.
pub async fn connect<P: AsRef<Path>>(path: P) -> io::Result<UnixStream> {
let uds = MioUnixStream::connect(path)?;
let io = PollEvented::new(uds)?;
future::poll_fn(|cx| io.poll_write_ready(cx)).await?;
Ok(UnixStream { io })
}
fn poll_send_with_fd(&self, cx: &mut Context, buf: &[u8], fds: &[RawFd]) -> Poll<io::Result<usize>> {
ready!(self.io.poll_write_ready(cx))?;
let fd = self.io.get_ref().as_raw_fd();
match send_with_fd(fd, buf, fds) {
Err(ref err) if err.kind() == ErrorKind::WouldBlock => Poll::Pending,
x => Poll::Ready(x),
}
}
/// Send data with file descriptors
pub async fn send_with_fd(&mut self, buf: &[u8], fds: &[RawFd]) -> io::Result<usize> {
future::poll_fn(|cx| self.poll_send_with_fd(cx, buf, fds)).await
}
/// Shuts down the read, write, or both halves of this connection.
///
/// This function will cause all pending and future I/O calls on the
/// specified portions to immediately return with an appropriate value
/// (see the documentation of `Shutdown`).
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.io.get_ref().shutdown(how)
}
}
impl AsyncRead for UnixStream {
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
false
}
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
self.poll_read_priv(cx, buf)
}
}
impl AsyncWrite for UnixStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
self.poll_write_priv(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
self.shutdown(Shutdown::Write)?;
Poll::Ready(Ok(()))
}
}
impl UnixStream {
// == Poll IO functions that takes `&self` ==
//
// They are not public because (taken from the doc of `PollEvented`):
//
// While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the
// caller must ensure that there are at most two tasks that use a
// `PollEvented` instance concurrently. One for reading and one for writing.
// While violating this requirement is "safe" from a Rust memory model point
// of view, it will result in unexpected behavior in the form of lost
// notifications and tasks hanging.
pub(crate) fn poll_read_priv(&self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
ready!(self.io.poll_read_ready(cx, mio::Ready::readable()))?;
match self.io.get_ref().read(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_read_ready(cx, mio::Ready::readable())?;
Poll::Pending
}
x => Poll::Ready(x),
}
}
pub(crate) fn poll_write_priv(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
ready!(self.io.poll_write_ready(cx))?;
match self.io.get_ref().write(buf) {
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.io.clear_write_ready(cx)?;
Poll::Pending
}
x => Poll::Ready(x),
}
}
}
/// A common implementation of `sendmsg` that sends provided bytes with ancillary file descriptors
/// over either a datagram or stream unix socket.
///
/// Borrowed from: https://github.com/Standard-Cognition/sendfd
fn send_with_fd(socket: RawFd, bs: &[u8], fds: &[RawFd]) -> io::Result<usize> {
unsafe {
let mut iov = libc::iovec {
// NB: this casts *const to *mut, and in doing so we trust the OS to be a good citizen
// and not mutate our buffer. This is the API we have to live with.
iov_base: bs.as_ptr() as *const _ as *mut _,
iov_len: bs.len(),
};
// Construct msghdr
//
// 1. Allocate memory for msg_control
let cmsg_fd_len = fds.len() * mem::size_of::<RawFd>();
let cmsg_buffer_len = libc::CMSG_SPACE(cmsg_fd_len as u32) as usize;
let mut cmsg_buffer = Vec::with_capacity(cmsg_buffer_len);
cmsg_buffer.set_len(cmsg_buffer_len);
let mut msghdr = libc::msghdr {
msg_name: ptr::null_mut(),
msg_namelen: 0,
msg_iov: &mut iov as *mut _,
msg_iovlen: 1,
msg_control: cmsg_buffer.as_mut_ptr(),
msg_controllen: cmsg_buffer_len.try_into().unwrap(),
..mem::zeroed()
};
// Fill cmsg with the file descriptors we are sending.
let cmsg_header = libc::CMSG_FIRSTHDR(&mut msghdr as *mut _);
cmsg_header.write(libc::cmsghdr {
cmsg_level: libc::SOL_SOCKET,
cmsg_type: libc::SCM_RIGHTS,
cmsg_len: libc::CMSG_LEN(cmsg_fd_len as u32).try_into().unwrap(),
});
let cmsg_data = libc::CMSG_DATA(cmsg_header);
let cmsg_data_slice = slice::from_raw_parts_mut(cmsg_data as *mut RawFd, fds.len());
cmsg_data_slice.copy_from_slice(fds);
let count = libc::sendmsg(socket, &msghdr as *const _, 0);
if count < 0 {
let err = Error::last_os_error();
Err(err)
} else {
Ok(count as usize)
}
}
}