Supports UDP tunnel

This commit is contained in:
zonyitoo
2019-12-01 20:30:55 +08:00
parent 6b4be3211f
commit 3e2cddb374
5 changed files with 626 additions and 285 deletions

View File

@@ -1,291 +1,14 @@
//! UDP relay local server
//! UDP local relay server
use std::io::{self, Cursor, ErrorKind, Read};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use std::io;
use bytes::BytesMut;
use log::{debug, error, info, trace};
use lru_time_cache::{Entry, LruCache};
use tokio;
use tokio::net::udp::{RecvHalf, SendHalf};
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use crate::config::{ServerAddr, ServerConfig};
use super::{socks5_local, tunnel_local};
use crate::context::SharedContext;
use crate::relay::loadbalancing::server::{LoadBalancer, RoundRobin};
use crate::relay::socks5::{Address, UdpAssociateHeader};
use crate::relay::utils::try_timeout;
use super::crypto_io::{decrypt_payload, encrypt_payload};
use super::MAXIMUM_UDP_PAYLOAD_SIZE;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
async fn parse_packet(pkt: &[u8]) -> io::Result<(Address, Vec<u8>)> {
// PKT = UdpAssociateHeader + PAYLOAD
let mut cur = Cursor::new(pkt);
let header = UdpAssociateHeader::read_from(&mut cur).await?;
if header.frag != 0 {
error!("Received UDP associate with frag != 0, which is not supported by ShadowSocks");
let err = io::Error::new(ErrorKind::Other, "unsupported UDP fragmentation");
return Err(err);
}
let addr = header.address;
// The remaining is PAYLOAD
let mut payload = Vec::new();
cur.read_to_end(&mut payload)?;
Ok((addr, payload))
}
// Represent a UDP association
struct UdpAssociation {
tx: mpsc::Sender<Vec<u8>>,
}
impl UdpAssociation {
/// Create an association with addr
async fn associate(
svr_cfg: Arc<ServerConfig>,
src_addr: SocketAddr,
mut response_tx: mpsc::Sender<(SocketAddr, Vec<u8>)>,
) -> io::Result<UdpAssociation> {
debug!("Created UDP Association for {}", src_addr);
// Create a socket for receiving packets
let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0);
let remote_udp = UdpSocket::bind(&local_addr).await?;
// Create a channel for sending packets to remote
// FIXME: Channel size 1024?
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(1024);
// Splits socket into sender and receiver
let (mut receiver, mut sender) = remote_udp.split();
let timeout = svr_cfg.udp_timeout().unwrap_or(DEFAULT_TIMEOUT);
// local -> remote
let c_svr_cfg = svr_cfg.clone();
tokio::spawn(async move {
while let Some(pkt) = rx.recv().await {
// pkt is already a raw packet, so just send it
if let Err(err) = UdpAssociation::relay_l2r(src_addr, &mut sender, &pkt[..], timeout, &*c_svr_cfg).await
{
error!("Failed to send packet {} -> ..., error: {}", src_addr, err);
// FIXME: Ignore? Or how to deal with it?
}
}
debug!("UDP ASSOCIATE {} -> .. finished", src_addr);
});
// local <- remote
tokio::spawn(async move {
loop {
// Read and send back to source
if let Err(err) =
UdpAssociation::relay_r2l(src_addr, &mut receiver, timeout, &mut response_tx, &*svr_cfg).await
{
error!("Failed to receive packet, {} <- .., error: {}", src_addr, err);
break;
}
}
debug!("UDP ASSOCIATE {} <- .. finished", src_addr);
});
Ok(UdpAssociation { tx })
}
/// Relay packets from local to remote
async fn relay_l2r(
src: SocketAddr,
remote_udp: &mut SendHalf,
pkt: &[u8],
timeout: Duration,
svr_cfg: &ServerConfig,
) -> io::Result<()> {
let (addr, payload) = parse_packet(&pkt).await?;
debug!(
"UDP ASSOCIATE {} -> {}, payload length {} bytes",
src,
addr,
payload.len()
);
// CLIENT -> SERVER protocol: ADDRESS + PAYLOAD
let mut send_buf = Vec::new();
addr.write_to_buf(&mut send_buf);
send_buf.extend_from_slice(&payload);
let mut encrypt_buf = BytesMut::new();
encrypt_payload(svr_cfg.method(), svr_cfg.key(), &send_buf, &mut encrypt_buf)?;
let send_len = match svr_cfg.addr() {
ServerAddr::SocketAddr(ref remote_addr) => {
try_timeout(remote_udp.send_to(&encrypt_buf[..], remote_addr), Some(timeout)).await?
}
#[cfg(feature = "trust-dns")]
ServerAddr::DomainName(ref dname, port) => {
use crate::relay::dns_resolver::resolve;
let vec_ipaddr = resolve(context, dname, port, false).await?;
assert!(!vec_ipaddr.is_empty());
try_timeout(remote_udp.send_to(&encrypt_buf[..], &vec_ipaddr[0]), Some(timeout)).await?
}
#[cfg(not(feature = "trust-dns"))]
ServerAddr::DomainName(ref dname, port) => {
// try_timeout(remote_udp.send_to(&encrypt_buf[..], (dname.as_str(), port)), Some(timeout)).await?
unimplemented!(
"tokio's UdpSocket SendHalf doesn't support ToSocketAddrs, {}:{}",
dname,
port
);
}
};
assert_eq!(encrypt_buf.len(), send_len);
Ok(())
}
/// Relay packets from remote to local
async fn relay_r2l(
src_addr: SocketAddr,
remote_udp: &mut RecvHalf,
timeout: Duration,
response_tx: &mut mpsc::Sender<(SocketAddr, Vec<u8>)>,
svr_cfg: &ServerConfig,
) -> io::Result<()> {
// Waiting for response from server SERVER -> CLIENT
// Packet length is limited by MAXIMUM_UDP_PAYLOAD_SIZE, excess bytes will be discarded.
let mut recv_buf = [0u8; MAXIMUM_UDP_PAYLOAD_SIZE];
let (recv_n, remote_addr) = try_timeout(remote_udp.recv_from(&mut recv_buf), Some(timeout)).await?;
let decrypt_buf = match decrypt_payload(svr_cfg.method(), svr_cfg.key(), &recv_buf[..recv_n])? {
None => {
error!("UDP packet too short, received length {}", recv_n);
let err = io::Error::new(io::ErrorKind::InvalidData, "packet too short");
return Err(err);
}
Some(b) => b,
};
// SERVER -> CLIENT protocol: ADDRESS + PAYLOAD
let mut cur = Cursor::new(decrypt_buf);
// FIXME: Address is ignored. Maybe useful in the future if we uses one common UdpSocket for communicate with remote server
let _ = Address::read_from(&mut cur).await?;
let mut payload = Vec::new();
let header = UdpAssociateHeader::new(0, Address::SocketAddress(src_addr));
header.write_to_buf(&mut payload);
cur.read_to_end(&mut payload)?;
debug!(
"UDP ASSOCIATE {} <- {}, payload length {} bytes",
src_addr,
remote_addr,
payload.len()
);
// Send back to src_addr
if let Err(err) = response_tx.send((src_addr, payload)).await {
error!("Failed to send packet into response channel, error: {}", err);
// FIXME: What to do? Ignore?
}
Ok(())
}
async fn send(&mut self, pkt: &[u8]) -> bool {
match self.tx.send(pkt.to_vec()).await {
Ok(..) => true,
Err(err) => {
error!("Failed to send packet, error: {}", err);
false
}
}
}
}
async fn listen(context: SharedContext, l: UdpSocket) -> io::Result<()> {
let mut balancer = RoundRobin::new(context.config());
let (mut r, mut w) = l.split();
let mut pkt_buf = [0u8; MAXIMUM_UDP_PAYLOAD_SIZE];
// FIXME: Channel size 1024?
let (tx, mut rx) = mpsc::channel::<(SocketAddr, Vec<u8>)>(1024);
tokio::spawn(async move {
while let Some((src, pkt)) = rx.recv().await {
if let Err(err) = w.send_to(&pkt, &src).await {
error!("UDP packet send failed, err: {:?}", err);
break;
}
}
// FIXME: How to stop the outer listener Future?
});
// let timeout = svr_cfg.udp_timeout().unwrap_or(DEFAULT_TIMEOUT);
let mut assoc_map = LruCache::with_expiry_duration(DEFAULT_TIMEOUT);
loop {
let (recv_len, src) = r.recv_from(&mut pkt_buf).await?;
// Packet length is limited by MAXIMUM_UDP_PAYLOAD_SIZE, excess bytes will be discarded.
// Copy bytes, because udp_associate runs in another tokio Task
let pkt = &pkt_buf[..recv_len];
trace!("Received UDP packet from {}, length {} bytes", src, recv_len);
// Pick a server
let svr_cfg = balancer.pick_server();
// Check or (re)create an association
loop {
let retry = {
let assoc = match assoc_map.entry(src.to_string()) {
Entry::Occupied(oc) => oc.into_mut(),
Entry::Vacant(vc) => vc.insert(
UdpAssociation::associate(svr_cfg.clone(), src, tx.clone())
.await
.expect("Failed to create udp association"),
),
};
!assoc.send(pkt).await
};
if retry {
assoc_map.remove(&src.to_string());
} else {
break;
}
}
}
}
/// Starts a UDP local server
pub async fn run(context: SharedContext) -> io::Result<()> {
let local_addr = *context.config().local.as_ref().unwrap();
let listener = UdpSocket::bind(&local_addr).await?;
info!("ShadowSocks UDP listening on {}", local_addr);
listen(context, listener).await
match context.config().forward {
Some(..) => tunnel_local::run(context).await,
None => socks5_local::run(context).await,
}
}

View File

@@ -39,6 +39,8 @@
pub mod local;
pub mod server;
pub(crate) mod socks5_local;
pub(crate) mod tunnel_local;
mod crypto_io;

View File

@@ -0,0 +1,291 @@
//! UDP relay local server
use std::io::{self, Cursor, ErrorKind, Read};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use bytes::BytesMut;
use log::{debug, error, info, trace};
use lru_time_cache::{Entry, LruCache};
use tokio;
use tokio::net::udp::{RecvHalf, SendHalf};
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use crate::config::{ServerAddr, ServerConfig};
use crate::context::SharedContext;
use crate::relay::loadbalancing::server::{LoadBalancer, RoundRobin};
use crate::relay::socks5::{Address, UdpAssociateHeader};
use crate::relay::utils::try_timeout;
use super::crypto_io::{decrypt_payload, encrypt_payload};
use super::MAXIMUM_UDP_PAYLOAD_SIZE;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
async fn parse_packet(pkt: &[u8]) -> io::Result<(Address, Vec<u8>)> {
// PKT = UdpAssociateHeader + PAYLOAD
let mut cur = Cursor::new(pkt);
let header = UdpAssociateHeader::read_from(&mut cur).await?;
if header.frag != 0 {
error!("Received UDP associate with frag != 0, which is not supported by ShadowSocks");
let err = io::Error::new(ErrorKind::Other, "unsupported UDP fragmentation");
return Err(err);
}
let addr = header.address;
// The remaining is PAYLOAD
let mut payload = Vec::new();
cur.read_to_end(&mut payload)?;
Ok((addr, payload))
}
// Represent a UDP association
struct UdpAssociation {
tx: mpsc::Sender<Vec<u8>>,
}
impl UdpAssociation {
/// Create an association with addr
async fn associate(
svr_cfg: Arc<ServerConfig>,
src_addr: SocketAddr,
mut response_tx: mpsc::Sender<(SocketAddr, Vec<u8>)>,
) -> io::Result<UdpAssociation> {
debug!("Created UDP Association for {}", src_addr);
// Create a socket for receiving packets
let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0);
let remote_udp = UdpSocket::bind(&local_addr).await?;
// Create a channel for sending packets to remote
// FIXME: Channel size 1024?
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(1024);
// Splits socket into sender and receiver
let (mut receiver, mut sender) = remote_udp.split();
let timeout = svr_cfg.udp_timeout().unwrap_or(DEFAULT_TIMEOUT);
// local -> remote
let c_svr_cfg = svr_cfg.clone();
tokio::spawn(async move {
while let Some(pkt) = rx.recv().await {
// pkt is already a raw packet, so just send it
if let Err(err) = UdpAssociation::relay_l2r(src_addr, &mut sender, &pkt[..], timeout, &*c_svr_cfg).await
{
error!("Failed to send packet {} -> ..., error: {}", src_addr, err);
// FIXME: Ignore? Or how to deal with it?
}
}
debug!("UDP ASSOCIATE {} -> .. finished", src_addr);
});
// local <- remote
tokio::spawn(async move {
loop {
// Read and send back to source
if let Err(err) =
UdpAssociation::relay_r2l(src_addr, &mut receiver, timeout, &mut response_tx, &*svr_cfg).await
{
error!("Failed to receive packet, {} <- .., error: {}", src_addr, err);
break;
}
}
debug!("UDP ASSOCIATE {} <- .. finished", src_addr);
});
Ok(UdpAssociation { tx })
}
/// Relay packets from local to remote
async fn relay_l2r(
src: SocketAddr,
remote_udp: &mut SendHalf,
pkt: &[u8],
timeout: Duration,
svr_cfg: &ServerConfig,
) -> io::Result<()> {
let (addr, payload) = parse_packet(&pkt).await?;
debug!(
"UDP ASSOCIATE {} -> {}, payload length {} bytes",
src,
addr,
payload.len()
);
// CLIENT -> SERVER protocol: ADDRESS + PAYLOAD
let mut send_buf = Vec::new();
addr.write_to_buf(&mut send_buf);
send_buf.extend_from_slice(&payload);
let mut encrypt_buf = BytesMut::new();
encrypt_payload(svr_cfg.method(), svr_cfg.key(), &send_buf, &mut encrypt_buf)?;
let send_len = match svr_cfg.addr() {
ServerAddr::SocketAddr(ref remote_addr) => {
try_timeout(remote_udp.send_to(&encrypt_buf[..], remote_addr), Some(timeout)).await?
}
#[cfg(feature = "trust-dns")]
ServerAddr::DomainName(ref dname, port) => {
use crate::relay::dns_resolver::resolve;
let vec_ipaddr = resolve(context, dname, port, false).await?;
assert!(!vec_ipaddr.is_empty());
try_timeout(remote_udp.send_to(&encrypt_buf[..], &vec_ipaddr[0]), Some(timeout)).await?
}
#[cfg(not(feature = "trust-dns"))]
ServerAddr::DomainName(ref dname, port) => {
// try_timeout(remote_udp.send_to(&encrypt_buf[..], (dname.as_str(), port)), Some(timeout)).await?
unimplemented!(
"tokio's UdpSocket SendHalf doesn't support ToSocketAddrs, {}:{}",
dname,
port
);
}
};
assert_eq!(encrypt_buf.len(), send_len);
Ok(())
}
/// Relay packets from remote to local
async fn relay_r2l(
src_addr: SocketAddr,
remote_udp: &mut RecvHalf,
timeout: Duration,
response_tx: &mut mpsc::Sender<(SocketAddr, Vec<u8>)>,
svr_cfg: &ServerConfig,
) -> io::Result<()> {
// Waiting for response from server SERVER -> CLIENT
// Packet length is limited by MAXIMUM_UDP_PAYLOAD_SIZE, excess bytes will be discarded.
let mut recv_buf = [0u8; MAXIMUM_UDP_PAYLOAD_SIZE];
let (recv_n, remote_addr) = try_timeout(remote_udp.recv_from(&mut recv_buf), Some(timeout)).await?;
let decrypt_buf = match decrypt_payload(svr_cfg.method(), svr_cfg.key(), &recv_buf[..recv_n])? {
None => {
error!("UDP packet too short, received length {}", recv_n);
let err = io::Error::new(io::ErrorKind::InvalidData, "packet too short");
return Err(err);
}
Some(b) => b,
};
// SERVER -> CLIENT protocol: ADDRESS + PAYLOAD
let mut cur = Cursor::new(decrypt_buf);
// FIXME: Address is ignored. Maybe useful in the future if we uses one common UdpSocket for communicate with remote server
let _ = Address::read_from(&mut cur).await?;
let mut payload = Vec::new();
let header = UdpAssociateHeader::new(0, Address::SocketAddress(src_addr));
header.write_to_buf(&mut payload);
cur.read_to_end(&mut payload)?;
debug!(
"UDP ASSOCIATE {} <- {}, payload length {} bytes",
src_addr,
remote_addr,
payload.len()
);
// Send back to src_addr
if let Err(err) = response_tx.send((src_addr, payload)).await {
error!("Failed to send packet into response channel, error: {}", err);
// FIXME: What to do? Ignore?
}
Ok(())
}
async fn send(&mut self, pkt: &[u8]) -> bool {
match self.tx.send(pkt.to_vec()).await {
Ok(..) => true,
Err(err) => {
error!("Failed to send packet, error: {}", err);
false
}
}
}
}
async fn listen(context: SharedContext, l: UdpSocket) -> io::Result<()> {
let mut balancer = RoundRobin::new(context.config());
let (mut r, mut w) = l.split();
let mut pkt_buf = [0u8; MAXIMUM_UDP_PAYLOAD_SIZE];
// FIXME: Channel size 1024?
let (tx, mut rx) = mpsc::channel::<(SocketAddr, Vec<u8>)>(1024);
tokio::spawn(async move {
while let Some((src, pkt)) = rx.recv().await {
if let Err(err) = w.send_to(&pkt, &src).await {
error!("UDP packet send failed, err: {:?}", err);
break;
}
}
// FIXME: How to stop the outer listener Future?
});
// let timeout = svr_cfg.udp_timeout().unwrap_or(DEFAULT_TIMEOUT);
let mut assoc_map = LruCache::with_expiry_duration(DEFAULT_TIMEOUT);
loop {
let (recv_len, src) = r.recv_from(&mut pkt_buf).await?;
// Packet length is limited by MAXIMUM_UDP_PAYLOAD_SIZE, excess bytes will be discarded.
// Copy bytes, because udp_associate runs in another tokio Task
let pkt = &pkt_buf[..recv_len];
trace!("Received UDP packet from {}, length {} bytes", src, recv_len);
// Pick a server
let svr_cfg = balancer.pick_server();
// Check or (re)create an association
loop {
let retry = {
let assoc = match assoc_map.entry(src.to_string()) {
Entry::Occupied(oc) => oc.into_mut(),
Entry::Vacant(vc) => vc.insert(
UdpAssociation::associate(svr_cfg.clone(), src, tx.clone())
.await
.expect("Failed to create udp association"),
),
};
!assoc.send(pkt).await
};
if retry {
assoc_map.remove(&src.to_string());
} else {
break;
}
}
}
}
/// Starts a UDP local server
pub async fn run(context: SharedContext) -> io::Result<()> {
let local_addr = *context.config().local.as_ref().unwrap();
let listener = UdpSocket::bind(&local_addr).await?;
info!("ShadowSocks UDP listening on {}", local_addr);
listen(context, listener).await
}

View File

@@ -0,0 +1,269 @@
//! UDP relay local server
use std::io::{self, Cursor, Read};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use bytes::BytesMut;
use log::{debug, error, info, trace};
use lru_time_cache::{Entry, LruCache};
use tokio;
use tokio::net::udp::{RecvHalf, SendHalf};
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use crate::config::{ServerAddr, ServerConfig};
use crate::context::{Context, SharedContext};
use crate::relay::loadbalancing::server::{LoadBalancer, RoundRobin};
use crate::relay::socks5::Address;
use crate::relay::utils::try_timeout;
use super::crypto_io::{decrypt_payload, encrypt_payload};
use super::MAXIMUM_UDP_PAYLOAD_SIZE;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
// Represent a UDP association
struct UdpAssociation {
tx: mpsc::Sender<Vec<u8>>,
}
impl UdpAssociation {
/// Create an association with addr
async fn associate(
context: SharedContext,
svr_cfg: Arc<ServerConfig>,
src_addr: SocketAddr,
mut response_tx: mpsc::Sender<(SocketAddr, Vec<u8>)>,
) -> io::Result<UdpAssociation> {
debug!("Created UDP Association for {}", src_addr);
// Create a socket for receiving packets
let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0);
let remote_udp = UdpSocket::bind(&local_addr).await?;
// Create a channel for sending packets to remote
// FIXME: Channel size 1024?
let (tx, mut rx) = mpsc::channel::<Vec<u8>>(1024);
// Splits socket into sender and receiver
let (mut receiver, mut sender) = remote_udp.split();
let timeout = svr_cfg.udp_timeout().unwrap_or(DEFAULT_TIMEOUT);
// local -> remote
let c_svr_cfg = svr_cfg.clone();
tokio::spawn(async move {
while let Some(pkt) = rx.recv().await {
// pkt is already a raw packet, so just send it
if let Err(err) =
UdpAssociation::relay_l2r(&*context, src_addr, &mut sender, &pkt[..], timeout, &*c_svr_cfg).await
{
error!("Failed to send packet {} -> ..., error: {}", src_addr, err);
// FIXME: Ignore? Or how to deal with it?
}
}
debug!("UDP ASSOCIATE {} -> .. finished", src_addr);
});
// local <- remote
tokio::spawn(async move {
loop {
// Read and send back to source
if let Err(err) =
UdpAssociation::relay_r2l(src_addr, &mut receiver, timeout, &mut response_tx, &*svr_cfg).await
{
error!("Failed to receive packet, {} <- .., error: {}", src_addr, err);
break;
}
}
debug!("UDP ASSOCIATE {} <- .. finished", src_addr);
});
Ok(UdpAssociation { tx })
}
/// Relay packets from local to remote
async fn relay_l2r(
context: &Context,
src: SocketAddr,
remote_udp: &mut SendHalf,
payload: &[u8],
timeout: Duration,
svr_cfg: &ServerConfig,
) -> io::Result<()> {
let addr = context.config().forward.as_ref().unwrap();
debug!(
"UDP ASSOCIATE {} -> {}, payload length {} bytes",
src,
addr,
payload.len()
);
// CLIENT -> SERVER protocol: ADDRESS + PAYLOAD
let mut send_buf = Vec::new();
addr.write_to_buf(&mut send_buf);
send_buf.extend_from_slice(payload);
let mut encrypt_buf = BytesMut::new();
encrypt_payload(svr_cfg.method(), svr_cfg.key(), &send_buf, &mut encrypt_buf)?;
let send_len = match svr_cfg.addr() {
ServerAddr::SocketAddr(ref remote_addr) => {
try_timeout(remote_udp.send_to(&encrypt_buf[..], remote_addr), Some(timeout)).await?
}
#[cfg(feature = "trust-dns")]
ServerAddr::DomainName(ref dname, port) => {
use crate::relay::dns_resolver::resolve;
let vec_ipaddr = resolve(context, dname, port, false).await?;
assert!(!vec_ipaddr.is_empty());
try_timeout(remote_udp.send_to(&encrypt_buf[..], &vec_ipaddr[0]), Some(timeout)).await?
}
#[cfg(not(feature = "trust-dns"))]
ServerAddr::DomainName(ref dname, port) => {
// try_timeout(remote_udp.send_to(&encrypt_buf[..], (dname.as_str(), port)), Some(timeout)).await?
unimplemented!(
"tokio's UdpSocket SendHalf doesn't support ToSocketAddrs, {}:{}",
dname,
port
);
}
};
assert_eq!(encrypt_buf.len(), send_len);
Ok(())
}
/// Relay packets from remote to local
async fn relay_r2l(
src_addr: SocketAddr,
remote_udp: &mut RecvHalf,
timeout: Duration,
response_tx: &mut mpsc::Sender<(SocketAddr, Vec<u8>)>,
svr_cfg: &ServerConfig,
) -> io::Result<()> {
// Waiting for response from server SERVER -> CLIENT
// Packet length is limited by MAXIMUM_UDP_PAYLOAD_SIZE, excess bytes will be discarded.
let mut recv_buf = [0u8; MAXIMUM_UDP_PAYLOAD_SIZE];
let (recv_n, remote_addr) = try_timeout(remote_udp.recv_from(&mut recv_buf), Some(timeout)).await?;
let decrypt_buf = match decrypt_payload(svr_cfg.method(), svr_cfg.key(), &recv_buf[..recv_n])? {
None => {
error!("UDP packet too short, received length {}", recv_n);
let err = io::Error::new(io::ErrorKind::InvalidData, "packet too short");
return Err(err);
}
Some(b) => b,
};
// SERVER -> CLIENT protocol: ADDRESS + PAYLOAD
let mut cur = Cursor::new(decrypt_buf);
// FIXME: Address is ignored. Maybe useful in the future if we uses one common UdpSocket for communicate with remote server
let _ = Address::read_from(&mut cur).await?;
let mut payload = Vec::new();
cur.read_to_end(&mut payload)?;
debug!(
"UDP ASSOCIATE {} <- {}, payload length {} bytes",
src_addr,
remote_addr,
payload.len()
);
// Send back to src_addr
if let Err(err) = response_tx.send((src_addr, payload)).await {
error!("Failed to send packet into response channel, error: {}", err);
// FIXME: What to do? Ignore?
}
Ok(())
}
async fn send(&mut self, pkt: &[u8]) -> bool {
match self.tx.send(pkt.to_vec()).await {
Ok(..) => true,
Err(err) => {
error!("Failed to send packet, error: {}", err);
false
}
}
}
}
async fn listen(context: SharedContext, l: UdpSocket) -> io::Result<()> {
let mut balancer = RoundRobin::new(context.config());
let (mut r, mut w) = l.split();
let mut pkt_buf = [0u8; MAXIMUM_UDP_PAYLOAD_SIZE];
// FIXME: Channel size 1024?
let (tx, mut rx) = mpsc::channel::<(SocketAddr, Vec<u8>)>(1024);
tokio::spawn(async move {
while let Some((src, pkt)) = rx.recv().await {
if let Err(err) = w.send_to(&pkt, &src).await {
error!("UDP packet send failed, err: {:?}", err);
break;
}
}
// FIXME: How to stop the outer listener Future?
});
// let timeout = svr_cfg.udp_timeout().unwrap_or(DEFAULT_TIMEOUT);
let mut assoc_map = LruCache::with_expiry_duration(DEFAULT_TIMEOUT);
loop {
let (recv_len, src) = r.recv_from(&mut pkt_buf).await?;
// Packet length is limited by MAXIMUM_UDP_PAYLOAD_SIZE, excess bytes will be discarded.
// Copy bytes, because udp_associate runs in another tokio Task
let pkt = &pkt_buf[..recv_len];
trace!("Received UDP packet from {}, length {} bytes", src, recv_len);
// Pick a server
let svr_cfg = balancer.pick_server();
// Check or (re)create an association
loop {
let retry = {
let assoc = match assoc_map.entry(src.to_string()) {
Entry::Occupied(oc) => oc.into_mut(),
Entry::Vacant(vc) => vc.insert(
UdpAssociation::associate(context.clone(), svr_cfg.clone(), src, tx.clone())
.await
.expect("Failed to create udp association"),
),
};
!assoc.send(pkt).await
};
if retry {
assoc_map.remove(&src.to_string());
} else {
break;
}
}
}
}
/// Starts a UDP local server
pub async fn run(context: SharedContext) -> io::Result<()> {
let local_addr = *context.config().local.as_ref().unwrap();
let listener = UdpSocket::bind(&local_addr).await?;
info!("ShadowSocks UDP listening on {}", local_addr);
listen(context, listener).await
}

View File

@@ -1,6 +1,6 @@
use env_logger;
use tokio;
use tokio::net::TcpStream;
use tokio::net::{TcpStream, UdpSocket};
use tokio::prelude::*;
use tokio::time::{self, Duration};
@@ -55,3 +55,59 @@ async fn tcp_tunnel() {
println!("Got reply from server: {}", String::from_utf8(buf).unwrap());
}
#[tokio::test]
async fn udp_tunnel() {
let _ = env_logger::try_init();
let mut local_config = Config::load_from_str(
r#"{
"local_port": 9110,
"local_address": "127.0.0.1",
"server": "127.0.0.1",
"server_port": 9120,
"password": "password",
"method": "aes-256-gcm",
"mode": "tcp_and_udp"
}"#,
ConfigType::Local,
)
.unwrap();
local_config.forward = Some("127.0.0.1:9130".parse::<Address>().unwrap());
let server_config = Config::load_from_str(
r#"{
"server": "127.0.0.1",
"server_port": 9120,
"password": "password",
"method": "aes-256-gcm",
"mode": "udp_only"
}"#,
ConfigType::Server,
)
.unwrap();
tokio::spawn(run_local(local_config));
tokio::spawn(run_server(server_config));
// Start a UDP echo server
tokio::spawn(async {
let mut socket = UdpSocket::bind("127.0.0.1:9130").await.unwrap();
let mut buf = vec![0u8; 65536];
let (n, src) = socket.recv_from(&mut buf).await.unwrap();
socket.send_to(&buf[..n], src).await.unwrap();
});
time::delay_for(Duration::from_secs(1)).await;
let mut socket = UdpSocket::bind("0.0.0.0:0").await.unwrap();
socket.send_to(b"HELLO WORLD", "127.0.0.1:9110").await.unwrap();
let mut buf = vec![0u8; 65536];
let n = socket.recv(&mut buf).await.unwrap();
println!("Got reply from server: {}", ::std::str::from_utf8(&buf[..n]).unwrap());
}