diff --git a/src/relay/tcprelay/local.rs b/src/relay/tcprelay/local.rs index 007a0ee6..6e4fb2e7 100644 --- a/src/relay/tcprelay/local.rs +++ b/src/relay/tcprelay/local.rs @@ -32,7 +32,7 @@ use tokio_core::net::{TcpStream, TcpListener}; use tokio_core::reactor::Handle; use tokio_core::io::Io; use tokio_core::io::{ReadHalf, WriteHalf}; -use tokio_core::io::{flush, copy, write_all}; +use tokio_core::io::{flush, write_all, copy}; use hyper::method::Method; @@ -44,6 +44,7 @@ use relay::loadbalancing::server::RoundRobin; use relay::loadbalancing::server::LoadBalancer; use super::http::{self, HttpRequestFut, HttpResponseFut}; +use super::tunnel; /// TCP relay local server pub struct TcpRelayLocal { @@ -104,11 +105,7 @@ impl Socks5RelayLocal { let rhalf = svr_r.and_then(move |svr_r| copy(svr_r, w)); let whalf = svr_w.and_then(move |svr_w| copy(r, svr_w)); - rhalf.join(whalf) - .then(move |_| { - trace!("Relay to {} is finished", cloned_addr); - Ok(()) - }) + tunnel(cloned_addr, whalf, rhalf) }) }) .boxed() @@ -258,13 +255,10 @@ impl HttpRelayServer { .and_then(move |(svr_s, w)| { super::proxy_server_handshake(svr_s, cloned_svr_cfg, addr).and_then(move |(svr_r, svr_w)| { let rhalf = svr_r.and_then(move |svr_r| copy(svr_r, w)); - let whalf = svr_w.and_then(move |svr_w| copy(r, svr_w)); + let whalf = svr_w.and_then(move |svr_w| write_all(svr_w, remains)) + .and_then(move |(svr_w, _)| copy(r, svr_w)); - rhalf.join(whalf) - .then(move |_| { - trace!("Relay to {} is finished", cloned_addr); - Ok(()) - }) + tunnel(cloned_addr, whalf, rhalf) }) }) .boxed() diff --git a/src/relay/tcprelay/mod.rs b/src/relay/tcprelay/mod.rs index 46fc65c4..55a78a1c 100644 --- a/src/relay/tcprelay/mod.rs +++ b/src/relay/tcprelay/mod.rs @@ -50,20 +50,28 @@ pub mod server; mod stream; mod http; +#[derive(Debug, Copy, Clone)] +pub enum TunnelDirection { + Client2Server, + Server2Client, +} + type DecryptedHalf = DecryptedReader>; type EncryptedHalf = EncryptedWriter>; type DecryptedHalfFut = BoxFuture; type EncryptedHalfFut = BoxFuture; -fn connect_proxy_server(handle: &Handle, svr_cfg: Arc) -> BoxFuture { +pub type BoxIoFuture = BoxFuture; + +fn connect_proxy_server(handle: &Handle, svr_cfg: Arc) -> BoxIoFuture { TcpStream::connect(&svr_cfg.addr, handle).boxed() } fn proxy_server_handshake(remote_stream: TcpStream, svr_cfg: Arc, relay_addr: Address) - -> BoxFuture<(DecryptedHalfFut, EncryptedHalfFut), io::Error> { + -> BoxIoFuture<(DecryptedHalfFut, EncryptedHalfFut)> { futures::lazy(move || { let (r, w) = remote_stream.split(); @@ -199,7 +207,7 @@ impl Future for CopyExact // If our buffer has some data, let's write it out! while *pos < *cap { - let i = try_nb!(writer.write(&buf[*pos..*cap])); + let i = try_nb!(writer.write(&buf[*pos..*cap]).and_then(|x| writer.flush().map(|_| x))); *pos += i; } @@ -227,3 +235,60 @@ pub fn copy_exact(r: R, w: W, amt: usize) -> CopyExact { CopyExact::new(r, w, amt) } + +pub fn tunnel(addr: Address, c2s: CF, s2c: SF) -> BoxIoFuture<()> + where CF: Future + Send + 'static, + SF: Future + Send + 'static +{ + let addr = Arc::new(addr); + + let cloned_addr = addr.clone(); + let c2s = c2s.then(move |res| { + match res { + Ok(amt) => { + // Continue reading response from remote server + trace!("Relay {} client -> server is finished, relayed {} bytes", + cloned_addr, + amt); + + Ok(TunnelDirection::Client2Server) + } + Err(err) => { + error!("Relay {} client -> server aborted: {}", cloned_addr, err); + Err(err) + } + } + }); + + let cloned_addr = addr.clone(); + let s2c = s2c.then(move |res| { + match res { + Ok(amt) => { + trace!("Relay {} client <- server is finished, relayed {} bytes", + cloned_addr, + amt); + + Ok(TunnelDirection::Server2Client) + } + Err(err) => { + error!("Relay {} client <- server aborted: {}", cloned_addr, err); + Err(err) + } + } + }); + + c2s.select(s2c) + .map_err(|(err, _)| err) + .and_then(move |(dir, next)| { + match dir { + TunnelDirection::Client2Server => next.map(move |_| ()).boxed(), + // Shutdown connection directly because remote server has disconnected + TunnelDirection::Server2Client => futures::finished(()).boxed(), + } + }) + .and_then(move |_| { + trace!("Relay {} client <-> server are all finished, closing", addr); + Ok(()) + }) + .boxed() +} diff --git a/src/relay/tcprelay/server.rs b/src/relay/tcprelay/server.rs index f50e3a1c..3c0177a1 100644 --- a/src/relay/tcprelay/server.rs +++ b/src/relay/tcprelay/server.rs @@ -48,6 +48,8 @@ use tokio_core::io::{read_exact, write_all, copy, flush}; use ip::IpAddr; +use super::tunnel; + type ClientRead = ReadHalf; type ClientWrite = WriteHalf; @@ -179,16 +181,13 @@ impl TcpRelayServer { r_fut.and_then(move |(r, addr)| { info!("Connecting {}", addr); let cloned_addr = addr.clone(); - TcpRelayServer::connect_remote(cpu_pool, cloned_handle, addr, forbidden_ip).and_then(|svr_s| { - let (svr_r, svr_w) = svr_s.split(); - let c2s = copy(r, svr_w); - let s2c = w_fut.and_then(|w| copy(svr_r, w)); - c2s.join(s2c) - .then(move |_| { - trace!("Relay {} is finished", cloned_addr); - Ok(()) - }) - }) + TcpRelayServer::connect_remote(cpu_pool, cloned_handle.clone(), addr, forbidden_ip) + .and_then(move |svr_s| { + let (svr_r, svr_w) = svr_s.split(); + tunnel(cloned_addr, + copy(r, svr_w), + w_fut.and_then(|w| copy(svr_r, w))) + }) }) });