take arbitrary IO for udp proxy (#1641)

* udp relay

* refact

* reset

* Update crates/shadowsocks/src/relay/udprelay/proxy_socket.rs

* export trait
This commit is contained in:
Yuwei Ba
2024-09-21 02:02:55 +10:00
committed by GitHub
parent e691853f44
commit 4e295818e5
5 changed files with 154 additions and 41 deletions

View File

@@ -6,7 +6,6 @@ use shadowsocks::{
relay::{socks5::Address, udprelay::options::UdpSocketControlData},
ProxySocket,
};
use tokio::net::ToSocketAddrs;
use super::flow::FlowStat;
@@ -47,7 +46,7 @@ impl MonProxySocket {
/// Send a UDP packet to target from proxy
#[inline]
pub async fn send_to<A: ToSocketAddrs>(&self, target: A, addr: &Address, payload: &[u8]) -> io::Result<()> {
pub async fn send_to(&self, target: SocketAddr, addr: &Address, payload: &[u8]) -> io::Result<()> {
let n = self.socket.send_to(target, addr, payload).await?;
self.flow_stat.incr_tx(n as u64);
@@ -56,9 +55,9 @@ impl MonProxySocket {
/// Send a UDP packet to target from proxy
#[inline]
pub async fn send_to_with_ctrl<A: ToSocketAddrs>(
pub async fn send_to_with_ctrl(
&self,
target: A,
target: SocketAddr,
addr: &Address,
control: &UdpSocketControlData,
payload: &[u8],

View File

@@ -24,7 +24,7 @@ use std::{
))]
use futures::future;
use futures::ready;
use pin_project::pin_project;
#[cfg(any(
target_os = "linux",
target_os = "android",
@@ -86,9 +86,7 @@ fn make_mtu_error(packet_size: usize, mtu: usize) -> io::Error {
/// Wrappers for outbound `UdpSocket`
#[derive(Debug)]
#[pin_project]
pub struct UdpSocket {
#[pin]
socket: tokio::net::UdpSocket,
mtu: Option<usize>,
}

View File

@@ -0,0 +1,88 @@
use async_trait::async_trait;
use std::{
io::Result,
net::SocketAddr,
ops::Deref,
task::{Context, Poll},
};
use tokio::io::ReadBuf;
use crate::net::UdpSocket;
/// a trait for datagram transport that wraps around a tokio `UdpSocket`
#[async_trait]
pub trait DatagramTransport: Send + Sync + std::fmt::Debug {
async fn recv(&self, buf: &mut [u8]) -> Result<usize>;
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)>;
async fn send(&self, buf: &[u8]) -> Result<usize>;
async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result<usize>;
fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>>;
fn poll_recv_from(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<SocketAddr>>;
fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>>;
fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>>;
fn poll_send_to(&self, cx: &mut Context<'_>, buf: &[u8], target: SocketAddr) -> Poll<Result<usize>>;
fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>>;
fn local_addr(&self) -> Result<SocketAddr>;
#[cfg(unix)]
fn as_raw_fd(&self) -> std::os::fd::RawFd;
}
#[async_trait]
impl DatagramTransport for UdpSocket {
async fn recv(&self, buf: &mut [u8]) -> Result<usize> {
UdpSocket::recv(self, buf).await
}
async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
UdpSocket::recv_from(self, buf).await
}
async fn send(&self, buf: &[u8]) -> Result<usize> {
UdpSocket::send(self, buf).await
}
async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result<usize> {
UdpSocket::send_to(self, buf, target).await
}
fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<()>> {
UdpSocket::poll_recv(self, cx, buf)
}
fn poll_recv_from(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<Result<SocketAddr>> {
UdpSocket::poll_recv_from(self, cx, buf)
}
fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.deref().poll_recv_ready(cx)
}
fn poll_send(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
UdpSocket::poll_send(self, cx, buf)
}
fn poll_send_to(&self, cx: &mut Context<'_>, buf: &[u8], target: SocketAddr) -> Poll<Result<usize>> {
UdpSocket::poll_send_to(self, cx, buf, target)
}
fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.deref().poll_send_ready(cx)
}
fn local_addr(&self) -> Result<SocketAddr> {
self.deref().local_addr()
}
#[cfg(unix)]
fn as_raw_fd(&self) -> std::os::fd::RawFd {
use std::ops::Deref;
use std::os::fd::AsRawFd;
self.deref().as_raw_fd()
}
}

View File

@@ -50,10 +50,12 @@
use std::time::Duration;
pub use self::proxy_socket::ProxySocket;
pub use compat::DatagramTransport;
mod aead;
#[cfg(feature = "aead-cipher-2022")]
mod aead_2022;
mod compat;
pub mod crypto_io;
pub mod options;
pub mod proxy_socket;

View File

@@ -12,7 +12,7 @@ use byte_string::ByteStr;
use bytes::{Bytes, BytesMut};
use log::{info, trace, warn};
use once_cell::sync::Lazy;
use tokio::{io::ReadBuf, net::ToSocketAddrs, time};
use tokio::{io::ReadBuf, time};
use crate::{
config::{ServerAddr, ServerConfig, ServerUserManager},
@@ -22,9 +22,12 @@ use crate::{
relay::{socks5::Address, udprelay::options::UdpSocketControlData},
};
use super::crypto_io::{
decrypt_client_payload, decrypt_server_payload, encrypt_client_payload, encrypt_server_payload, ProtocolError,
ProtocolResult,
use super::{
compat::DatagramTransport,
crypto_io::{
decrypt_client_payload, decrypt_server_payload, encrypt_client_payload, encrypt_server_payload, ProtocolError,
ProtocolResult,
},
};
#[cfg(unix)]
@@ -72,7 +75,7 @@ pub type ProxySocketResult<T> = Result<T, ProxySocketError>;
#[derive(Debug)]
pub struct ProxySocket {
socket_type: UdpSocketType,
socket: ShadowUdpSocket,
io: Box<dyn DatagramTransport>,
method: CipherKind,
key: Box<[u8]>,
send_timeout: Option<Duration>,
@@ -128,11 +131,40 @@ impl ProxySocket {
let key = svr_cfg.key().to_vec().into_boxed_slice();
let method = svr_cfg.method();
// NOTE: svr_cfg.timeout() is not for this socket, but for associations.
ProxySocket {
socket_type,
io: Box::new(socket.into()),
method,
key,
send_timeout: None,
recv_timeout: None,
context,
identity_keys: match socket_type {
UdpSocketType::Client => svr_cfg.clone_identity_keys(),
UdpSocketType::Server => Arc::new(Vec::new()),
},
user_manager: match socket_type {
UdpSocketType::Client => None,
UdpSocketType::Server => svr_cfg.clone_user_manager(),
},
}
}
pub fn from_io(
socket_type: UdpSocketType,
context: SharedContext,
svr_cfg: &ServerConfig,
io: Box<dyn DatagramTransport>,
) -> ProxySocket {
let key = svr_cfg.key().to_vec().into_boxed_slice();
let method = svr_cfg.method();
// NOTE: svr_cfg.timeout() is not for this socket, but for associations.
ProxySocket {
socket_type,
socket: socket.into(),
io,
method,
key,
send_timeout: None,
@@ -241,8 +273,8 @@ impl ProxySocket {
);
let send_len = match self.send_timeout {
None => self.socket.send(&send_buf).await?,
Some(d) => match time::timeout(d, self.socket.send(&send_buf)).await {
None => self.io.send(&send_buf).await?,
Some(d) => match time::timeout(d, self.io.send(&send_buf)).await {
Ok(Ok(l)) => l,
Ok(Err(err)) => return Err(err.into()),
Err(..) => return Err(io::Error::from(ErrorKind::TimedOut).into()),
@@ -295,7 +327,7 @@ impl ProxySocket {
let n_send_buf = send_buf.len();
match self.socket.poll_send(cx, &send_buf).map_err(|x| x.into()) {
match self.io.poll_send(cx, &send_buf).map_err(|x| x.into()) {
Poll::Ready(Ok(l)) => {
if l == n_send_buf {
Poll::Ready(Ok(payload.len()))
@@ -340,14 +372,14 @@ impl ProxySocket {
self.encrypt_send_buffer(addr, control, &self.identity_keys, payload, &mut send_buf)?;
info!(
"UDP server client send to {}, payload length {} bytes, packet length {} bytes",
"UDP server client poll_send_to to {}, payload length {} bytes, packet length {} bytes",
target,
payload.len(),
send_buf.len()
);
let n_send_buf = send_buf.len();
match self.socket.poll_send_to(cx, &send_buf, target).map_err(|x| x.into()) {
match self.io.poll_send_to(cx, &send_buf, target).map_err(|x| x.into()) {
Poll::Ready(Ok(l)) => {
if l == n_send_buf {
Poll::Ready(Ok(payload.len()))
@@ -363,25 +395,20 @@ impl ProxySocket {
///
/// Check if socket is ready to `send`, or writable.
pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<ProxySocketResult<()>> {
self.socket.poll_send_ready(cx).map_err(|x| x.into())
self.io.poll_send_ready(cx).map_err(|x| x.into())
}
/// Send a UDP packet to target through proxy `target`
pub async fn send_to<A: ToSocketAddrs>(
&self,
target: A,
addr: &Address,
payload: &[u8],
) -> ProxySocketResult<usize> {
pub async fn send_to(&self, target: SocketAddr, addr: &Address, payload: &[u8]) -> ProxySocketResult<usize> {
self.send_to_with_ctrl(target, addr, &DEFAULT_SOCKET_CONTROL, payload)
.await
.map_err(Into::into)
}
/// Send a UDP packet to target through proxy `target`
pub async fn send_to_with_ctrl<A: ToSocketAddrs>(
pub async fn send_to_with_ctrl(
&self,
target: A,
target: SocketAddr,
addr: &Address,
control: &UdpSocketControlData,
payload: &[u8],
@@ -390,7 +417,7 @@ impl ProxySocket {
self.encrypt_send_buffer(addr, control, &self.identity_keys, payload, &mut send_buf)?;
trace!(
"UDP server client send to, addr {}, control: {:?}, payload length {} bytes, packet length {} bytes",
"UDP server client send_to to, addr {}, control: {:?}, payload length {} bytes, packet length {} bytes",
addr,
control,
payload.len(),
@@ -398,8 +425,8 @@ impl ProxySocket {
);
let send_len = match self.send_timeout {
None => self.socket.send_to(&send_buf, target).await?,
Some(d) => match time::timeout(d, self.socket.send_to(&send_buf, target)).await {
None => self.io.send_to(&send_buf, target).await?,
Some(d) => match time::timeout(d, self.io.send_to(&send_buf, target)).await {
Ok(Ok(l)) => l,
Ok(Err(err)) => return Err(err.into()),
Err(..) => return Err(io::Error::from(ErrorKind::TimedOut).into()),
@@ -408,7 +435,7 @@ impl ProxySocket {
if send_buf.len() != send_len {
warn!(
"UDP server client send {} bytes, but actually sent {} bytes",
"UDP server client send_to {} bytes, but actually sent {} bytes",
send_buf.len(),
send_len
);
@@ -448,10 +475,9 @@ impl ProxySocket {
&self,
recv_buf: &mut [u8],
) -> ProxySocketResult<(usize, Address, usize, Option<UdpSocketControlData>)> {
// Waiting for response from server SERVER -> CLIENT
let recv_n = match self.recv_timeout {
None => self.socket.recv(recv_buf).await?,
Some(d) => match time::timeout(d, self.socket.recv(recv_buf)).await {
None => self.io.recv(recv_buf).await?,
Some(d) => match time::timeout(d, self.io.recv(recv_buf)).await {
Ok(Ok(l)) => l,
Ok(Err(err)) => return Err(err.into()),
Err(..) => return Err(io::Error::from(ErrorKind::TimedOut).into()),
@@ -498,8 +524,8 @@ impl ProxySocket {
) -> ProxySocketResult<(usize, SocketAddr, Address, usize, Option<UdpSocketControlData>)> {
// Waiting for response from server SERVER -> CLIENT
let (recv_n, target_addr) = match self.recv_timeout {
None => self.socket.recv_from(recv_buf).await?,
Some(d) => match time::timeout(d, self.socket.recv_from(recv_buf)).await {
None => self.io.recv_from(recv_buf).await?,
Some(d) => match time::timeout(d, self.io.recv_from(recv_buf)).await {
Ok(Ok(l)) => l,
Ok(Err(err)) => return Err(err.into()),
Err(..) => return Err(io::Error::from(ErrorKind::TimedOut).into()),
@@ -542,7 +568,7 @@ impl ProxySocket {
cx: &mut Context<'_>,
recv_buf: &mut ReadBuf,
) -> Poll<ProxySocketResult<(usize, Address, usize, Option<UdpSocketControlData>)>> {
ready!(self.socket.poll_recv(cx, recv_buf))?;
ready!(self.io.poll_recv(cx, recv_buf))?;
let n_recv = recv_buf.filled().len();
@@ -570,7 +596,7 @@ impl ProxySocket {
cx: &mut Context<'_>,
recv_buf: &mut ReadBuf,
) -> Poll<ProxySocketResult<(usize, SocketAddr, Address, usize, Option<UdpSocketControlData>)>> {
let src = ready!(self.socket.poll_recv_from(cx, recv_buf))?;
let src = ready!(self.io.poll_recv_from(cx, recv_buf))?;
let n_recv = recv_buf.filled().len();
match self.decrypt_recv_buffer(recv_buf.filled_mut(), self.user_manager.as_deref()) {
@@ -581,12 +607,12 @@ impl ProxySocket {
/// poll family functions
pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<ProxySocketResult<()>> {
self.socket.poll_recv_ready(cx).map_err(|x| x.into())
self.io.poll_recv_ready(cx).map_err(|x| x.into())
}
/// Get local addr of socket
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.socket.local_addr()
self.io.local_addr()
}
/// Set `send` timeout, `None` will clear timeout
@@ -604,6 +630,6 @@ impl ProxySocket {
impl AsRawFd for ProxySocket {
/// Retrieve raw fd of the outbound socket
fn as_raw_fd(&self) -> RawFd {
self.socket.as_raw_fd()
self.io.as_raw_fd()
}
}