set default timeout for shadowsocks connection tunnel

- default timeout is 24 hours
- ref #490
- refactored tunnel copy with copy_bidirectional to avoid unnecessary
  splits
This commit is contained in:
zonyitoo
2021-06-04 13:46:09 +08:00
parent e2ac20de66
commit 2819c47158
18 changed files with 526 additions and 251 deletions

23
Cargo.lock generated
View File

@@ -522,9 +522,9 @@ checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04"
[[package]]
name = "heck"
version = "0.3.2"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87cbf45460356b7deeb5e3415b5563308c0a9b057c85e12b06ad551f98d0a6ac"
checksum = "6d621efb26863f0e9924c6ac577e8275e5e6b77455db64ffa6c65c904e9e132c"
dependencies = [
"unicode-segmentation",
]
@@ -1018,12 +1018,11 @@ dependencies = [
[[package]]
name = "ordered-float"
version = "2.5.0"
version = "2.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "809348965973b261c3e504c8d0434e465274f78c880e10039914f2c5dcf49461"
checksum = "f100fcfb41e5385e0991f74981732049f9b896821542a219420491046baafdc2"
dependencies = [
"num-traits",
"rand",
]
[[package]]
@@ -1483,7 +1482,7 @@ dependencies = [
[[package]]
name = "shadowsocks"
version = "1.11.0"
version = "1.11.1"
dependencies = [
"arc-swap 1.3.0",
"async-trait",
@@ -1527,7 +1526,7 @@ dependencies = [
[[package]]
name = "shadowsocks-rust"
version = "1.11.1"
version = "1.11.2"
dependencies = [
"byte_string",
"byteorder",
@@ -1550,7 +1549,7 @@ dependencies = [
[[package]]
name = "shadowsocks-service"
version = "1.11.1"
version = "1.11.2"
dependencies = [
"async-trait",
"byte_string",
@@ -1591,9 +1590,9 @@ dependencies = [
[[package]]
name = "signal-hook-registry"
version = "1.3.0"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "16f1d0fef1604ba8f7a073c7e701f213e056707210e9020af4528e0101ce11a6"
checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0"
dependencies = [
"libc",
]
@@ -2061,9 +2060,9 @@ dependencies = [
[[package]]
name = "unicode-normalization"
version = "0.1.18"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33717dca7ac877f497014e10d73f3acf948c342bee31b5ca7892faf94ccc6b49"
checksum = "d54590932941a9e9266f0832deed84ebe1bf2e4c9e4a3554d393d18f5e854bf9"
dependencies = [
"tinyvec",
]

View File

@@ -1,6 +1,6 @@
[package]
name = "shadowsocks-rust"
version = "1.11.1"
version = "1.11.2"
authors = ["Shadowsocks Contributors"]
description = "shadowsocks is a fast tunnel proxy that helps you bypass firewalls."
repository = "https://github.com/shadowsocks/shadowsocks-rust"

View File

@@ -1,6 +1,6 @@
[package]
name = "shadowsocks-service"
version = "1.11.1"
version = "1.11.2"
authors = ["Shadowsocks Contributors"]
description = "shadowsocks is a fast tunnel proxy that helps you bypass firewalls."
repository = "https://github.com/shadowsocks/shadowsocks-rust"
@@ -84,7 +84,7 @@ byteorder = "1.3"
rand = { version = "0.8", optional = true }
futures = "0.3"
tokio = { version = "1", features = ["io-util", "macros", "net", "parking_lot", "rt", "sync", "time"] }
tokio = { version = "1.5", features = ["io-util", "macros", "net", "parking_lot", "rt", "sync", "time"] }
tokio-native-tls = { version = "0.3", optional = true }
native-tls = { version = "0.2.7", optional = true, features = ["alpn"] }
tokio-rustls = { version = "0.22", optional = true }
@@ -108,7 +108,7 @@ regex = "1.4"
serde = { version = "1.0", features = ["derive"] }
json5 = "0.3"
shadowsocks = { version = "1.10.2", path = "../shadowsocks" }
shadowsocks = { version = "1.11.1", path = "../shadowsocks" }
# Just for the ioctl call macro
[target.'cfg(any(target_os = "macos", target_os = "ios", target_os = "freebsd", target_os = "netbsd", target_os = "openbsd"))'.dependencies]

View File

@@ -101,7 +101,7 @@ impl HttpDispatcher {
//
// FIXME: What STATUS should I return for connection error?
let server = self.balancer.best_tcp_server();
let stream = AutoProxyClientStream::connect(self.context, server.as_ref(), &host).await?;
let mut stream = AutoProxyClientStream::connect(self.context, server.as_ref(), &host).await?;
debug!("CONNECT relay connected {} <-> {}", self.client_addr, host);
@@ -114,20 +114,13 @@ impl HttpDispatcher {
let client_addr = self.client_addr;
tokio::spawn(async move {
match upgrade::on(req).await {
Ok(upgraded) => {
Ok(mut upgraded) => {
trace!("CONNECT tunnel upgrade success, {} <-> {}", client_addr, host);
use tokio::io::split;
let (mut plain_reader, mut plain_writer) = split(upgraded);
let (mut shadow_reader, mut shadow_writer) = stream.into_split();
let _ = establish_tcp_tunnel(
server.server_config(),
&mut plain_reader,
&mut plain_writer,
&mut shadow_reader,
&mut shadow_writer,
&mut upgraded,
&mut stream,
client_addr,
&host,
)

View File

@@ -19,7 +19,7 @@ use crate::{
local::{
context::ServiceContext,
loadbalancing::PingBalancer,
net::{AutoProxyClientStream, AutoProxyIo},
net::AutoProxyClientStream,
redir::{
redir_ext::{TcpListenerRedirExt, TcpStreamRedirExt},
to_ipv4_mapped,
@@ -44,37 +44,13 @@ async fn establish_client_tcp_redir<'a>(
let server = balancer.best_tcp_server();
let svr_cfg = server.server_config();
let remote = AutoProxyClientStream::connect(context, &server, addr).await?;
let mut remote = AutoProxyClientStream::connect(context, &server, addr).await?;
if nodelay {
remote.set_nodelay(true)?;
}
if remote.is_proxied() {
debug!(
"established tcp redir tunnel {} <-> {} through sever {} (outbound: {})",
peer_addr,
addr,
svr_cfg.external_addr(),
svr_cfg.addr(),
);
} else {
debug!("established tcp redir tunnel {} <-> {}", peer_addr, addr);
}
let (mut plain_reader, mut plain_writer) = stream.split();
let (mut shadow_reader, mut shadow_writer) = remote.into_split();
establish_tcp_tunnel(
svr_cfg,
&mut plain_reader,
&mut plain_writer,
&mut shadow_reader,
&mut shadow_writer,
peer_addr,
addr,
)
.await
establish_tcp_tunnel(svr_cfg, &mut stream, &mut remote, peer_addr, addr).await
}
async fn handle_redir_client(

View File

@@ -121,18 +121,6 @@ impl Socks4TcpHandler {
// UNWRAP.
let mut stream = stream.into_inner();
let (mut plain_reader, mut plain_writer) = stream.split();
let (mut shadow_reader, mut shadow_writer) = remote.into_split();
establish_tcp_tunnel(
svr_cfg,
&mut plain_reader,
&mut plain_writer,
&mut shadow_reader,
&mut shadow_writer,
peer_addr,
&target_addr,
)
.await
establish_tcp_tunnel(svr_cfg, &mut stream, &mut remote, peer_addr, &target_addr).await
}
}

View File

@@ -137,7 +137,7 @@ impl Socks5TcpHandler {
let server = self.balancer.best_tcp_server();
let svr_cfg = server.server_config();
let remote = match AutoProxyClientStream::connect(self.context.clone(), &server, &target_addr).await {
let mut remote = match AutoProxyClientStream::connect(self.context.clone(), &server, &target_addr).await {
Ok(remote) => {
// Tell the client that we are ready
let header =
@@ -167,19 +167,7 @@ impl Socks5TcpHandler {
remote.set_nodelay(true)?;
}
let (mut plain_reader, mut plain_writer) = stream.split();
let (mut shadow_reader, mut shadow_writer) = remote.into_split();
establish_tcp_tunnel(
svr_cfg,
&mut plain_reader,
&mut plain_writer,
&mut shadow_reader,
&mut shadow_writer,
peer_addr,
&target_addr,
)
.await
establish_tcp_tunnel(svr_cfg, &mut stream, &mut remote, peer_addr, &target_addr).await
}
async fn handle_udp_associate(self, mut stream: TcpStream, client_addr: Address) -> io::Result<()> {

View File

@@ -78,23 +78,11 @@ async fn handle_tcp_client(
svr_cfg.addr(),
);
let remote = AutoProxyClientStream::connect_proxied(context, &server, &forward_addr).await?;
let mut remote = AutoProxyClientStream::connect_proxied(context, &server, &forward_addr).await?;
if nodelay {
remote.set_nodelay(true)?;
}
let (mut plain_reader, mut plain_writer) = stream.split();
let (mut shadow_reader, mut shadow_writer) = remote.into_split();
establish_tcp_tunnel(
svr_cfg,
&mut plain_reader,
&mut plain_writer,
&mut shadow_reader,
&mut shadow_writer,
peer_addr,
&forward_addr,
)
.await
establish_tcp_tunnel(svr_cfg, &mut stream, &mut remote, peer_addr, &forward_addr).await
}

View File

@@ -2,39 +2,31 @@
use std::{io, net::SocketAddr, time::Duration};
use futures::future::{self, Either};
use log::trace;
use log::{debug, trace};
use shadowsocks::{
config::ServerConfig,
relay::{
socks5::Address,
tcprelay::utils::{copy_from_encrypted, copy_to_encrypted},
},
relay::{socks5::Address, tcprelay::utils::copy_encrypted_bidirectional},
};
use tokio::{
io::{copy, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
io::{copy_bidirectional, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
time,
};
use crate::local::net::AutoProxyIo;
pub async fn establish_tcp_tunnel<PR, PW, SR, SW>(
pub(crate) async fn establish_tcp_tunnel<P, S>(
svr_cfg: &ServerConfig,
plain_reader: &mut PR,
plain_writer: &mut PW,
shadow_reader: &mut SR,
shadow_writer: &mut SW,
plain: &mut P,
shadow: &mut S,
peer_addr: SocketAddr,
target_addr: &Address,
) -> io::Result<()>
where
PR: AsyncRead + Unpin,
PW: AsyncWrite + Unpin,
SR: AsyncRead + AutoProxyIo + Unpin,
SW: AsyncWrite + AutoProxyIo + Unpin,
P: AsyncRead + AsyncWrite + Unpin,
S: AsyncRead + AsyncWrite + AutoProxyIo + Unpin,
{
if shadow_reader.is_proxied() && shadow_writer.is_proxied() {
trace!(
if shadow.is_proxied() {
debug!(
"established tcp tunnel {} <-> {} through sever {} (outbound: {})",
peer_addr,
target_addr,
@@ -42,16 +34,8 @@ where
svr_cfg.addr(),
);
} else {
trace!("established tcp tunnel {} <-> {} bypassed", peer_addr, target_addr);
return establish_tcp_tunnel_bypassed(
plain_reader,
plain_writer,
shadow_reader,
shadow_writer,
peer_addr,
target_addr,
)
.await;
debug!("established tcp tunnel {} <-> {} bypassed", peer_addr, target_addr);
return establish_tcp_tunnel_bypassed(plain, shadow, peer_addr, target_addr).await;
}
// https://github.com/shadowsocks/shadowsocks-rust/issues/232
@@ -61,22 +45,22 @@ where
// Wait at most 500ms, and then sends handshake packet to remote servers.
{
let mut buffer = [0u8; 8192];
match time::timeout(Duration::from_millis(500), plain_reader.read(&mut buffer)).await {
match time::timeout(Duration::from_millis(500), plain.read(&mut buffer)).await {
Ok(Ok(0)) => {
// EOF. Just terminate right here.
return Ok(());
}
Ok(Ok(n)) => {
// Send the first packet.
shadow_writer.write_all(&buffer[..n]).await?;
shadow.write_all(&buffer[..n]).await?;
}
Ok(Err(err)) => return Err(err),
Err(..) => {
// Timeout. Send handshake to server.
shadow_writer.write(&[]).await?;
shadow.write(&[]).await?;
trace!(
"tcp tunnel {} -> {} sent handshake without data",
"tcp tunnel {} -> {} (proxied) sent handshake without data",
peer_addr,
target_addr
);
@@ -84,62 +68,56 @@ where
}
}
let l2r = copy_to_encrypted(svr_cfg.method(), plain_reader, shadow_writer);
let r2l = copy_from_encrypted(svr_cfg.method(), shadow_reader, plain_writer);
tokio::pin!(l2r);
tokio::pin!(r2l);
match future::select(l2r, r2l).await {
Either::Left((Ok(..), ..)) => {
trace!("tcp tunnel {} -> {} closed", peer_addr, target_addr);
match copy_encrypted_bidirectional(svr_cfg.method(), shadow, plain).await {
Ok((wn, rn)) => {
trace!(
"tcp tunnel {} <-> {} (proxied) closed, L2R {} bytes, R2L {} bytes",
peer_addr,
target_addr,
rn,
wn
);
}
Either::Left((Err(err), ..)) => {
trace!("tcp tunnel {} -> {} closed with error: {}", peer_addr, target_addr, err);
}
Either::Right((Ok(..), ..)) => {
trace!("tcp tunnel {} <- {} closed", peer_addr, target_addr);
}
Either::Right((Err(err), ..)) => {
trace!("tcp tunnel {} <- {} closed with error: {}", peer_addr, target_addr, err);
Err(err) => {
trace!(
"tcp tunnel {} <-> {} (proxied) closed with error: {}",
peer_addr,
target_addr,
err
);
}
}
Ok(())
}
async fn establish_tcp_tunnel_bypassed<PR, PW, SR, SW>(
plain_reader: &mut PR,
plain_writer: &mut PW,
shadow_reader: &mut SR,
shadow_writer: &mut SW,
async fn establish_tcp_tunnel_bypassed<P, S>(
plain: &mut P,
shadow: &mut S,
peer_addr: SocketAddr,
target_addr: &Address,
) -> io::Result<()>
where
PR: AsyncRead + Unpin,
PW: AsyncWrite + Unpin,
SR: AsyncRead + Unpin,
SW: AsyncWrite + Unpin,
P: AsyncRead + AsyncWrite + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let l2r = copy(plain_reader, shadow_writer);
let r2l = copy(shadow_reader, plain_writer);
tokio::pin!(l2r);
tokio::pin!(r2l);
match future::select(l2r, r2l).await {
Either::Left((Ok(..), ..)) => {
trace!("tcp tunnel {} -> {} closed", peer_addr, target_addr);
match copy_bidirectional(plain, shadow).await {
Ok((rn, wn)) => {
trace!(
"tcp tunnel {} <-> {} (bypassed) closed, L2R {} bytes, R2L {} bytes",
peer_addr,
target_addr,
rn,
wn
);
}
Either::Left((Err(err), ..)) => {
trace!("tcp tunnel {} -> {} closed with error: {}", peer_addr, target_addr, err);
}
Either::Right((Ok(..), ..)) => {
trace!("tcp tunnel {} <- {} closed", peer_addr, target_addr);
}
Either::Right((Err(err), ..)) => {
trace!("tcp tunnel {} <- {} closed with error: {}", peer_addr, target_addr, err);
Err(err) => {
trace!(
"tcp tunnel {} <-> {} (bypassed) closed with error: {}",
peer_addr,
target_addr,
err
);
}
}

View File

@@ -8,17 +8,13 @@ use std::{
time::Duration,
};
use futures::future::{self, Either};
use log::{debug, error, info, trace, warn};
use shadowsocks::{
crypto::v1::CipherKind,
net::{AcceptOpts, TcpStream as OutboundTcpStream},
relay::{
socks5::{Address, Error as Socks5Error},
tcprelay::{
utils::{copy_from_encrypted, copy_to_encrypted},
ProxyServerStream,
},
tcprelay::{utils::copy_encrypted_bidirectional, ProxyServerStream},
},
ProxyListener,
ServerConfig,
@@ -192,15 +188,6 @@ impl TcpServerClient {
}
}
let (mut lr, mut lw) = self.stream.into_split();
let (mut rr, mut rw) = remote_stream.split();
let l2r = copy_to_encrypted(self.method, &mut lr, &mut rw);
let r2l = copy_from_encrypted(self.method, &mut rr, &mut lw);
tokio::pin!(l2r);
tokio::pin!(r2l);
debug!(
"established tcp tunnel {} <-> {} with {:?}",
self.peer_addr,
@@ -208,24 +195,19 @@ impl TcpServerClient {
self.context.connect_opts_ref()
);
match future::select(l2r, r2l).await {
Either::Left((Ok(..), ..)) => {
trace!("tcp tunnel {} -> {} closed", self.peer_addr, target_addr);
}
Either::Left((Err(err), ..)) => {
match copy_encrypted_bidirectional(self.method, &mut self.stream, &mut remote_stream).await {
Ok((rn, wn)) => {
trace!(
"tcp tunnel {} -> {} closed with error: {}",
"tcp tunnel {} <-> {} closed, L2R {} bytes, R2L {} bytes",
self.peer_addr,
target_addr,
err
rn,
wn
);
}
Either::Right((Ok(..), ..)) => {
trace!("tcp tunnel {} <- {} closed", self.peer_addr, target_addr);
}
Either::Right((Err(err), ..)) => {
Err(err) => {
trace!(
"tcp tunnel {} <- {} closed with error: {}",
"tcp tunnel {} <-> {} closed with error: {}",
self.peer_addr,
target_addr,
err

View File

@@ -1,6 +1,6 @@
[package]
name = "shadowsocks"
version = "1.11.0"
version = "1.11.1"
authors = ["Shadowsocks Contributors"]
description = "shadowsocks is a fast tunnel proxy that helps you bypass firewalls."
repository = "https://github.com/shadowsocks/shadowsocks-rust"

View File

@@ -155,7 +155,7 @@ pub struct ServerConfig {
method: CipherKind,
/// Encryption key
enc_key: Box<[u8]>,
/// Handshake timeout
/// Handshake timeout (connect)
timeout: Option<Duration>,
/// Plugin config
@@ -279,6 +279,15 @@ impl ServerConfig {
self.timeout
}
/// Timeout for established tunnels (connection)
pub fn connection_timeout(&self) -> Duration {
// Connection should be kept at least 24 hours.
// Otherwise connection will be closed accidently if there are no data exchanged from both ends.
static MIN_CONNECTION_TIMEOUT: Duration = Duration::from_secs(24 * 60 * 60);
std::cmp::max(MIN_CONNECTION_TIMEOUT, self.timeout.unwrap_or(Duration::from_secs(0)))
}
/// Get server's remark
pub fn remarks(&self) -> Option<&str> {
self.remarks.as_ref().map(AsRef::as_ref)

View File

@@ -1,6 +1,6 @@
//! A TCP listener for accepting shadowsocks' client connection
use std::{io, net::SocketAddr};
use std::{io, net::SocketAddr, time::Duration};
use once_cell::sync::Lazy;
use tokio::{
@@ -21,6 +21,7 @@ pub struct ProxyListener {
listener: TcpListener,
method: CipherKind,
key: Box<[u8]>,
connection_timeout: Duration,
context: SharedContext,
}
@@ -56,6 +57,7 @@ impl ProxyListener {
listener,
method: svr_cfg.method(),
key: svr_cfg.key().to_vec().into_boxed_slice(),
connection_timeout: svr_cfg.connection_timeout(),
context,
}
}
@@ -76,7 +78,13 @@ impl ProxyListener {
let stream = map_fn(stream);
// Create a ProxyServerStream and read the target address from it
let stream = ProxyServerStream::from_stream(self.context.clone(), stream, self.method, &self.key);
let stream = ProxyServerStream::from_stream(
self.context.clone(),
stream,
self.method,
&self.key,
self.connection_timeout,
);
Ok((stream, peer_addr))
}

View File

@@ -26,6 +26,8 @@ use crate::{
},
};
use super::timeout::TimedStream;
enum ProxyClientStreamWriteState {
Connect(Address),
Connecting(BytesMut),
@@ -36,7 +38,7 @@ enum ProxyClientStreamWriteState {
#[pin_project]
pub struct ProxyClientStream<S> {
#[pin]
stream: CryptoStream<S>,
stream: CryptoStream<TimedStream<S>>,
state: ProxyClientStreamWriteState,
context: SharedContext,
}
@@ -139,7 +141,13 @@ where
A: Into<Address>,
{
let addr = addr.into();
let stream = CryptoStream::from_stream(&context, stream, svr_cfg.method(), svr_cfg.key());
let stream = CryptoStream::from_stream(
&context,
// NOTE: All stream will have a default timeout even if `svr_cfg.timeout()` is None
TimedStream::new(stream, Some(svr_cfg.connection_timeout())),
svr_cfg.method(),
svr_cfg.key(),
);
ProxyClientStream {
stream,
@@ -150,17 +158,17 @@ where
/// Get reference to the underlying stream
pub fn get_ref(&self) -> &S {
self.stream.get_ref()
self.stream.get_ref().get_ref()
}
/// Get mutable reference to the underlying stream
pub fn get_mut(&mut self) -> &mut S {
self.stream.get_mut()
self.stream.get_mut().get_mut()
}
/// Consumes the `ProxyClientStream` and return the underlying stream
pub fn into_inner(self) -> S {
self.stream.into_inner()
self.stream.into_inner().into_inner()
}
}
@@ -266,7 +274,7 @@ where
#[pin_project]
pub struct ProxyClientStreamReadHalf<S> {
#[pin]
reader: CryptoStreamReadHalf<S>,
reader: CryptoStreamReadHalf<TimedStream<S>>,
context: SharedContext,
}
@@ -285,7 +293,7 @@ where
#[pin_project]
pub struct ProxyClientStreamWriteHalf<S> {
#[pin]
writer: CryptoStreamWriteHalf<S>,
writer: CryptoStreamWriteHalf<TimedStream<S>>,
state: ProxyClientStreamWriteState,
}

View File

@@ -7,3 +7,4 @@ pub use self::{
pub mod client;
pub mod server;
mod timeout;

View File

@@ -4,6 +4,7 @@ use std::{
io,
pin::Pin,
task::{self, Poll},
time::Duration,
};
use pin_project::pin_project;
@@ -15,11 +16,13 @@ use crate::{
relay::tcprelay::crypto_io::{CryptoStream, CryptoStreamReadHalf, CryptoStreamWriteHalf},
};
use super::timeout::TimedStream;
/// A stream for communicating with shadowsocks' proxy client
#[pin_project]
pub struct ProxyServerStream<S> {
#[pin]
stream: CryptoStream<S>,
stream: CryptoStream<TimedStream<S>>,
context: SharedContext,
}
@@ -29,26 +32,33 @@ impl<S> ProxyServerStream<S> {
stream: S,
method: CipherKind,
key: &[u8],
connection_timeout: Duration,
) -> ProxyServerStream<S> {
ProxyServerStream {
stream: CryptoStream::from_stream(&context, stream, method, key),
stream: CryptoStream::from_stream(
&context,
// NOTE: All stream will have a default timeout even if `svr_cfg.timeout()` is None
TimedStream::new(stream, Some(connection_timeout)),
method,
key,
),
context,
}
}
/// Get reference of the internal stream
pub fn get_ref(&self) -> &S {
self.stream.get_ref()
self.stream.get_ref().get_ref()
}
/// Get mutable reference of the internal stream
pub fn get_mut(&mut self) -> &mut S {
self.stream.get_mut()
self.stream.get_mut().get_mut()
}
/// Consumes the object and return the internal stream
pub fn into_inner(self) -> S {
self.stream.into_inner()
self.stream.into_inner().into_inner()
}
}
@@ -105,7 +115,7 @@ where
#[pin_project]
pub struct ProxyServerStreamReadHalf<S> {
#[pin]
reader: CryptoStreamReadHalf<S>,
reader: CryptoStreamReadHalf<TimedStream<S>>,
context: SharedContext,
}
@@ -124,7 +134,7 @@ where
#[pin_project]
pub struct ProxyServerStreamWriteHalf<S> {
#[pin]
writer: CryptoStreamWriteHalf<S>,
writer: CryptoStreamWriteHalf<TimedStream<S>>,
}
impl<S> AsyncWrite for ProxyServerStreamWriteHalf<S>

View File

@@ -0,0 +1,200 @@
//! Asynchronous Stream support unified timeout for both Read and Write
use std::{
future::Future,
io::{self, IoSlice},
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use pin_project::pin_project;
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
time::{self, Instant, Sleep},
};
#[derive(Debug)]
struct TimeoutState {
timeout: Option<Duration>,
cur: Pin<Box<Sleep>>,
active: bool,
}
impl TimeoutState {
#[inline]
fn new() -> TimeoutState {
TimeoutState {
timeout: None,
cur: Box::pin(time::sleep_until(Instant::now())),
active: false,
}
}
#[inline]
fn timeout(&self) -> Option<Duration> {
self.timeout
}
#[inline]
fn set_timeout(&mut self, timeout: Option<Duration>) {
// since this takes &mut self, we can't yet be active
self.timeout = timeout;
}
#[inline]
fn set_timeout_pinned(mut self: Pin<&mut Self>, timeout: Option<Duration>) {
self.timeout = timeout;
self.reset();
}
#[inline]
fn reset(mut self: Pin<&mut Self>) {
if self.active {
self.active = false;
self.cur.as_mut().reset(Instant::now());
}
}
#[inline]
fn poll_check(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Result<()> {
let timeout = match self.timeout {
Some(timeout) => timeout,
None => return Ok(()),
};
if !self.active {
self.cur.as_mut().reset(Instant::now() + timeout);
self.active = true;
}
match self.cur.as_mut().poll(cx) {
Poll::Ready(()) => Err(io::Error::from(io::ErrorKind::TimedOut)),
Poll::Pending => Ok(()),
}
}
}
/// A stream that timeouts if both Read and Write are both pending
///
/// IMPLEMENTATION NOTE:
///
/// Because the `TimedStream` internally shared the same `tokio::time::Sleep` state,
/// but it can only remember one `Waker`. Which means that the timeout event can only
/// notify one task, either `poll_read` or `poll_write`.
///
/// If this behavior is not expected, use the `tokio-io-timeout` crate instead.
///
/// If using this stream in a splitted way (ReadHalf and WriteHalf), then you should
/// kill both of them when you read `ErrorKind::TimedOut` from `poll_read` or `poll_write`.
/// In other word, it should work like a bidirection tunnel.
#[pin_project]
pub struct TimedStream<S> {
#[pin]
stream: S,
#[pin]
timeout_state: TimeoutState,
}
impl<S> TimedStream<S> {
/// Create a new `TimedStream` with optional timeout
pub fn new(stream: S, timeout: Option<Duration>) -> TimedStream<S> {
let mut timeout_state = TimeoutState::new();
if timeout.is_some() {
timeout_state.set_timeout(timeout);
}
TimedStream { stream, timeout_state }
}
/// Get timeout
#[inline]
#[allow(dead_code)]
pub fn timeout(&self) -> Option<Duration> {
self.timeout_state.timeout()
}
/// Set timeout exclusively
#[inline]
#[allow(dead_code)]
pub fn set_timeout(&mut self, timeout: Option<Duration>) {
self.timeout_state.set_timeout(timeout)
}
/// Set timeout exclusively with Pinned self
#[inline]
#[allow(dead_code)]
pub fn set_timeout_pinned(self: Pin<&mut Self>, timeout: Option<Duration>) {
self.project().timeout_state.set_timeout_pinned(timeout)
}
/// Get immutable reference of internal stream
pub fn get_ref(&self) -> &S {
&self.stream
}
/// Get mutable reference of internal stream
pub fn get_mut(&mut self) -> &mut S {
&mut self.stream
}
/// Consumes the `TimedStream` and return the internal stream
pub fn into_inner(self) -> S {
self.stream
}
}
impl<S> AsyncRead for TimedStream<S>
where
S: AsyncRead + Unpin,
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
let this = self.project();
let r = this.stream.poll_read(cx, buf);
match r {
Poll::Ready(..) => this.timeout_state.reset(),
Poll::Pending => this.timeout_state.poll_check(cx)?,
}
r
}
}
impl<S> AsyncWrite for TimedStream<S>
where
S: AsyncWrite + Unpin,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
let this = self.project();
let r = this.stream.poll_write(cx, buf);
match r {
Poll::Ready(..) => this.timeout_state.reset(),
Poll::Pending => this.timeout_state.poll_check(cx)?,
}
r
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().stream.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().stream.poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
let r = this.stream.poll_write_vectored(cx, bufs);
match r {
Poll::Ready(..) => this.timeout_state.reset(),
Poll::Pending => this.timeout_state.poll_check(cx)?,
}
r
}
}

View File

@@ -1,4 +1,7 @@
//! Utilities for TCP relay
//!
//! The `CopyBuffer`, `Copy` and `CopyBidirection` are borrowed from the [tokio](https://github.com/tokio-rs/tokio) project.
//! LICENSE MIT
use std::{
future::Future,
@@ -12,35 +15,43 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::crypto::v1::{CipherCategory, CipherKind};
/// A future that asynchronously copies the entire contents of a reader into a
/// writer.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct Copy<'a, R: ?Sized, W: ?Sized> {
reader: &'a mut R,
struct CopyBuffer {
read_done: bool,
writer: &'a mut W,
pos: usize,
cap: usize,
amt: u64,
buf: Box<[u8]>,
}
impl<R, W> Future for Copy<'_, R, W>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
type Output = io::Result<u64>;
impl CopyBuffer {
fn new(buffer_size: usize) -> Self {
Self {
read_done: false,
pos: 0,
cap: 0,
amt: 0,
buf: vec![0; buffer_size].into_boxed_slice(),
}
}
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
fn poll_copy<R, W>(
&mut self,
cx: &mut Context<'_>,
mut reader: Pin<&mut R>,
mut writer: Pin<&mut W>,
) -> Poll<io::Result<u64>>
where
R: AsyncRead + ?Sized,
W: AsyncWrite + ?Sized,
{
loop {
// If our buffer is empty, then we need to read some data to
// continue.
if self.pos == self.cap && !self.read_done {
let me = &mut *self;
let mut buf = ReadBuf::new(&mut me.buf);
ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut buf))?;
ready!(reader.as_mut().poll_read(cx, &mut buf))?;
let n = buf.filled().len();
if n == 0 {
self.read_done = true;
@@ -53,7 +64,7 @@ where
// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let me = &mut *self;
let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, &me.buf[me.pos..me.cap]))?;
let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?;
if i == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
@@ -68,14 +79,38 @@ where
// If we've written all the data and we've seen EOF, flush out the
// data and finish the transfer.
if self.pos == self.cap && self.read_done {
let me = &mut *self;
ready!(Pin::new(&mut *me.writer).poll_flush(cx))?;
ready!(writer.as_mut().poll_flush(cx))?;
return Poll::Ready(Ok(self.amt));
}
}
}
}
/// A future that asynchronously copies the entire contents of a reader into a
/// writer.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct Copy<'a, R: ?Sized, W: ?Sized> {
reader: &'a mut R,
writer: &'a mut W,
buf: CopyBuffer,
}
impl<R, W> Future for Copy<'_, R, W>
where
R: AsyncRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
type Output = io::Result<u64>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
let me = &mut *self;
me.buf
.poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer))
}
}
/// Copy data from encrypted reader to plain writer
pub async fn copy_from_encrypted<ER, PW>(method: CipherKind, reader: &mut ER, writer: &mut PW) -> io::Result<u64>
where
@@ -84,12 +119,8 @@ where
{
Copy {
reader,
read_done: false,
writer,
amt: 0,
pos: 0,
cap: 0,
buf: alloc_encrypted_read_buffer(method),
buf: CopyBuffer::new(encrypted_read_buffer_size(method)),
}
.await
}
@@ -102,32 +133,148 @@ where
{
Copy {
reader,
read_done: false,
writer,
amt: 0,
pos: 0,
cap: 0,
buf: alloc_plain_read_buffer(method),
buf: CopyBuffer::new(plain_read_buffer_size(method)),
}
.await
}
/// Create a buffer for reading from shadowsocks' encrypted channel
pub fn alloc_encrypted_read_buffer(method: CipherKind) -> Box<[u8]> {
fn encrypted_read_buffer_size(method: CipherKind) -> usize {
match method.category() {
CipherCategory::Aead => vec![0u8; super::aead::MAX_PACKET_SIZE + method.tag_len()].into_boxed_slice(),
CipherCategory::Aead => super::aead::MAX_PACKET_SIZE + method.tag_len(),
#[cfg(feature = "stream-cipher")]
CipherCategory::Stream => vec![0u8; 1 << 14].into_boxed_slice(),
CipherCategory::None => vec![0u8; 1 << 14].into_boxed_slice(),
CipherCategory::Stream => 1 << 14,
CipherCategory::None => 1 << 14,
}
}
/// Create a buffer for reading from plain channel (not encrypted), for copying data into encrypted channel
pub fn alloc_plain_read_buffer(method: CipherKind) -> Box<[u8]> {
fn plain_read_buffer_size(method: CipherKind) -> usize {
match method.category() {
CipherCategory::Aead => vec![0u8; super::aead::MAX_PACKET_SIZE].into_boxed_slice(),
CipherCategory::Aead => super::aead::MAX_PACKET_SIZE,
#[cfg(feature = "stream-cipher")]
CipherCategory::Stream => vec![0u8; 1 << 14].into_boxed_slice(),
CipherCategory::None => vec![0u8; 1 << 14].into_boxed_slice(),
CipherCategory::Stream => 1 << 14,
CipherCategory::None => 1 << 14,
}
}
/// Create a buffer for reading from shadowsocks' encrypted channel
#[inline]
pub fn alloc_encrypted_read_buffer(method: CipherKind) -> Box<[u8]> {
vec![0u8; encrypted_read_buffer_size(method)].into_boxed_slice()
}
/// Create a buffer for reading from plain channel (not encrypted), for copying data into encrypted channel
#[inline]
pub fn alloc_plain_read_buffer(method: CipherKind) -> Box<[u8]> {
vec![0u8; plain_read_buffer_size(method)].into_boxed_slice()
}
enum TransferState {
Running(CopyBuffer),
ShuttingDown(u64),
Done(u64),
}
struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
a: &'a mut A,
b: &'a mut B,
a_to_b: TransferState,
b_to_a: TransferState,
}
fn transfer_one_direction<A, B>(
cx: &mut Context<'_>,
state: &mut TransferState,
r: &mut A,
w: &mut B,
) -> Poll<io::Result<u64>>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
let mut r = Pin::new(r);
let mut w = Pin::new(w);
loop {
match state {
TransferState::Running(buf) => {
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
*state = TransferState::ShuttingDown(count);
}
TransferState::ShuttingDown(count) => {
ready!(w.as_mut().poll_shutdown(cx))?;
*state = TransferState::Done(*count);
}
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
}
}
}
impl<'a, A, B> Future for CopyBidirectional<'a, A, B>
where
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
type Output = io::Result<(u64, u64)>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Unpack self into mut refs to each field to avoid borrow check issues.
let CopyBidirectional { a, b, a_to_b, b_to_a } = &mut *self;
let a_to_b = transfer_one_direction(cx, a_to_b, &mut *a, &mut *b)?;
let b_to_a = transfer_one_direction(cx, b_to_a, &mut *b, &mut *a)?;
// It is not a problem if ready! returns early because transfer_one_direction for the
// other direction will keep returning TransferState::Done(count) in future calls to poll
let a_to_b = ready!(a_to_b);
let b_to_a = ready!(b_to_a);
Poll::Ready(Ok((a_to_b, b_to_a)))
}
}
/// Copies data in both directions between `encrypted` stream and `plain` stream.
///
/// This function returns a future that will read from both streams,
/// writing any data read to the opposing stream.
/// This happens in both directions concurrently.
///
/// If an EOF is observed on one stream, [`shutdown()`] will be invoked on
/// the other, and reading from that stream will stop. Copying of data in
/// the other direction will continue.
///
/// The future will complete successfully once both directions of communication has been shut down.
/// A direction is shut down when the reader reports EOF,
/// at which point [`shutdown()`] is called on the corresponding writer. When finished,
/// it will return a tuple of the number of bytes copied from encrypted to plain
/// and the number of bytes copied from plain to encrypted, in that order.
///
/// [`shutdown()`]: crate::io::AsyncWriteExt::shutdown
///
/// # Errors
///
/// The future will immediately return an error if any IO operation on `encrypted`
/// or `plain` returns an error. Some data read from either stream may be lost (not
/// written to the other stream) in this case.
///
/// # Return value
///
/// Returns a tuple of bytes copied `encrypted` to `plain` and bytes copied `plain` to `encrypted`.
pub async fn copy_encrypted_bidirectional<E, P>(
method: CipherKind,
encrypted: &mut E,
plain: &mut P,
) -> Result<(u64, u64), std::io::Error>
where
E: AsyncRead + AsyncWrite + Unpin + ?Sized,
P: AsyncRead + AsyncWrite + Unpin + ?Sized,
{
CopyBidirectional {
a: encrypted,
b: plain,
a_to_b: TransferState::Running(CopyBuffer::new(encrypted_read_buffer_size(method))),
b_to_a: TransferState::Running(CopyBuffer::new(plain_read_buffer_size(method))),
}
.await
}