Migrating to futures and tokio

This commit is contained in:
Y. T. Chung
2016-10-21 19:25:03 +08:00
parent 05d3dc2f87
commit 5bf6560352
19 changed files with 1553 additions and 1916 deletions

View File

@@ -16,8 +16,6 @@ default = [
"cipher-chacha20",
"cipher-salsa20",
"enable-udp",
]
cipher-aes-cfb = []
@@ -58,13 +56,12 @@ qrcode = "^0.1.6"
env_logger = "^0.3.2"
rust-crypto = "^0.2.34"
ip = "1.0.0"
openssl = "^0.7.1"
openssl = "^0.8"
lru-cache = "0.0.7"
libc = "^0.2.7"
hyper = "0.9"
url = "^1.2"
httparse = "^1.1"
[dependencies.coio]
git = "https://github.com/zonyitoo/coio-rs.git"
#branch = "io-timeouts"
futures = "0.1"
futures-cpupool = "0.1"
tokio-core = "0.1"

View File

@@ -30,27 +30,21 @@ extern crate clap;
extern crate shadowsocks;
#[macro_use]
extern crate log;
extern crate time;
extern crate coio;
extern crate env_logger;
extern crate ip;
extern crate time;
use clap::{App, Arg};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::net::SocketAddr;
use std::env;
use std::time::Duration;
use coio::Scheduler;
use std::sync::Arc;
use env_logger::LogBuilder;
use log::{LogRecord, LogLevelFilter};
use ip::IpAddr;
use shadowsocks::config::{self, Config, ServerConfig};
use shadowsocks::config::DEFAULT_DNS_CACHE_CAPACITY;
use shadowsocks::relay::{RelayLocal, Relay};
use shadowsocks::relay::RelayLocal;
fn main() {
let matches = App::new("shadowsocks")
@@ -75,21 +69,11 @@ fn main() {
.long("server-addr")
.takes_value(true)
.help("Server address"))
.arg(Arg::with_name("SERVER_PORT")
.short("p")
.long("server-port")
.takes_value(true)
.help("Server port"))
.arg(Arg::with_name("LOCAL_ADDR")
.short("b")
.long("local-addr")
.takes_value(true)
.help("Local address, listen only to this address if specified"))
.arg(Arg::with_name("LOCAL_PORT")
.short("l")
.long("local-port")
.takes_value(true)
.help("Local port"))
.arg(Arg::with_name("PASSWORD")
.short("k")
.long("password")
@@ -203,26 +187,21 @@ fn main() {
let mut has_provided_server_config = false;
if matches.value_of("SERVER_ADDR").is_some() && matches.value_of("SERVER_PORT").is_some() &&
matches.value_of("PASSWORD").is_some() && matches.value_of("ENCRYPT_METHOD").is_some() {
let (svr_addr, svr_port, password, method) = matches.value_of("SERVER_ADDR")
if matches.value_of("SERVER_ADDR").is_some() && matches.value_of("PASSWORD").is_some() &&
matches.value_of("ENCRYPT_METHOD").is_some() {
let (svr_addr, password, method) = matches.value_of("SERVER_ADDR")
.and_then(|svr_addr| {
matches.value_of("SERVER_PORT")
.map(|svr_port| (svr_addr, svr_port))
})
.and_then(|(svr_addr, svr_port)| {
matches.value_of("PASSWORD")
.map(|pwd| (svr_addr, svr_port, pwd))
.map(|pwd| (svr_addr, pwd))
})
.and_then(|(svr_addr, svr_port, pwd)| {
.and_then(|(svr_addr, pwd)| {
matches.value_of("ENCRYPT_METHOD")
.map(|m| (svr_addr, svr_port, pwd, m))
.map(|m| (svr_addr, pwd, m))
})
.unwrap();
let sc = ServerConfig {
addr: svr_addr.to_owned(),
port: svr_port.parse().ok().expect("`port` should be an integer"),
addr: svr_addr.parse::<SocketAddr>().expect("Invalid server addr"),
password: password.to_owned(),
method: match method.parse() {
Ok(m) => m,
@@ -234,34 +213,26 @@ fn main() {
dns_cache_capacity: DEFAULT_DNS_CACHE_CAPACITY,
};
config.server.push(sc);
config.server.push(Arc::new(sc));
has_provided_server_config = true;
} else if matches.value_of("SERVER_ADDR").is_none() && matches.value_of("SERVER_PORT").is_none() &&
matches.value_of("PASSWORD").is_none() && matches.value_of("ENCRYPT_METHOD").is_none() {
} else if matches.value_of("SERVER_ADDR").is_none() && matches.value_of("PASSWORD").is_none() &&
matches.value_of("ENCRYPT_METHOD").is_none() {
// Does not provide server config
} else {
panic!("`server-addr`, `server-port`, `method` and `password` should be provided together");
panic!("`server-addr`, `method` and `password` should be provided together");
}
let mut has_provided_local_config = false;
if matches.value_of("LOCAL_ADDR").is_some() && matches.value_of("LOCAL_PORT").is_some() {
let (local_addr, local_port) = matches.value_of("LOCAL_ADDR")
.and_then(|local_addr| {
matches.value_of("LOCAL_PORT")
.map(|p| (local_addr, p))
})
if matches.value_of("LOCAL_ADDR").is_some() {
let local_addr = matches.value_of("LOCAL_ADDR")
.unwrap();
let local_addr: IpAddr = local_addr.parse()
let local_addr: SocketAddr = local_addr.parse()
.ok()
.expect("`local-addr` is not a valid IP address");
let local_port: u16 = local_port.parse().ok().expect("`local-port` is not a valid integer");
config.local = Some(match local_addr {
IpAddr::V4(v4) => SocketAddr::V4(SocketAddrV4::new(v4, local_port)),
IpAddr::V6(v6) => SocketAddr::V6(SocketAddrV6::new(v6, local_port, 0, 0)),
});
config.local = Some(Arc::new(local_addr));
has_provided_local_config = true;
}
@@ -277,40 +248,5 @@ fn main() {
debug!("Config: {:?}", config);
let threads = matches.value_of("THREADS")
.unwrap_or("1")
.parse::<usize>()
.ok()
.expect("`threads` should be an integer");
let stack_size = if matches.occurrences_of("VERBOSE") >= 1 {
1 * 1024 * 1024 // 1M stack for formatting!
} else {
128 * 1024
};
trace!("Coroutine stack size: {}, thread count {}",
stack_size,
threads);
Scheduler::new()
.with_workers(threads)
.default_stack_size(stack_size)
.run(move || {
if debug_level > 0 && cfg!(debug_assertions) {
// Statistic coroutine
Scheduler::spawn(|| {
loop {
debug!("STAT Coroutines: {}, TCP work: {}, HTTP work: {}",
Scheduler::instance().unwrap().work_count(),
RelayLocal::global_tcp_work_count(),
RelayLocal::global_http_work_count());
coio::sleep(Duration::from_secs(1));
}
});
}
RelayLocal::new(config).run();
})
.unwrap();
RelayLocal::new(Arc::new(config)).run().unwrap();
}

View File

@@ -32,24 +32,21 @@ extern crate clap;
extern crate shadowsocks;
#[macro_use]
extern crate log;
extern crate time;
extern crate coio;
extern crate env_logger;
extern crate time;
use std::env;
use std::time::Duration;
use std::sync::Arc;
use std::net::SocketAddr;
use clap::{App, Arg};
use coio::Scheduler;
use env_logger::LogBuilder;
use log::{LogRecord, LogLevelFilter};
use time::PreciseTime;
use shadowsocks::config::{self, Config, ServerConfig};
use shadowsocks::config::DEFAULT_DNS_CACHE_CAPACITY;
use shadowsocks::relay::{RelayServer, Relay};
use shadowsocks::relay::RelayServer;
fn main() {
let matches = App::new("shadowsocks")
@@ -74,11 +71,6 @@ fn main() {
.long("server-addr")
.takes_value(true)
.help("Server address"))
.arg(Arg::with_name("SERVER_PORT")
.short("p")
.long("server-port")
.takes_value(true)
.help("Server port"))
.arg(Arg::with_name("LOCAL_ADDR")
.short("b")
.long("local-addr")
@@ -201,26 +193,21 @@ fn main() {
let mut has_provided_server_config = false;
if matches.value_of("SERVER_ADDR").is_some() && matches.value_of("SERVER_PORT").is_some() &&
matches.value_of("PASSWORD").is_some() && matches.value_of("ENCRYPT_METHOD").is_some() {
let (svr_addr, svr_port, password, method) = matches.value_of("SERVER_ADDR")
if matches.value_of("SERVER_ADDR").is_some() && matches.value_of("PASSWORD").is_some() &&
matches.value_of("ENCRYPT_METHOD").is_some() {
let (svr_addr, password, method) = matches.value_of("SERVER_ADDR")
.and_then(|svr_addr| {
matches.value_of("SERVER_PORT")
.map(|svr_port| (svr_addr, svr_port))
})
.and_then(|(svr_addr, svr_port)| {
matches.value_of("PASSWORD")
.map(|pwd| (svr_addr, svr_port, pwd))
.map(|pwd| (svr_addr, pwd))
})
.and_then(|(svr_addr, svr_port, pwd)| {
.and_then(|(svr_addr, pwd)| {
matches.value_of("ENCRYPT_METHOD")
.map(|m| (svr_addr, svr_port, pwd, m))
.map(|m| (svr_addr, pwd, m))
})
.unwrap();
let sc = ServerConfig {
addr: svr_addr.to_owned(),
port: svr_port.parse().ok().expect("`port` should be an integer"),
addr: svr_addr.parse::<SocketAddr>().expect("`server-addr` invalid"),
password: password.to_owned(),
method: match method.parse() {
Ok(m) => m,
@@ -232,13 +219,13 @@ fn main() {
dns_cache_capacity: DEFAULT_DNS_CACHE_CAPACITY,
};
config.server.push(sc);
config.server.push(Arc::new(sc));
has_provided_server_config = true;
} else if matches.value_of("SERVER_ADDR").is_none() && matches.value_of("SERVER_PORT").is_none() &&
matches.value_of("PASSWORD").is_none() && matches.value_of("ENCRYPT_METHOD").is_none() {
} else if matches.value_of("SERVER_ADDR").is_none() && matches.value_of("PASSWORD").is_none() &&
matches.value_of("ENCRYPT_METHOD").is_none() {
// Does not provide server config
} else {
panic!("`server-addr`, `server-port`, `method` and `password` should be provided together");
panic!("`server-addr`, `method` and `password` should be provided together");
}
if !has_provided_config && !has_provided_server_config {
@@ -264,33 +251,5 @@ fn main() {
.ok()
.expect("`threads` should be an integer");
let stack_size = if matches.occurrences_of("VERBOSE") >= 1 {
1 * 1024 * 1024 // 1M stack for formatting!
} else {
128 * 1024
};
trace!("Coroutine stack size: {}, thread count {}",
stack_size,
threads);
let show_run_time = matches.occurrences_of("VERBOSE") >= 2;
Scheduler::new()
.with_workers(threads)
.default_stack_size(stack_size)
.run(move || {
if show_run_time {
let start_time = PreciseTime::now();
Scheduler::spawn(move || {
loop {
info!("SYSTEM System has already run {}",
start_time.to(PreciseTime::now()));
coio::sleep(Duration::from_secs(5));
}
});
}
RelayServer::new(config).run();
})
.unwrap();
RelayServer::new(Arc::new(config)).run(threads).unwrap();
}

View File

@@ -76,6 +76,7 @@ use std::fmt::{self, Debug, Formatter};
use std::path::Path;
use std::collections::HashSet;
use std::time::Duration;
use std::sync::Arc;
use ip::IpAddr;
@@ -87,8 +88,7 @@ pub const DEFAULT_DNS_CACHE_CAPACITY: usize = 65536;
/// Configuration for a server
#[derive(Clone, Debug)]
pub struct ServerConfig {
pub addr: String,
pub port: u16,
pub addr: SocketAddr,
pub password: String,
pub method: CipherType,
pub timeout: Option<Duration>,
@@ -96,10 +96,9 @@ pub struct ServerConfig {
}
impl ServerConfig {
pub fn basic(addr: String, port: u16, password: String, method: CipherType) -> ServerConfig {
pub fn basic(addr: SocketAddr, password: String, method: CipherType) -> ServerConfig {
ServerConfig {
addr: addr,
port: port,
password: password,
method: method,
timeout: None,
@@ -112,8 +111,18 @@ impl json::ToJson for ServerConfig {
fn to_json(&self) -> json::Json {
use serialize::json::Json;
let mut obj = json::Object::new();
obj.insert("address".to_owned(), Json::String(self.addr.clone()));
obj.insert("port".to_owned(), Json::U64(self.port as u64));
match self.addr {
SocketAddr::V4(ref v4) => {
obj.insert("address".to_owned(), Json::String(v4.ip().to_string()));
obj.insert("port".to_owned(), Json::U64(v4.port() as u64));
}
SocketAddr::V6(ref v6) => {
obj.insert("address".to_owned(), Json::String(v6.ip().to_string()));
obj.insert("port".to_owned(), Json::U64(v6.port() as u64));
}
}
obj.insert("password".to_owned(), Json::String(self.password.clone()));
obj.insert("method".to_owned(), Json::String(self.method.to_string()));
if let Some(t) = self.timeout {
@@ -138,12 +147,12 @@ pub enum ConfigType {
/// Configuration
#[derive(Clone, Debug)]
pub struct Config {
pub server: Vec<ServerConfig>,
pub local: Option<ClientConfig>,
pub http_proxy: Option<ClientConfig>,
pub server: Vec<Arc<ServerConfig>>,
pub local: Option<Arc<ClientConfig>>,
pub http_proxy: Option<Arc<ClientConfig>>,
pub enable_udp: bool,
pub timeout: Option<Duration>,
pub forbidden_ip: HashSet<IpAddr>,
pub forbidden_ip: Arc<HashSet<IpAddr>>,
}
impl Default for Config {
@@ -195,10 +204,110 @@ impl Config {
http_proxy: None,
enable_udp: false,
timeout: None,
forbidden_ip: HashSet::new(),
forbidden_ip: Arc::new(HashSet::new()),
}
}
fn parse_server(server: &json::Object) -> Result<ServerConfig, Error> {
let method = server.get("method")
.ok_or_else(|| Error::new(ErrorKind::MissingField, "need to specify a method", None))
.and_then(|method_o| {
method_o.as_string()
.ok_or_else(|| Error::new(ErrorKind::Malformed, "`method` should be a string", None))
})
.and_then(|method_str| {
method_str.parse::<CipherType>()
.map_err(|_| {
Error::new(ErrorKind::Invalid,
"not supported method",
Some(format!("`{}` is not a supported method", method_str)))
})
});
let method = try!(method);
let port = server.get("port")
.or_else(|| server.get("server_port"))
.ok_or_else(|| {
Error::new(ErrorKind::MissingField,
"need to specify a server port",
None)
})
.and_then(|port_o| {
port_o.as_u64()
.map(|u| u as u16)
.ok_or_else(|| Error::new(ErrorKind::Malformed, "`port` should be an integer", None))
});
let port = try!(port);
let addr = server.get("address")
.or_else(|| server.get("server"))
.ok_or_else(|| {
Error::new(ErrorKind::MissingField,
"need to specify a server address",
None)
})
.and_then(|addr_o| {
addr_o.as_string()
.ok_or_else(|| Error::new(ErrorKind::Malformed, "`address` should be a string", None))
})
.and_then(|addr_str| {
addr_str.parse::<Ipv4Addr>()
.map(|v4| SocketAddr::V4(SocketAddrV4::new(v4, port)))
.or_else(|_| {
addr_str.parse::<Ipv6Addr>()
.map(|v6| SocketAddr::V6(SocketAddrV6::new(v6, port, 0, 0)))
})
.map_err(|_| Error::new(ErrorKind::Malformed, "invalid server addr", None))
});
let mut addr = try!(addr);
// Merge address and port
match addr {
SocketAddr::V4(ref mut v4) => v4.set_port(port),
SocketAddr::V6(ref mut v6) => v6.set_port(port),
}
let password = server.get("password")
.ok_or_else(|| Error::new(ErrorKind::MissingField, "need to specify a password", None))
.and_then(|pwd_o| {
pwd_o.as_string()
.ok_or_else(|| Error::new(ErrorKind::Malformed, "`password` should be a string", None))
.map(|s| s.to_string())
});
let password = try!(password);
let timeout = match server.get("timeout") {
Some(t) => {
let val = try!(t.as_u64()
.ok_or(Error::new(ErrorKind::Malformed, "`timeout` should be an integer", None)));
Some(Duration::from_secs(val))
}
None => None,
};
let dns_cache_capacity = match server.get("dns_cache_capacity") {
Some(t) => {
try!(t.as_u64()
.ok_or(Error::new(ErrorKind::Malformed,
"`dns_cache_capacity` should be an integer",
None))) as usize
}
None => DEFAULT_DNS_CACHE_CAPACITY,
};
Ok(ServerConfig {
addr: addr,
password: password,
method: method,
timeout: timeout,
dns_cache_capacity: dns_cache_capacity,
})
}
fn parse_json_object(o: &json::Object, require_local_info: bool) -> Result<Config, Error> {
let mut config = Config::new();
@@ -218,119 +327,17 @@ impl Config {
.ok_or(Error::new(ErrorKind::Malformed, "`servers` should be a list", None)));
for server in server_list.iter() {
let method_o = try!(server.find("method")
.ok_or(Error::new(ErrorKind::MissingField, "need to specify a method", None)));
let method_str = try!(method_o.as_string()
.ok_or(Error::new(ErrorKind::Malformed, "`method` should be a string", None)));
let method = try!(method_str.parse::<CipherType>().map_err(|_| {
Error::new(ErrorKind::Invalid,
"not supported method",
Some(format!("`{}` is not a supported method", method_str)))
}));
let addr_o = try!(server.find("address")
.ok_or(Error::new(ErrorKind::MissingField,
"need to specify a server address",
None)));
let addr_str = try!(addr_o.as_string()
.ok_or(Error::new(ErrorKind::Malformed, "`address` should be a string", None)));
let cfg = ServerConfig {
addr: addr_str.to_string(),
port: try!(try!(server.find("port")
.ok_or(Error::new(ErrorKind::MissingField,
"need to specify a server port",
None)))
.as_u64()
.ok_or(Error::new(ErrorKind::Malformed,
"`port` should be an \
integer",
None))) as u16,
password: try!(try!(server.find("password")
.ok_or(Error::new(ErrorKind::MissingField, "need to specify a password", None)))
.as_string()
.ok_or(Error::new(ErrorKind::Malformed, "`password` should be a string", None)))
.to_string(),
method: method,
timeout: match server.find("timeout") {
Some(t) => {
let val = try!(t.as_u64()
.ok_or(Error::new(ErrorKind::Malformed, "`timeout` should be an integer", None)));
Some(Duration::from_secs(val))
}
None => None,
},
dns_cache_capacity: match server.find("dns_cache_capacity") {
Some(t) => {
try!(t.as_u64()
.ok_or(Error::new(ErrorKind::Malformed,
"`dns_cache_capacity` should be an integer",
None))) as usize
}
None => DEFAULT_DNS_CACHE_CAPACITY,
},
};
config.server.push(cfg);
if let Some(server) = server.as_object() {
let cfg = try!(Config::parse_server(server));
config.server.push(Arc::new(cfg));
}
}
} else if o.contains_key("server") && o.contains_key("server_port") && o.contains_key("password") &&
o.contains_key("method") {
// Traditional configuration file
let method_o = try!(o.get("method")
.ok_or(Error::new(ErrorKind::MissingField, "need to specify method", None)));
let method_str = try!(method_o.as_string()
.ok_or(Error::new(ErrorKind::Malformed, "`method` should be a string", None)));
let method = try!(method_str.parse::<CipherType>()
.map_err(|_| {
Error::new(ErrorKind::Invalid,
"not supported method",
Some(format!("`{}` is not a supported method", method_str)))
}));
let addr_o = try!(o.get("server")
.ok_or(Error::new(ErrorKind::MissingField,
"need to specify server address",
None)));
let addr_str = try!(addr_o.as_string()
.ok_or(Error::new(ErrorKind::Malformed, "`server` should be a string", None)));
let single_server = ServerConfig {
addr: addr_str.to_string(),
port: try!(try!(o.get("server_port")
.ok_or(Error::new(ErrorKind::MissingField,
"need to specify a server port",
None)))
.as_u64()
.ok_or(Error::new(ErrorKind::Malformed,
"`port` should be an \
integer",
None))) as u16,
password: try!(try!(o.get("password")
.ok_or(Error::new(ErrorKind::MissingField, "need to specify a password", None)))
.as_string()
.ok_or(Error::new(ErrorKind::Malformed, "`password` should be a string", None)))
.to_string(),
method: method,
timeout: match o.get("timeout") {
Some(t) => {
let val = try!(t.as_u64()
.ok_or(Error::new(ErrorKind::Malformed, "`timeout` should be an integer", None)));
Some(Duration::from_secs(val))
}
None => None,
},
dns_cache_capacity: match o.get("dns_cache_capacity") {
Some(t) => {
try!(t.as_u64()
.ok_or(Error::new(ErrorKind::Malformed,
"`dns_cache_capacity` should be an integer",
None))) as usize
}
None => DEFAULT_DNS_CACHE_CAPACITY,
},
};
config.server = vec![single_server];
let single_server = try!(Config::parse_server(o));
config.server = vec![Arc::new(single_server)];
}
if require_local_info {
@@ -353,10 +360,10 @@ impl Config {
None))) as u16;
match addr_str.parse::<Ipv4Addr>() {
Ok(ip) => Some(SocketAddr::V4(SocketAddrV4::new(ip, port))),
Ok(ip) => Some(Arc::new(SocketAddr::V4(SocketAddrV4::new(ip, port)))),
Err(..) => {
match addr_str.parse::<Ipv6Addr>() {
Ok(ip) => Some(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0))),
Ok(ip) => Some(Arc::new(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)))),
Err(..) => {
return Err(Error::new(ErrorKind::Malformed,
"`local_address` is not a valid IP \
@@ -392,10 +399,10 @@ impl Config {
None))) as u16;
match addr_str.parse::<Ipv4Addr>() {
Ok(ip) => Some(SocketAddr::V4(SocketAddrV4::new(ip, port))),
Ok(ip) => Some(Arc::new(SocketAddr::V4(SocketAddrV4::new(ip, port)))),
Err(..) => {
match addr_str.parse::<Ipv6Addr>() {
Ok(ip) => Some(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0))),
Ok(ip) => Some(Arc::new(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)))),
Err(..) => {
return Err(Error::new(ErrorKind::Malformed,
"`local_http_address` is not a valid IP \
@@ -418,7 +425,9 @@ impl Config {
.ok_or(Error::new(ErrorKind::Malformed,
"`forbidden_ip` should be a list",
None)));
config.forbidden_ip.extend(forbidden_ip_arr.into_iter().filter_map(|x| {
let mut forbidden_ip = HashSet::new();
forbidden_ip.extend(forbidden_ip_arr.into_iter().filter_map(|x| {
let x = match x.as_string() {
Some(x) => x,
None => {
@@ -436,6 +445,8 @@ impl Config {
}
}
}));
config.forbidden_ip = Arc::new(forbidden_ip);
}
Ok(config)
@@ -512,10 +523,20 @@ impl json::ToJson for Config {
let mut obj = json::Object::new();
if self.server.len() == 1 {
// Official format
obj.insert("server".to_owned(),
Json::String(self.server[0].addr.clone()));
obj.insert("server_port".to_owned(),
Json::U64(self.server[0].port as u64));
let server = &self.server[0];
match server.addr {
SocketAddr::V4(ref v4) => {
obj.insert("server".to_owned(), Json::String(v4.ip().to_string()));
obj.insert("server_port".to_owned(), Json::U64(v4.port() as u64));
}
SocketAddr::V6(ref v6) => {
obj.insert("server".to_owned(), Json::String(v6.ip().to_string()));
obj.insert("server_port".to_owned(), Json::U64(v6.port() as u64));
}
}
obj.insert("password".to_owned(),
Json::String(self.server[0].password.clone()));
obj.insert("method".to_owned(),
@@ -528,8 +549,8 @@ impl json::ToJson for Config {
obj.insert("servers".to_owned(), Json::Array(arr));
}
if let Some(l) = self.local {
let ip_str = match &l {
if let Some(ref l) = self.local {
let ip_str = match &**l {
&SocketAddr::V4(ref v4) => v4.ip().to_string(),
&SocketAddr::V6(ref v6) => v6.ip().to_string(),
};

View File

@@ -22,7 +22,8 @@
//! Ciphers
use std::str::FromStr;
use std::fmt::{Debug, Display, self};
use std::fmt::{self, Debug, Display};
use std::io;
use rand::{self, Rng};
use std::convert::From;
@@ -34,6 +35,8 @@ use crypto::crypto::CryptoCipher;
use crypto::digest::{self, DigestType};
use openssl::crypto::symm;
/// Basic operation of Cipher, which is a Symmetric Cipher.
///
/// The `update` method could be called multiple times, and the `finalize` method will
@@ -45,48 +48,48 @@ pub trait Cipher {
pub type CipherResult<T> = Result<T, Error>;
#[derive(Copy, Clone)]
pub enum ErrorKind {
pub enum Error {
UnknownCipherType,
OpenSSLError,
}
pub struct Error {
pub kind: ErrorKind,
pub desc: &'static str,
pub detail: Option<String>,
}
impl Error {
pub fn new(kind: ErrorKind, desc: &'static str, detail: Option<String>) -> Error {
Error {
kind: kind,
desc: desc,
detail: detail,
}
}
OpenSSLError(::openssl::error::ErrorStack),
IoError(io::Error),
}
impl Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
try!(write!(f, "{}", self.desc));
match self.detail {
Some(ref d) => write!(f, " ({})", d),
None => Ok(())
match self {
&Error::UnknownCipherType => write!(f, "UnknownCipherType"),
&Error::OpenSSLError(ref err) => write!(f, "{:?}", err),
&Error::IoError(ref err) => write!(f, "{:?}", err),
}
}
}
impl Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
try!(write!(f, "{}", self.desc));
match self.detail {
Some(ref d) => write!(f, " ({})", d),
None => Ok(())
match self {
&Error::UnknownCipherType => write!(f, "UnknownCipherType"),
&Error::OpenSSLError(ref err) => write!(f, "{}", err),
&Error::IoError(ref err) => write!(f, "{}", err),
}
}
}
impl From<Error> for io::Error {
fn from(e: Error) -> io::Error {
match e {
Error::UnknownCipherType => io::Error::new(io::ErrorKind::Other, "Unknown Cipher type"),
Error::OpenSSLError(err) => From::from(err),
Error::IoError(err) => err,
}
}
}
impl From<::openssl::error::ErrorStack> for Error {
fn from(e: ::openssl::error::ErrorStack) -> Error {
Error::OpenSSLError(e)
}
}
#[cfg(feature = "cipher-aes-cfb")]
const CIPHER_AES_128_CFB: &'static str = "aes-128-cfb";
#[cfg(feature = "cipher-aes-cfb")]
@@ -121,21 +124,33 @@ const CIPHER_SALSA20: &'static str = "salsa20";
pub enum CipherType {
Table,
#[cfg(feature = "cipher-aes-cfb")] Aes128Cfb,
#[cfg(feature = "cipher-aes-cfb")] Aes128Cfb1,
#[cfg(feature = "cipher-aes-cfb")] Aes128Cfb8,
#[cfg(feature = "cipher-aes-cfb")] Aes128Cfb128,
#[cfg(feature = "cipher-aes-cfb")]
Aes128Cfb,
#[cfg(feature = "cipher-aes-cfb")]
Aes128Cfb1,
#[cfg(feature = "cipher-aes-cfb")]
Aes128Cfb8,
#[cfg(feature = "cipher-aes-cfb")]
Aes128Cfb128,
#[cfg(feature = "cipher-aes-cfb")] Aes256Cfb,
#[cfg(feature = "cipher-aes-cfb")] Aes256Cfb1,
#[cfg(feature = "cipher-aes-cfb")] Aes256Cfb8,
#[cfg(feature = "cipher-aes-cfb")] Aes256Cfb128,
#[cfg(feature = "cipher-aes-cfb")]
Aes256Cfb,
#[cfg(feature = "cipher-aes-cfb")]
Aes256Cfb1,
#[cfg(feature = "cipher-aes-cfb")]
Aes256Cfb8,
#[cfg(feature = "cipher-aes-cfb")]
Aes256Cfb128,
#[cfg(feature = "cipher-rc4")] Rc4,
#[cfg(feature = "cipher-rc4")] Rc4Md5,
#[cfg(feature = "cipher-rc4")]
Rc4,
#[cfg(feature = "cipher-rc4")]
Rc4Md5,
#[cfg(feature = "cipher-chacha20")] ChaCha20,
#[cfg(feature = "cipher-salsa20")] Salsa20,
#[cfg(feature = "cipher-chacha20")]
ChaCha20,
#[cfg(feature = "cipher-salsa20")]
Salsa20,
}
impl CipherType {
@@ -143,18 +158,25 @@ impl CipherType {
match *self {
CipherType::Table => 0,
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb => 16,
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb1 => 16,
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb8 => 16,
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb128 => 16,
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes128Cfb => symm::Type::AES_128_CFB128.block_size(),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes128Cfb1 => symm::Type::AES_128_CFB1.block_size(),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes128Cfb8 => symm::Type::AES_128_CFB8.block_size(),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes128Cfb128 => symm::Type::AES_128_CFB128.block_size(),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes256Cfb => symm::Type::AES_256_CFB128.block_size(),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes256Cfb1 => symm::Type::AES_256_CFB1.block_size(),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes256Cfb8 => symm::Type::AES_256_CFB8.block_size(),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes256Cfb128 => symm::Type::AES_256_CFB128.block_size(),
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb => 16,
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb1 => 16,
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb8 => 16,
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb128 => 16,
#[cfg(feature = "cipher-rc4")] CipherType::Rc4 => 0,
#[cfg(feature = "cipher-rc4")] CipherType::Rc4Md5 => 16,
#[cfg(feature = "cipher-rc4")] CipherType::Rc4 => symm::Type::RC4_128.block_size(),
#[cfg(feature = "cipher-rc4")] CipherType::Rc4Md5 => symm::Type::RC4_128.block_size(),
#[cfg(feature = "cipher-chacha20")] CipherType::ChaCha20 => 8,
#[cfg(feature = "cipher-salsa20")] CipherType::Salsa20 => 8,
@@ -165,18 +187,25 @@ impl CipherType {
match *self {
CipherType::Table => 0,
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb => 16,
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb1 => 16,
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb8 => 16,
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb128 => 16,
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes128Cfb => symm::Type::AES_128_CFB128.key_len(),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes128Cfb1 => symm::Type::AES_128_CFB1.key_len(),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes128Cfb8 => symm::Type::AES_128_CFB8.key_len(),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes128Cfb128 => symm::Type::AES_128_CFB128.key_len(),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes256Cfb => symm::Type::AES_256_CFB128.key_len(),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes256Cfb1 => symm::Type::AES_256_CFB1.key_len(),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes256Cfb8 => symm::Type::AES_256_CFB8.key_len(),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes256Cfb128 => symm::Type::AES_256_CFB128.key_len(),
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb => 32,
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb1 => 32,
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb8 => 32,
#[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb128 => 32,
#[cfg(feature = "cipher-rc4")] CipherType::Rc4 => 16,
#[cfg(feature = "cipher-rc4")] CipherType::Rc4Md5 => 16,
#[cfg(feature = "cipher-rc4")] CipherType::Rc4 => symm::Type::RC4_128.key_len(),
#[cfg(feature = "cipher-rc4")] CipherType::Rc4Md5 => symm::Type::RC4_128.key_len(),
#[cfg(feature = "cipher-chacha20")] CipherType::ChaCha20 => 32,
#[cfg(feature = "cipher-salsa20")] CipherType::Salsa20 => 32,
@@ -203,15 +232,49 @@ impl CipherType {
i += 1
}
let whole = m.iter().fold(Vec::new(), |mut a, b| { a.extend_from_slice(b); a });
let whole = m.iter().fold(Vec::new(), |mut a, b| {
a.extend_from_slice(b);
a
});
let key = whole[0..key_len].to_vec();
key
}
pub fn iv_size(&self) -> usize {
match *self {
CipherType::Table => 0,
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes128Cfb => symm::Type::AES_128_CFB128.iv_len().unwrap_or(0),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes128Cfb1 => symm::Type::AES_128_CFB1.iv_len().unwrap_or(0),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes128Cfb8 => symm::Type::AES_128_CFB8.iv_len().unwrap_or(0),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes128Cfb128 => symm::Type::AES_128_CFB128.iv_len().unwrap_or(0),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes256Cfb => symm::Type::AES_256_CFB128.iv_len().unwrap_or(0),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes256Cfb1 => symm::Type::AES_256_CFB1.iv_len().unwrap_or(0),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes256Cfb8 => symm::Type::AES_256_CFB8.iv_len().unwrap_or(0),
#[cfg(feature = "cipher-aes-cfb")]
CipherType::Aes256Cfb128 => symm::Type::AES_256_CFB128.iv_len().unwrap_or(0),
#[cfg(feature = "cipher-rc4")] CipherType::Rc4 => symm::Type::RC4_128.iv_len().unwrap_or(0),
#[cfg(feature = "cipher-rc4")] CipherType::Rc4Md5 => symm::Type::RC4_128.iv_len().unwrap_or(0),
#[cfg(feature = "cipher-chacha20")] CipherType::ChaCha20 => 32,
#[cfg(feature = "cipher-salsa20")] CipherType::Salsa20 => 32,
}
}
pub fn gen_init_vec(&self) -> Vec<u8> {
let iv_len = self.block_size();
let iv_len = self.iv_size();
let mut iv = Vec::with_capacity(iv_len);
unsafe { iv.set_len(iv_len); }
unsafe {
iv.set_len(iv_len);
}
rand::thread_rng().fill_bytes(iv.as_mut_slice());
iv
@@ -224,46 +287,34 @@ impl FromStr for CipherType {
match s {
CIPHER_TABLE | "" => Ok(CipherType::Table),
#[cfg(feature = "cipher-aes-cfb")]
CIPHER_AES_128_CFB =>
Ok(CipherType::Aes128Cfb),
CIPHER_AES_128_CFB => Ok(CipherType::Aes128Cfb),
#[cfg(feature = "cipher-aes-cfb")]
CIPHER_AES_128_CFB_1 =>
Ok(CipherType::Aes128Cfb1),
CIPHER_AES_128_CFB_1 => Ok(CipherType::Aes128Cfb1),
#[cfg(feature = "cipher-aes-cfb")]
CIPHER_AES_128_CFB_8 =>
Ok(CipherType::Aes128Cfb8),
CIPHER_AES_128_CFB_8 => Ok(CipherType::Aes128Cfb8),
#[cfg(feature = "cipher-aes-cfb")]
CIPHER_AES_128_CFB_128 =>
Ok(CipherType::Aes128Cfb128),
CIPHER_AES_128_CFB_128 => Ok(CipherType::Aes128Cfb128),
#[cfg(feature = "cipher-aes-cfb")]
CIPHER_AES_256_CFB =>
Ok(CipherType::Aes256Cfb),
CIPHER_AES_256_CFB => Ok(CipherType::Aes256Cfb),
#[cfg(feature = "cipher-aes-cfb")]
CIPHER_AES_256_CFB_1 =>
Ok(CipherType::Aes256Cfb1),
CIPHER_AES_256_CFB_1 => Ok(CipherType::Aes256Cfb1),
#[cfg(feature = "cipher-aes-cfb")]
CIPHER_AES_256_CFB_8 =>
Ok(CipherType::Aes256Cfb8),
CIPHER_AES_256_CFB_8 => Ok(CipherType::Aes256Cfb8),
#[cfg(feature = "cipher-aes-cfb")]
CIPHER_AES_256_CFB_128 =>
Ok(CipherType::Aes256Cfb128),
CIPHER_AES_256_CFB_128 => Ok(CipherType::Aes256Cfb128),
#[cfg(feature = "cipher-rc4")]
CIPHER_RC4 =>
Ok(CipherType::Rc4),
CIPHER_RC4 => Ok(CipherType::Rc4),
#[cfg(feature = "cipher-rc4")]
CIPHER_RC4_MD5 =>
Ok(CipherType::Rc4Md5),
CIPHER_RC4_MD5 => Ok(CipherType::Rc4Md5),
#[cfg(feature = "cipher-chacha20")]
CIPHER_CHACHA20 =>
Ok(CipherType::ChaCha20),
CIPHER_CHACHA20 => Ok(CipherType::ChaCha20),
#[cfg(feature = "cipher-salsa20")]
CIPHER_SALSA20 =>
Ok(CipherType::Salsa20),
CIPHER_SALSA20 => Ok(CipherType::Salsa20),
_ => Err(Error::new(ErrorKind::UnknownCipherType, "Unknown cipher type", None))
_ => Err(Error::UnknownCipherType),
}
}
}
@@ -360,15 +411,12 @@ pub fn with_type(t: CipherType, key: &[u8], iv: &[u8], mode: CryptoMode) -> Ciph
CipherType::Table => CipherVariant::new(table::TableCipher::new(key, mode)),
#[cfg(feature = "cipher-chacha20")]
CipherType::ChaCha20 =>
CipherVariant::new(CryptoCipher::new(t, key, iv)),
CipherType::ChaCha20 => CipherVariant::new(CryptoCipher::new(t, key, iv)),
#[cfg(feature = "cipher-salsa20")]
CipherType::Salsa20 =>
CipherVariant::new(CryptoCipher::new(t, key, iv)),
CipherType::Salsa20 => CipherVariant::new(CryptoCipher::new(t, key, iv)),
#[cfg(feature = "cipher-rc4")]
CipherType::Rc4Md5 =>
CipherVariant::new(rc4_md5::Rc4Md5Cipher::new(key, iv, mode)),
CipherType::Rc4Md5 => CipherVariant::new(rc4_md5::Rc4Md5Cipher::new(key, iv, mode)),
_ => CipherVariant::new(openssl::OpenSSLCipher::new(t, key, iv, mode)),
}
@@ -383,8 +431,14 @@ mod test_cipher {
fn test_get_cipher() {
let key = CipherType::Aes128Cfb.bytes_to_key(b"PassWORD");
let iv = CipherType::Aes128Cfb.gen_init_vec();
let mut encryptor = with_type(CipherType::Aes128Cfb, &key[0..], &iv[0..], CryptoMode::Encrypt);
let mut decryptor = with_type(CipherType::Aes128Cfb, &key[0..], &iv[0..], CryptoMode::Decrypt);
let mut encryptor = with_type(CipherType::Aes128Cfb,
&key[0..],
&iv[0..],
CryptoMode::Encrypt);
let mut decryptor = with_type(CipherType::Aes128Cfb,
&key[0..],
&iv[0..],
CryptoMode::Decrypt);
let message = "HELLO WORLD";
let mut encrypted_msg = Vec::new();

View File

@@ -26,6 +26,8 @@ use std::convert::From;
use openssl::crypto::symm;
pub use self::cipher::{CipherType, Cipher, CipherVariant};
pub mod cipher;
pub mod openssl;
pub mod digest;
@@ -36,7 +38,7 @@ pub mod crypto;
#[derive(Clone, Copy)]
pub enum CryptoMode {
Encrypt,
Decrypt
Decrypt,
}
impl From<CryptoMode> for symm::Mode {

View File

@@ -35,6 +35,7 @@ use openssl::crypto::symm;
use openssl::crypto::hash;
pub struct OpenSSLCrypto {
cipher: symm::Type,
inner: symm::Crypter,
}
@@ -56,26 +57,38 @@ impl OpenSSLCrypto {
#[cfg(feature = "cipher-rc4")]
CipherType::Rc4 => symm::Type::RC4_128,
_ => panic!("Cipher type {:?} does not supported by OpenSSLCrypt yet", cipher_type),
_ => {
panic!("Cipher type {:?} does not supported by OpenSSLCrypt yet",
cipher_type)
}
};
let cipher = symm::Crypter::new(t);
cipher.init(From::from(mode), key, iv);
let key = cipher_type.bytes_to_key(key);
// Panic if error occurs
let cipher = symm::Crypter::new(t, From::from(mode), &key[..], Some(iv)).unwrap();
OpenSSLCrypto {
cipher: t,
inner: cipher,
}
}
pub fn update(&mut self, data: &[u8], out: &mut Vec<u8>) -> CipherResult<()> {
let output = self.inner.update(data);
out.extend_from_slice(&output);
let orig_length = out.len();
let least_reserved = data.len() + self.cipher.block_size();
out.resize(orig_length + least_reserved, 0);
let length = try!(self.inner.update(data, &mut out[orig_length..]));
out.resize(orig_length + length, 0);
Ok(())
}
pub fn finalize(&mut self, out: &mut Vec<u8>) -> CipherResult<()> {
let output = self.inner.finalize();
out.extend_from_slice(&output);
let orig_length = out.len();
let least_reserved = self.cipher.block_size();
out.resize(orig_length + least_reserved, 0);
let length = try!(self.inner.finalize(&mut out[orig_length..]));
out.resize(orig_length + length, 0);
Ok(())
}
}
@@ -113,9 +126,7 @@ pub struct OpenSSLCipher {
impl OpenSSLCipher {
pub fn new(cipher_type: cipher::CipherType, key: &[u8], iv: &[u8], mode: CryptoMode) -> OpenSSLCipher {
OpenSSLCipher {
worker: OpenSSLCrypto::new(cipher_type, &key[..], &iv[..], mode),
}
OpenSSLCipher { worker: OpenSSLCrypto::new(cipher_type, &key[..], &iv[..], mode) }
}
}
@@ -143,9 +154,7 @@ impl OpenSSLDigest {
digest::DigestType::Sha1 => hash::Type::SHA1,
};
OpenSSLDigest {
inner: hash::Hasher::new(t),
}
OpenSSLDigest { inner: hash::Hasher::new(t).unwrap() }
}
}
@@ -157,6 +166,7 @@ impl Digest for OpenSSLDigest {
}
fn digest(&mut self) -> Vec<u8> {
self.inner.finish()
// TODO: Check error
self.inner.finish().unwrap()
}
}

View File

@@ -22,8 +22,6 @@
#![crate_type = "lib"]
#![crate_name = "shadowsocks"]
#![feature(lookup_host)]
extern crate rustc_serialize as serialize;
#[macro_use]
extern crate log;
@@ -32,8 +30,6 @@ extern crate lru_cache;
extern crate byteorder;
extern crate rand;
extern crate coio;
extern crate crypto as rust_crypto;
extern crate ip;
extern crate openssl;
@@ -41,6 +37,11 @@ extern crate hyper;
extern crate url;
extern crate httparse;
extern crate futures;
extern crate futures_cpupool;
#[macro_use]
extern crate tokio_core;
extern crate libc;
pub const VERSION: &'static str = env!("CARGO_PKG_VERSION");

View File

@@ -19,13 +19,15 @@
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
use std::sync::Arc;
pub use self::roundrobin::RoundRobin;
use config::ServerConfig;
pub mod roundrobin;
pub trait LoadBalancer {
fn pick_server<'a>(&'a mut self) -> &'a ServerConfig;
pub trait LoadBalancer: Send + 'static {
fn pick_server(&mut self) -> Arc<ServerConfig>;
fn total(&self) -> usize;
}

View File

@@ -19,36 +19,40 @@
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
use std::sync::Arc;
use relay::loadbalancing::server::LoadBalancer;
use config::ServerConfig;
use config::{Config, ServerConfig};
#[derive(Clone)]
pub struct RoundRobin {
server: Vec<ServerConfig>,
config: Arc<Config>,
index: usize,
}
impl RoundRobin {
pub fn new(config: Vec<ServerConfig>) -> RoundRobin {
pub fn new(config: Arc<Config>) -> RoundRobin {
RoundRobin {
server: config,
config: config,
index: 0usize,
}
}
}
impl LoadBalancer for RoundRobin {
fn pick_server<'a>(&'a mut self) -> &'a ServerConfig {
if self.server.is_empty() {
fn pick_server(&mut self) -> Arc<ServerConfig> {
let server = &self.config.server;
if server.is_empty() {
panic!("No server");
}
let ref s = self.server[self.index];
self.index = (self.index + 1) % self.server.len();
s
let ref s = server[self.index];
self.index = (self.index + 1) % server.len();
s.clone()
}
fn total(&self) -> usize {
self.server.len()
self.config.server.len()
}
}

View File

@@ -21,9 +21,11 @@
//! Local side
use coio::Scheduler;
use std::sync::Arc;
use std::io;
use tokio_core::reactor::Core;
use relay::Relay;
use relay::tcprelay::local::TcpRelayLocal;
#[cfg(feature = "enable-udp")]
use relay::udprelay::local::UdpRelayLocal;
@@ -55,103 +57,18 @@ use config::Config;
/// ```
#[derive(Clone)]
pub struct RelayLocal {
enable_udp: bool,
enable_http: bool,
tcprelay: TcpRelayLocal,
#[cfg(feature = "enable-udp")]
udprelay: UdpRelayLocal,
config: Arc<Config>,
}
impl RelayLocal {
#[cfg(feature = "enable-udp")]
pub fn new(config: Config) -> RelayLocal {
let tcprelay = TcpRelayLocal::new(config.clone());
let udprelay = UdpRelayLocal::new(config.clone());
RelayLocal {
tcprelay: tcprelay,
udprelay: udprelay,
enable_udp: config.enable_udp,
enable_http: config.http_proxy.is_some(),
}
pub fn new(config: Arc<Config>) -> RelayLocal {
RelayLocal { config: config }
}
#[cfg(not(feature = "enable-udp"))]
pub fn new(config: Config) -> RelayLocal {
let tcprelay = TcpRelayLocal::new(config.clone());
RelayLocal {
tcprelay: tcprelay,
enable_udp: config.enable_udp,
enable_http: config.http_proxy.is_some(),
}
}
/// Global TCP work count
pub fn global_tcp_work_count() -> usize {
super::tcprelay::global_tcp_work_count()
}
/// Global HTTP work count
pub fn global_http_work_count() -> usize {
super::tcprelay::global_http_work_count()
}
}
impl Relay for RelayLocal {
#[cfg(not(feature = "enable-udp"))]
fn run(&self) {
if self.enable_udp {
warn!("UDP relay feature is disabled, recompile with feature=\"enable-udp\" to enable this feature");
}
let mut futs = Vec::new();
let tcprelay = self.tcprelay.clone();
let tcp_fut = Scheduler::spawn(move || {
info!("Enabled TCP relay");
tcprelay.run_tcp()
});
futs.push(tcp_fut);
let tcprelay = self.tcprelay.clone();
let tcp_fut = Scheduler::spawn(move || {
info!("Enabled HTTP relay");
tcprelay.run_http()
});
futs.push(tcp_fut);
for fut in futs {
fut.join().unwrap();
}
}
fn run(&self) {
let mut futs = Vec::new();
let tcprelay = self.tcprelay.clone();
let tcp_fut = Scheduler::spawn(move || {
info!("Enabled TCP relay");
tcprelay.run_tcp()
});
futs.push(tcp_fut);
if self.enable_udp {
let udprelay = self.udprelay.clone();
let udp_fut = Scheduler::spawn(move || {
info!("Enabled UDP relay");
udprelay.run()
});
futs.push(udp_fut);
}
let tcprelay = self.tcprelay.clone();
let tcp_fut = Scheduler::spawn(move || {
info!("Enabled HTTP relay");
tcprelay.run_http()
});
futs.push(tcp_fut);
for fut in futs {
fut.join().unwrap();
}
pub fn run(self) -> io::Result<()> {
let mut lp = try!(Core::new());
let handle = lp.handle();
let tcp_fut = TcpRelayLocal::new(self.config.clone()).run(handle.clone());
lp.run(tcp_fut)
}
}

View File

@@ -21,15 +21,9 @@
//! Relay server in local and server side implementations.
use std::io::{self, Read, Write};
use std::net::SocketAddr;
use std::mem;
pub use self::local::RelayLocal;
pub use self::server::RelayServer;
use ip::IpAddr;
mod tcprelay;
#[cfg(feature = "enable-udp")]
mod udprelay;
@@ -37,53 +31,3 @@ pub mod local;
pub mod server;
mod loadbalancing;
pub mod socks5;
pub trait Relay {
fn run(&self);
}
fn copy_once<R: Read, W: Write>(r: &mut R, w: &mut W) -> io::Result<usize> {
let mut buf: [u8; 4096] = unsafe { mem::uninitialized() };
let len = match r.read(&mut buf) {
Ok(0) => return Ok(0),
Ok(len) => len,
Err(e) => return Err(e),
};
w.write_all(&buf[..len]).and_then(|_| w.flush()).map(|_| len)
}
fn copy_exact<R: Read, W: Write>(r: &mut R, w: &mut W, len: usize) -> io::Result<()> {
let mut buf: [u8; 4096] = unsafe { mem::uninitialized() };
let mut remain = len;
while remain > 0 {
let bufl = if remain > buf.len() {
buf.len()
} else {
remain
};
let len = match r.read(&mut buf[..bufl]) {
Ok(0) => break,
Ok(len) => {
remain -= len;
len
}
Err(e) => return Err(e),
};
try!(w.write_all(&buf[..len]).and_then(|_| w.flush()));
}
if remain != 0 {
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "Unexpected Eof");
Err(err)
} else {
Ok(())
}
}
fn take_ip_addr(sockaddr: &SocketAddr) -> IpAddr {
match sockaddr {
&SocketAddr::V4(ref v4) => IpAddr::V4(*v4.ip()),
&SocketAddr::V6(ref v6) => IpAddr::V6(*v6.ip()),
}
}

View File

@@ -21,12 +21,14 @@
//! Server side
use coio::Scheduler;
use std::sync::Arc;
use std::io;
#[cfg(feature = "enable-udp")]
use relay::udprelay::server::UdpRelayServer;
use tokio_core::reactor::Core;
// #[cfg(feature = "enable-udp")]
// use relay::udprelay::server::UdpRelayServer;
use relay::tcprelay::server::TcpRelayServer;
use relay::Relay;
use config::Config;
/// Relay server running on server side.
@@ -55,66 +57,18 @@ use config::Config;
///
#[derive(Clone)]
pub struct RelayServer {
enable_udp: bool,
tcprelay: TcpRelayServer,
#[cfg(feature = "enable-udp")]
udprelay: UdpRelayServer,
config: Arc<Config>,
}
impl RelayServer {
#[cfg(feature = "enable-udp")]
pub fn new(config: Config) -> RelayServer {
let tcprelay = TcpRelayServer::new(config.clone());
let udprelay = UdpRelayServer::new(config.clone());
RelayServer {
tcprelay: tcprelay,
udprelay: udprelay,
enable_udp: config.enable_udp,
}
pub fn new(config: Arc<Config>) -> RelayServer {
RelayServer { config: config }
}
#[cfg(not(feature = "enable-udp"))]
pub fn new(config: Config) -> RelayServer {
let tcprelay = TcpRelayServer::new(config.clone());
RelayServer {
tcprelay: tcprelay,
enable_udp: config.enable_udp,
}
}
}
impl Relay for RelayServer {
#[cfg(feature = "enable-udp")]
fn run(&self) {
let mut futs = Vec::new();
let tcprelay = self.tcprelay.clone();
let tcp_fut = Scheduler::spawn(move || tcprelay.run());
info!("Enabled TCP relay");
futs.push(tcp_fut);
if self.enable_udp {
let udprelay = self.udprelay.clone();
let udp_fut = Scheduler::spawn(move || udprelay.run());
info!("Enabled UDP relay");
futs.push(udp_fut);
}
for fut in futs {
fut.join().unwrap();
}
}
#[cfg(not(feature = "enable-udp"))]
fn run(&self) {
if self.enable_udp {
warn!("UDP relay feature is disabled, recompile with feature=\"enable-udp\" to enable this feature");
}
let tcprelay = self.tcprelay.clone();
let fut = Scheduler::spawn(move || tcprelay.run());
info!("Enabled TCP relay");
fut.join().unwrap();
pub fn run(self, threads: usize) -> io::Result<()> {
let mut lp = try!(Core::new());
let handle = lp.handle();
let tcp_fut = TcpRelayServer::new(self.config.clone(), threads).run(handle.clone());
lp.run(tcp_fut)
}
}

View File

@@ -23,13 +23,17 @@
use std::fmt::{self, Debug, Formatter};
use std::net::{Ipv4Addr, Ipv6Addr, ToSocketAddrs, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::io::{self, Read, Write};
use std::io::{self, Cursor, Read, Write};
use std::vec;
use std::error;
use std::convert::From;
use byteorder::{ReadBytesExt, WriteBytesExt, BigEndian};
use futures::{self, Future, BoxFuture};
use tokio_core::io::{read_exact, write_all};
const SOCKS5_VERSION: u8 = 0x05;
pub const SOCKS5_AUTH_METHOD_NONE: u8 = 0x00;
@@ -179,6 +183,12 @@ impl From<io::Error> for Error {
}
}
impl From<Error> for io::Error {
fn from(err: Error) -> io::Error {
io::Error::new(io::ErrorKind::Other, err.message)
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub enum Address {
SocketAddress(SocketAddr),
@@ -186,16 +196,16 @@ pub enum Address {
}
impl Address {
#[inline]
pub fn read_from<R: Read + Sized>(reader: &mut R) -> Result<Address, Error> {
match parse_request_header(reader) {
Ok((_, addr)) => Ok(addr),
Err(err) => Err(err),
}
pub fn read_from<R>(stream: R) -> BoxFuture<(R, Address), Error>
where R: Read + Send + 'static
{
parse_request_header(stream)
}
#[inline]
pub fn write_to<W: Write + Sized>(&self, writer: &mut W) -> io::Result<()> {
pub fn write_to<W>(self, writer: W) -> BoxFuture<W, io::Error>
where W: Write + Send + 'static
{
write_addr(self, writer)
}
@@ -249,31 +259,46 @@ impl TcpRequestHeader {
}
}
#[inline]
pub fn read_from<R: Read>(stream: &mut R) -> Result<TcpRequestHeader, Error> {
let ver = try!(stream.read_u8());
if ver != SOCKS5_VERSION {
return Err(Error::new(Reply::ConnectionRefused, "Unsupported Socks version"));
}
pub fn read_from<R>(r: R) -> BoxFuture<(R, TcpRequestHeader), Error>
where R: Read + Send + 'static
{
read_exact(r, [0u8; 3])
.map_err(From::from)
.and_then(|(r, buf)| {
let ver = buf[0];
if ver != SOCKS5_VERSION {
return Err(Error::new(Reply::ConnectionRefused, "Unsupported Socks version"));
}
let cmd = try!(stream.read_u8());
let _ = try!(stream.read_u8());
let cmd = buf[1];
let command = match Command::from_u8(cmd) {
Some(c) => c,
None => return Err(Error::new(Reply::CommandNotSupported, "Unsupported command")),
};
Ok(TcpRequestHeader {
command: match Command::from_u8(cmd) {
Some(c) => c,
None => return Err(Error::new(Reply::CommandNotSupported, "Unsupported command")),
},
address: try!(Address::read_from(stream)),
})
Ok((r, command))
})
.and_then(|(r, command)| {
Address::read_from(r).map(move |(conn, address)| {
let header = TcpRequestHeader {
command: command,
address: address,
};
(conn, header)
})
})
.boxed()
}
#[inline]
pub fn write_to<W: Write + Sized>(&self, stream: &mut W) -> io::Result<()> {
try!(stream.write_all(&[SOCKS5_VERSION, self.command.as_u8(), 0x00]));
try!(self.address.write_to(stream));
pub fn write_to<W>(&self, w: W) -> BoxFuture<W, io::Error>
where W: Write + Send + 'static
{
let addr = self.address.clone();
Ok(())
write_all(w, [SOCKS5_VERSION, self.command.as_u8(), 0x00])
.and_then(move |(conn, _)| addr.write_to(conn))
.boxed()
}
#[inline]
@@ -296,28 +321,42 @@ impl TcpResponseHeader {
}
}
#[inline]
pub fn read_from<R: Read>(stream: &mut R) -> Result<TcpResponseHeader, Error> {
let ver = try!(stream.read_u8());
if ver != SOCKS5_VERSION {
return Err(Error::new(Reply::ConnectionRefused, "Unsupported Socks version"));
}
pub fn read_from<R>(r: R) -> BoxFuture<(R, TcpResponseHeader), Error>
where R: Read + Send + 'static
{
read_exact(r, [0u8; 3])
.map_err(From::from)
.and_then(|(r, buf)| {
let ver = buf[0];
let reply_code = buf[1];
let reply_code = try!(stream.read_u8());
let _ = try!(stream.read_u8());
if ver != SOCKS5_VERSION {
return Err(Error::new(Reply::ConnectionRefused, "Unsupported Socks version"));
}
Ok(TcpResponseHeader {
reply: Reply::from_u8(reply_code),
address: try!(Address::read_from(stream)),
})
Ok((r, reply_code))
})
.and_then(|(r, reply_code)| {
Address::read_from(r).map(move |(r, address)| {
let rep = TcpResponseHeader {
reply: Reply::from_u8(reply_code),
address: address,
};
(r, rep)
})
})
.boxed()
}
#[inline]
pub fn write_to<W: Write + Sized>(&self, stream: &mut W) -> io::Result<()> {
try!(stream.write_all(&[SOCKS5_VERSION, self.reply.as_u8(), 0x00]));
try!(self.address.write_to(stream));
Ok(())
pub fn write_to<W>(&self, w: W) -> BoxFuture<W, io::Error>
where W: Write + Send + 'static
{
let addr = self.address.clone();
write_all(w, [SOCKS5_VERSION, self.reply.as_u8(), 0x00])
.map(From::from)
.and_then(move |(w, _)| addr.write_to(w))
.boxed()
}
#[inline]
@@ -326,82 +365,134 @@ impl TcpResponseHeader {
}
}
#[inline]
fn parse_request_header<R: Read>(stream: &mut R) -> Result<(usize, Address), Error> {
let atyp = match stream.read_u8() {
Ok(atyp) => atyp,
Err(_) => return Err(Error::new(Reply::GeneralFailure, "Error while reading address type")),
};
fn parse_request_header<R>(stream: R) -> BoxFuture<(R, Address), Error>
where R: Read + Send + 'static
{
read_exact(stream, [0u8])
.map_err(|_| Error::new(Reply::GeneralFailure, "Error while reading address type"))
.and_then(|(conn, atyp)| {
match atyp[0] {
SOCKS5_ADDR_TYPE_IPV4 => {
let v4addr = read_exact(conn, [0u8; 6]).map_err(From::from);
v4addr.and_then(|(conn, v4addr)| {
let mut stream = Cursor::new(v4addr);
let v4addr = Ipv4Addr::new(try!(stream.read_u8()),
try!(stream.read_u8()),
try!(stream.read_u8()),
try!(stream.read_u8()));
let port = try!(stream.read_u16::<BigEndian>());
match atyp {
SOCKS5_ADDR_TYPE_IPV4 => {
let v4addr = Ipv4Addr::new(try!(stream.read_u8()),
try!(stream.read_u8()),
try!(stream.read_u8()),
try!(stream.read_u8()));
let port = try!(stream.read_u16::<BigEndian>());
Ok((7usize, Address::SocketAddress(SocketAddr::V4(SocketAddrV4::new(v4addr, port)))))
}
SOCKS5_ADDR_TYPE_IPV6 => {
let v6addr = Ipv6Addr::new(try!(stream.read_u16::<BigEndian>()),
try!(stream.read_u16::<BigEndian>()),
try!(stream.read_u16::<BigEndian>()),
try!(stream.read_u16::<BigEndian>()),
try!(stream.read_u16::<BigEndian>()),
try!(stream.read_u16::<BigEndian>()),
try!(stream.read_u16::<BigEndian>()),
try!(stream.read_u16::<BigEndian>()));
let port = try!(stream.read_u16::<BigEndian>());
Ok((19usize, Address::SocketAddress(SocketAddr::V6(SocketAddrV6::new(v6addr, port, 0, 0)))))
}
SOCKS5_ADDR_TYPE_DOMAIN_NAME => {
let addr_len = try!(stream.read_u8()) as usize;
let mut raw_addr = Vec::with_capacity(addr_len);
try!(stream.take(addr_len as u64).read_to_end(&mut raw_addr));
let port = try!(stream.read_u16::<BigEndian>());
let addr = match String::from_utf8(raw_addr) {
Ok(addr) => addr,
Err(..) => return Err(Error::new(Reply::GeneralFailure, "Invalid address encoding")),
};
Ok((4 + addr_len, Address::DomainNameAddress(addr, port)))
}
_ => {
// Address type not supported
Err(Error::new(Reply::AddressTypeNotSupported, "Not supported address type"))
}
}
}
#[inline]
fn write_addr<W: Write + Sized>(addr: &Address, buf: &mut W) -> io::Result<()> {
match addr {
&Address::SocketAddress(addr) => {
match addr {
SocketAddr::V4(addr) => {
try!(buf.write_all(&[SOCKS5_ADDR_TYPE_IPV4]));
try!(buf.write_all(&addr.ip().octets()));
Ok((conn, Address::SocketAddress(SocketAddr::V4(SocketAddrV4::new(v4addr, port)))))
})
.boxed()
}
SocketAddr::V6(addr) => {
try!(buf.write_u8(SOCKS5_ADDR_TYPE_IPV6));
for seg in &addr.ip().segments() {
try!(buf.write_u16::<BigEndian>(*seg));
}
SOCKS5_ADDR_TYPE_IPV6 => {
let v6addr = read_exact(conn, [0u8; 18]).map_err(From::from);
v6addr.and_then(|(conn, v6addr)| {
let mut stream = Cursor::new(v6addr);
let v6addr = Ipv6Addr::new(try!(stream.read_u16::<BigEndian>()),
try!(stream.read_u16::<BigEndian>()),
try!(stream.read_u16::<BigEndian>()),
try!(stream.read_u16::<BigEndian>()),
try!(stream.read_u16::<BigEndian>()),
try!(stream.read_u16::<BigEndian>()),
try!(stream.read_u16::<BigEndian>()),
try!(stream.read_u16::<BigEndian>()));
let port = try!(stream.read_u16::<BigEndian>());
Ok((conn, Address::SocketAddress(SocketAddr::V6(SocketAddrV6::new(v6addr, port, 0, 0)))))
})
.boxed()
}
SOCKS5_ADDR_TYPE_DOMAIN_NAME => {
let addr_len = read_exact(conn, [0u8]).map_err(From::from);
addr_len.and_then(|(conn, addr_len)| {
let addr_len = addr_len[0] as usize;
let raw_addr = read_exact(conn, vec![0u8; addr_len]).map_err(From::from);
raw_addr.and_then(|(conn, raw_addr)| {
let port = read_exact(conn, [0u8; 2]).map_err(From::from);
port.and_then(|(conn, port)| {
let mut stream = Cursor::new(port);
let port = try!(stream.read_u16::<BigEndian>());
let addr = match String::from_utf8(raw_addr) {
Ok(addr) => addr,
Err(..) => {
return Err(Error::new(Reply::GeneralFailure, "Invalid address encoding"))
}
};
Ok((conn, Address::DomainNameAddress(addr, port)))
})
})
})
.boxed()
}
_ => {
// Address type not supported
futures::failed(Error::new(Reply::AddressTypeNotSupported, "Not supported address type")).boxed()
}
}
try!(buf.write_u16::<BigEndian>(addr.port()));
})
.boxed()
}
fn write_addr<W>(addr: Address, w: W) -> BoxFuture<W, io::Error>
where W: Write + Send + 'static
{
match addr {
Address::SocketAddress(addr) => {
let write_addr = match addr {
SocketAddr::V4(addr) => {
write_all(w, [SOCKS5_ADDR_TYPE_IPV4])
.and_then(move |(w, _)| write_all(w, addr.ip().octets()))
.map(|(conn, _)| conn)
.boxed()
}
SocketAddr::V6(addr) => {
write_all(w, [SOCKS5_ADDR_TYPE_IPV6])
.and_then(move |(w, _)| {
let mut rbuf = [0u8; 16];
{
let mut buf = Cursor::new(&mut rbuf[..]);
for seg in &addr.ip().segments() {
try!(buf.write_u16::<BigEndian>(*seg));
}
}
Ok((w, rbuf))
})
.and_then(|(w, rbuf)| write_all(w, rbuf))
.map(|(conn, _)| conn)
.boxed()
}
};
write_addr.and_then(move |w| {
let mut rbuf = [0u8; 2];
{
let mut buf = Cursor::new(&mut rbuf[..]);
try!(buf.write_u16::<BigEndian>(addr.port()));
}
Ok((w, rbuf))
})
.and_then(|(w, rbuf)| write_all(w, rbuf))
.map(|(conn, _)| conn)
.boxed()
}
&Address::DomainNameAddress(ref dnaddr, port) => {
try!(buf.write_u8(SOCKS5_ADDR_TYPE_DOMAIN_NAME));
try!(buf.write_u8(dnaddr.len() as u8));
try!(buf.write_all(dnaddr[..].as_bytes()));
try!(buf.write_u16::<BigEndian>(port));
Address::DomainNameAddress(dnaddr, port) => {
futures::lazy(move || {
let mut buf = Vec::with_capacity(dnaddr.len() + 4);
try!(buf.write_u8(SOCKS5_ADDR_TYPE_DOMAIN_NAME));
try!(buf.write_u8(dnaddr.len() as u8));
try!(buf.write_all(dnaddr[..].as_bytes()));
try!(buf.write_u16::<BigEndian>(port));
Ok(buf)
})
.and_then(|buf| write_all(w, buf))
.map(|(conn, _)| conn)
.boxed()
}
}
Ok(())
}
#[inline]
@@ -432,25 +523,34 @@ impl HandshakeRequest {
HandshakeRequest { methods: methods }
}
pub fn read_from<R: Read>(stream: &mut R) -> io::Result<HandshakeRequest> {
let ver = try!(stream.read_u8());
if ver != SOCKS5_VERSION {
return Err(io::Error::new(io::ErrorKind::Other, "Invalid Socks5 version"));
}
pub fn read_from<R>(r: R) -> BoxFuture<(R, HandshakeRequest), io::Error>
where R: Read + Send + 'static
{
read_exact(r, [0u8, 0u8])
.and_then(|(r, buf)| {
let ver = buf[0];
let nmet = buf[1];
let nmet = try!(stream.read_u8());
if ver != SOCKS5_VERSION {
return Err(io::Error::new(io::ErrorKind::Other, "Invalid Socks5 version"));
}
let mut methods = Vec::new();
try!(stream.take(nmet as u64).read_to_end(&mut methods));
Ok(HandshakeRequest { methods: methods })
Ok((r, nmet))
})
.and_then(|(r, nmet)| {
read_exact(r, vec![0u8; nmet as usize])
.and_then(|(r, methods)| Ok((r, HandshakeRequest { methods: methods })))
})
.boxed()
}
pub fn write_to(&self, stream: &mut Write) -> io::Result<()> {
try!(stream.write_all(&[SOCKS5_VERSION, self.methods.len() as u8]));
try!(stream.write_all(&self.methods[..]));
Ok(())
pub fn write_to<W>(self, w: W) -> BoxFuture<W, io::Error>
where W: Write + Send + 'static
{
write_all(w, [SOCKS5_VERSION, self.methods.len() as u8])
.and_then(move |(w, _)| write_all(w, self.methods))
.map(|(w, _)| w)
.boxed()
}
}
@@ -469,19 +569,27 @@ impl HandshakeResponse {
HandshakeResponse { chosen_method: cm }
}
pub fn read_from<R: Read>(stream: &mut R) -> io::Result<HandshakeResponse> {
let ver = try!(stream.read_u8());
if ver != SOCKS5_VERSION {
return Err(io::Error::new(io::ErrorKind::Other, "Invalid Socks5 version"));
}
pub fn read_from<R>(r: R) -> BoxFuture<(R, HandshakeResponse), io::Error>
where R: Read + Send + 'static
{
read_exact(r, [0u8, 0u8])
.and_then(|(r, buf)| {
let ver = buf[0];
let met = buf[1];
let met = try!(stream.read_u8());
Ok(HandshakeResponse { chosen_method: met })
if ver != SOCKS5_VERSION {
Err(io::Error::new(io::ErrorKind::Other, "Invalid Socks5 version"))
} else {
Ok((r, HandshakeResponse { chosen_method: met }))
}
})
.boxed()
}
pub fn write_to(&self, stream: &mut Write) -> io::Result<()> {
stream.write_all(&[SOCKS5_VERSION, self.chosen_method])
pub fn write_to<W>(self, w: W) -> BoxFuture<W, io::Error>
where W: Write + Send + 'static
{
write_all(w, [SOCKS5_VERSION, self.chosen_method]).map(|(w, _)| w).boxed()
}
}
@@ -499,19 +607,29 @@ impl UdpAssociateHeader {
}
}
pub fn read_from<R: Read>(reader: &mut R) -> Result<UdpAssociateHeader, Error> {
let _ = try!(reader.read_u8());
let _ = try!(reader.read_u8());
let frag = try!(reader.read_u8());
Ok(UdpAssociateHeader::new(frag, try!(Address::read_from(reader))))
pub fn read_from<R>(r: R) -> BoxFuture<(R, UdpAssociateHeader), Error>
where R: Read + Send + 'static
{
read_exact(r, [0u8; 3])
.map_err(From::from)
.and_then(|(r, buf)| {
let frag = buf[2];
Address::read_from(r).map(move |(r, address)| {
let h = UdpAssociateHeader::new(frag, address);
(r, h)
})
})
.boxed()
}
pub fn write_to<W: Write + Sized>(&self, writer: &mut W) -> io::Result<()> {
try!(writer.write_all(&[0x00, 0x00, self.frag]));
try!(self.address.write_to(writer));
Ok(())
pub fn write_to<W>(&self, w: W) -> BoxFuture<W, io::Error>
where W: Write + Send + 'static
{
let addr = self.address.clone();
write_all(w, [0x00, 0x00, self.frag])
.map_err(From::from)
.and_then(move |(w, _)| addr.write_to(w))
.boxed()
}
pub fn len(&self) -> usize {

View File

@@ -21,10 +21,10 @@
/// Http Proxy
use std::io::{self, Write};
use std::io::{self, Read, Write};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::mem;
use hyper::server::response::Response;
use hyper::uri::RequestUri;
use hyper::header::Headers;
use hyper::status::StatusCode;
@@ -36,8 +36,13 @@ use httparse::{self, Request};
use url::Host;
use futures::{self, Future, BoxFuture, Poll};
use tokio_core::io::write_all;
use relay::socks5::Address;
#[derive(Debug)]
pub struct HttpRequest {
pub version: HttpVersion,
pub method: Method,
@@ -85,20 +90,27 @@ impl HttpRequest {
}
}
pub fn write_to<W: Write>(&self, w: &mut W) -> io::Result<()> {
try!(write!(w,
"{} {} {}\r\n",
self.method,
self.request_uri,
self.version));
pub fn write_to<W>(self, w: W) -> BoxFuture<W, io::Error>
where W: Write + Send + 'static
{
futures::lazy(move || {
let mut w = Vec::new();
try!(write!(w,
"{} {} {}\r\n",
self.method,
self.request_uri,
self.version));
for header in self.headers.iter() {
try!(write!(w, "{}: {}\r\n", header.name(), header.value_string()));
}
for header in self.headers.iter() {
try!(write!(w, "{}: {}\r\n", header.name(), header.value_string()));
}
try!(write!(w, "\r\n"));
Ok(())
try!(write!(w, "\r\n"));
Ok(w)
})
.and_then(|buf| write_all(w, buf))
.map(|(w, _)| w)
.boxed()
}
}
@@ -152,11 +164,93 @@ pub fn get_address(uri: &RequestUri) -> Result<Address, StatusCode> {
}
}
pub fn write_response(stream: &mut Write, status: StatusCode) -> io::Result<()> {
let mut headers = Headers::new();
let mut resp = Response::new(stream, &mut headers);
*resp.status_mut() = status;
try!(resp.start().and_then(|r| r.end()));
Ok(())
pub fn write_response<W>(w: W, version: HttpVersion, status: StatusCode) -> BoxFuture<W, io::Error>
where W: Write + Send + 'static
{
let buf = format!("{} {}\r\n\r\n", version, status);
write_all(w, buf.into_bytes()).map(|(w, _)| w).boxed()
}
/// HTTP Client
pub enum RequestReader<R>
where R: Read
{
Pending { r: R, buf: Vec<u8> },
Empty,
}
impl<R> RequestReader<R>
where R: Read
{
pub fn new(r: R) -> RequestReader<R> {
RequestReader::with_buf(r, Vec::new())
}
pub fn with_buf(r: R, buf: Vec<u8>) -> RequestReader<R> {
RequestReader::Pending { r: r, buf: buf }
}
}
impl<R> Future for RequestReader<R>
where R: Read
{
type Item = (R, HttpRequest, Vec<u8>);
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut lbuf = [0u8; 4096];
let (req, len) = match self {
&mut RequestReader::Pending { ref mut r, ref mut buf } => {
let mut http_req = None;
let mut total_len = 0;
loop {
let n = try_nb!(r.read(&mut lbuf));
buf.extend_from_slice(&lbuf[..n]);
// Maximum 128 headers
let mut headers = [httparse::EMPTY_HEADER; 128];
let headers_ptr = &headers as *const _;
let mut req = Request::new(&mut headers);
match req.parse(&mut buf[..]) {
Ok(httparse::Status::Partial) => {
if n == 0 {
// Already EOF!
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "Unexpected Eof");
return Err(err);
}
}
Ok(httparse::Status::Complete(len)) => {
total_len = len;
// Make borrow checker happy
let headers_ref = unsafe { &*headers_ptr };
let hreq = match HttpRequest::from_raw(&req, headers_ref) {
Ok(r) => r,
Err(err) => {
error!("HttpRequest::from_raw: {}", err);
let err = io::Error::new(io::ErrorKind::Other, "Hyper error");
return Err(err);
}
};
http_req = Some(hreq);
break;
}
Err(err) => {
error!("Request parse: {:?}", err);
let err = io::Error::new(io::ErrorKind::Other, "Hyper error");
return Err(err);
}
}
}
(http_req.unwrap(), total_len)
}
&mut RequestReader::Empty => panic!("poll a RequestReader after it's done"),
};
match mem::replace(self, RequestReader::Empty) {
RequestReader::Pending { r, buf } => Ok((r, req, buf[len..].to_vec()).into()),
RequestReader::Empty => unreachable!(),
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -21,184 +21,193 @@
//! TcpRelay implementation
use std::net::SocketAddr;
use std::io::{self, Read, Write};
use std::sync::Arc;
use std::mem;
use crypto::cipher::{self, CipherType};
use crypto::cipher;
use crypto::CryptoMode;
use relay::socks5::Address;
use config::ServerConfig;
use coio::net::TcpStream;
use tokio_core::net::TcpStream;
use tokio_core::reactor::Handle;
use tokio_core::io::{read_exact, write_all, flush};
use tokio_core::io::{ReadHalf, WriteHalf};
use tokio_core::io::Io;
use self::stream::{DecryptedReader, EncryptedWriter};
use futures::{Future, BoxFuture, Poll};
mod cached_dns;
use self::stream::{EncryptedWriter, DecryptedReader};
// use coio::net::TcpStream;
// use self::stream::{DecryptedReader, EncryptedWriter};
// mod cached_dns;
pub mod local;
pub mod server;
mod stream;
mod http;
fn connect_proxy_server(server_addr: &SocketAddr,
encrypt_method: CipherType,
pwd: &[u8],
relay_addr: &Address)
-> io::Result<(DecryptedReader<TcpStream>, EncryptedWriter<TcpStream>)> {
let mut remote_stream = try!(TcpStream::connect(&server_addr));
type DecryptedHalf = DecryptedReader<ReadHalf<TcpStream>>;
type EncryptedHalf = EncryptedWriter<WriteHalf<TcpStream>>;
// Encrypt data to remote server
fn connect_proxy_server(handle: &Handle,
svr_cfg: Arc<ServerConfig>,
relay_addr: Address)
-> BoxFuture<(DecryptedHalf, EncryptedHalf), io::Error> {
TcpStream::connect(&svr_cfg.addr, handle)
.and_then(move |remote_stream| {
let (r, w) = remote_stream.split();
// Send initialize vector to remote and create encryptor
let mut encrypt_stream = {
let local_iv = encrypt_method.gen_init_vec();
trace!("Going to send initialize vector: {:?}", local_iv);
let encryptor = cipher::with_type(encrypt_method, pwd, &local_iv[..], CryptoMode::Encrypt);
if let Err(err) = remote_stream.write_all(&local_iv[..]) {
error!("Error occurs while writing initialize vector: {}", err);
return Err(err);
// Encrypt data to remote server
// Send initialize vector to remote and create encryptor
let local_iv = svr_cfg.method.gen_init_vec();
trace!("Going to send initialize vector: {:?}", local_iv);
write_all(w, local_iv)
.and_then(move |(w, local_iv)| {
let encryptor = cipher::with_type(svr_cfg.method,
svr_cfg.password.as_bytes(),
&local_iv[..],
CryptoMode::Encrypt);
Ok((svr_cfg, r, EncryptedWriter::new(w, encryptor)))
})
.and_then(|(svr_cfg, r, enc_w)| {
trace!("Got encrypt stream and going to send addr: {:?}",
relay_addr);
// Send relay address to remote
relay_addr.write_to(Vec::new())
.and_then(|addr_buf| {
write_all(enc_w, addr_buf)
.and_then(|(enc_w, _)| flush(enc_w))
.and_then(|enc_w| Ok((svr_cfg, r, enc_w)))
})
})
})
.and_then(|(svr_cfg, r, enc_w)| {
// Decrypt data from remote server
let iv_len = svr_cfg.method.iv_size();
read_exact(r, vec![0u8; iv_len]).and_then(move |(r, remote_iv)| {
trace!("Got initialize vector {:?}", remote_iv);
let decryptor = cipher::with_type(svr_cfg.method,
svr_cfg.password.as_bytes(),
&remote_iv[..],
CryptoMode::Decrypt);
let decrypt_stream = DecryptedReader::new(r, decryptor);
trace!("Finished creating remote encrypt stream pair");
Ok((decrypt_stream, enc_w))
})
})
.boxed()
}
/// Copy exactly N bytes
pub enum CopyExact<R, W>
where R: Read,
W: Write
{
Pending {
reader: R,
writer: W,
buf: [u8; 4096],
remain: usize,
pos: usize,
cap: usize,
},
Empty,
}
impl<R, W> CopyExact<R, W>
where R: Read,
W: Write
{
pub fn new(r: R, w: W, amt: usize) -> CopyExact<R, W> {
CopyExact::Pending {
reader: r,
writer: w,
buf: [0u8; 4096],
remain: amt,
pos: 0,
cap: 0,
}
let remote_writer = match remote_stream.try_clone() {
Ok(s) => s,
Err(err) => {
error!("Error occurs while cloning remote stream: {}", err);
return Err(err);
}
};
EncryptedWriter::new(remote_writer, encryptor)
};
trace!("Got encrypt stream and going to send addr: {:?}",
relay_addr);
// Send relay address to remote
let mut addr_buf = Vec::new();
try!(relay_addr.write_to(&mut addr_buf));
if let Err(err) = encrypt_stream.write_all(&addr_buf).and_then(|_| encrypt_stream.flush()) {
error!("Error occurs while writing address: {}", err);
return Err(err);
}
}
// Decrypt data from remote server
impl<R, W> Future for CopyExact<R, W>
where R: Read,
W: Write
{
type Item = (R, W);
type Error = io::Error;
let remote_iv = {
let mut iv = Vec::with_capacity(encrypt_method.block_size());
unsafe {
iv.set_len(encrypt_method.block_size());
}
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self {
&mut CopyExact::Empty => panic!("poll after CopyExact is finished"),
&mut CopyExact::Pending { ref mut reader,
ref mut writer,
ref mut buf,
ref mut remain,
ref mut pos,
ref mut cap } => {
loop {
// If our buffer is empty, then we need to read some data to
// continue.
if *pos == *cap && *remain != 0 {
let buf_len = if *remain > buf.len() {
buf.len()
} else {
*remain
};
let n = try_nb!(reader.read(&mut buf[..buf_len]));
if n == 0 {
// Unexpected EOF!
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "Unexpected Eof");
return Err(err);
} else {
*pos = 0;
*cap = n;
*remain -= n;
}
}
let mut total_len = 0;
while total_len < encrypt_method.block_size() {
match remote_stream.read(&mut iv[total_len..]) {
Ok(0) => {
error!("Unexpected EOF while reading initialize vector");
debug!("Already read: {:?}", &iv[..total_len]);
// If our buffer has some data, let's write it out!
while *pos < *cap {
let i = try_nb!(writer.write(&buf[*pos..*cap]));
*pos += i;
}
let err = io::Error::new(io::ErrorKind::UnexpectedEof,
"Unexpected EOF while reading initialize vector");
return Err(err);
}
Ok(n) => total_len += n,
Err(err) => {
error!("Error while reading initialize vector: {}", err);
return Err(err);
// If we've written al the data and we've seen EOF, flush out the
// data and finish the transfer.
// done with the entire transfer.
if *pos == *cap && *remain == 0 {
try_nb!(writer.flush());
break; // The only path to execute the following logic
}
}
}
}
iv
};
trace!("Got initialize vector {:?}", remote_iv);
let decryptor = cipher::with_type(encrypt_method, pwd, &remote_iv[..], CryptoMode::Decrypt);
let decrypt_stream = DecryptedReader::new(remote_stream, decryptor);
trace!("Finished creating remote encrypt stream pair");
Ok((decrypt_stream, encrypt_stream))
}
#[cfg(debug_assertions)]
mod stat {
use std::sync::atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT};
static GLOBAL_TCP_WORK_COUNT: AtomicUsize = ATOMIC_USIZE_INIT;
static GLOBAL_HTTP_WORK_COUNT: AtomicUsize = ATOMIC_USIZE_INIT;
pub fn global_tcp_work_count_add() {
GLOBAL_TCP_WORK_COUNT.fetch_add(1, Ordering::Relaxed);
}
pub fn global_tcp_work_count_sub() {
GLOBAL_TCP_WORK_COUNT.fetch_sub(1, Ordering::Relaxed);
}
pub fn global_tcp_work_count_get() -> usize {
GLOBAL_TCP_WORK_COUNT.load(Ordering::Relaxed)
}
pub fn global_http_work_count_add() {
GLOBAL_HTTP_WORK_COUNT.fetch_add(1, Ordering::Relaxed);
}
pub fn global_http_work_count_sub() {
GLOBAL_HTTP_WORK_COUNT.fetch_sub(1, Ordering::Relaxed);
}
pub fn global_http_work_count_get() -> usize {
GLOBAL_HTTP_WORK_COUNT.load(Ordering::Relaxed)
match mem::replace(self, CopyExact::Empty) {
CopyExact::Pending { reader, writer, .. } => Ok((reader, writer).into()),
CopyExact::Empty => unreachable!(),
}
}
}
#[cfg(not(debug_assertions))]
mod stat {
pub fn global_tcp_work_count_add() {}
pub fn global_tcp_work_count_sub() {}
pub fn global_tcp_work_count_get() -> usize {
0
}
pub fn global_http_work_count_add() {}
pub fn global_http_work_count_sub() {}
pub fn global_http_work_count_get() -> usize {
0
}
}
struct TcpWorkCounter;
impl TcpWorkCounter {
fn new() -> TcpWorkCounter {
stat::global_tcp_work_count_add();
TcpWorkCounter
}
}
impl Drop for TcpWorkCounter {
fn drop(&mut self) {
stat::global_tcp_work_count_sub();
}
}
struct HttpWorkCounter;
impl HttpWorkCounter {
fn new() -> HttpWorkCounter {
stat::global_http_work_count_add();
HttpWorkCounter
}
}
impl Drop for HttpWorkCounter {
fn drop(&mut self) {
stat::global_http_work_count_sub();
}
}
/// Get total TCP relay work count
pub fn global_tcp_work_count() -> usize {
stat::global_tcp_work_count_get()
}
/// Get total HTTP relay work count
pub fn global_http_work_count() -> usize {
stat::global_http_work_count_get()
pub fn copy_exact<R, W>(r: R, w: W, amt: usize) -> CopyExact<R, W>
where R: Read,
W: Write
{
CopyExact::new(r, w, amt)
}

View File

@@ -21,314 +21,213 @@
//! TcpRelay server that running on the server side
use std::io;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::Arc;
use std::io::{self, Read, Write, BufReader};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::collections::HashSet;
use config::{Config, ServerConfig};
use crypto::CryptoMode;
use crypto::cipher;
use super::stream::{EncryptedWriter, DecryptedReader};
use relay::socks5::Address;
use futures::{self, Future, BoxFuture};
use futures::stream::Stream;
use futures_cpupool::CpuPool;
use tokio_core::reactor::Handle;
use tokio_core::net::{TcpStream, TcpListener};
use tokio_core::io::Io;
use tokio_core::io::{ReadHalf, WriteHalf};
use tokio_core::io::{read_exact, write_all, copy};
use ip::IpAddr;
use coio::Scheduler;
use coio::net::{TcpListener, TcpStream, Shutdown};
type ClientRead = ReadHalf<TcpStream>;
type ClientWrite = WriteHalf<TcpStream>;
use config::{Config, ServerConfig};
use relay::socks5;
use relay::tcprelay::cached_dns::CachedDns;
use relay::tcprelay::stream::{DecryptedReader, EncryptedWriter};
use crypto::cipher;
use crypto::CryptoMode;
type EncryptedHalf = EncryptedWriter<ClientWrite>;
type DecryptedHalf = DecryptedReader<ClientRead>;
#[derive(Clone)]
/// TCP Relay backend
pub struct TcpRelayServer {
config: Config,
config: Arc<Config>,
cpu_pool: CpuPool,
}
impl TcpRelayServer {
pub fn new(c: Config) -> TcpRelayServer {
if c.server.is_empty() {
panic!("You have to provide a server configuration");
/// Creates an instance
pub fn new(config: Arc<Config>, threads: usize) -> TcpRelayServer {
TcpRelayServer {
config: config,
cpu_pool: CpuPool::new(threads),
}
TcpRelayServer { config: c }
}
fn accept_loop(s: ServerConfig, forbidden_ip: Arc<HashSet<IpAddr>>) {
let acceptor = TcpListener::bind(&(&s.addr[..], s.port))
.unwrap_or_else(|err| panic!("Failed to bind a TCP socket: {}", err));
fn handshake((r, w): (ClientRead, ClientWrite),
svr_cfg: Arc<ServerConfig>)
-> BoxFuture<(DecryptedHalf, EncryptedHalf), io::Error> {
let iv_len = svr_cfg.method.iv_size();
read_exact(r, vec![0u8; iv_len])
.and_then(move |(r, iv)| {
trace!("Got handshake iv: {:?}", iv);
let decryptor = cipher::with_type(svr_cfg.method,
svr_cfg.password.as_bytes(),
&iv[..],
CryptoMode::Decrypt);
let decrypt_stream = DecryptedReader::new(r, decryptor);
info!("Shadowsocks listening on {}:{}", s.addr, s.port);
Ok((svr_cfg, decrypt_stream))
})
.and_then(|(svr_cfg, enc_r)| {
let iv = svr_cfg.method.gen_init_vec();
trace!("Going to send handshake iv: {:?}", iv);
write_all(w, iv).and_then(move |(w, iv)| {
let encryptor = cipher::with_type(svr_cfg.method,
svr_cfg.password.as_bytes(),
&iv[..],
CryptoMode::Encrypt);
let encrypt_stream = EncryptedWriter::new(w, encryptor);
let dnscache_arc = Arc::new(CachedDns::with_capacity(s.dns_cache_capacity));
Ok((enc_r, encrypt_stream))
})
})
.boxed()
}
let pwd = s.method.bytes_to_key(s.password.as_bytes());
let timeout = s.timeout;
let method = s.method;
fn resolve_address(addr: Address, cpu_pool: CpuPool) -> BoxFuture<SocketAddr, io::Error> {
match addr {
Address::SocketAddress(addr) => futures::finished(addr).boxed(),
Address::DomainNameAddress(dname, port) => {
cpu_pool.spawn(futures::lazy(move || {
let dname = format!("{}:{}", dname, port);
let mut addrs = try!(dname.to_socket_addrs());
addrs.next().ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to resolve domain"))
}))
.boxed()
}
}
}
info!("Method {}, Timeout: {:?}", method, timeout);
fn resolve_remote(cpu_pool: CpuPool,
addr: Address,
forbidden_ip: Arc<HashSet<IpAddr>>)
-> BoxFuture<SocketAddr, io::Error> {
TcpRelayServer::resolve_address(addr, cpu_pool)
.and_then(move |addr| {
trace!("Resolved address as {}", addr);
let ipaddr = match addr.clone() {
SocketAddr::V4(v4) => IpAddr::V4(v4.ip().clone()),
SocketAddr::V6(v6) => IpAddr::V6(v6.ip().clone()),
};
for s in acceptor.incoming() {
let mut stream = match s {
Ok((s, addr)) => {
debug!("Got connection from {:?}", addr);
s
if forbidden_ip.contains(&ipaddr) {
info!("{} has been forbidden", ipaddr);
let err = io::Error::new(io::ErrorKind::Other, "Forbidden IP");
Err(err)
} else {
Ok(addr)
}
})
.boxed()
}
fn connect_remote(cpu_pool: CpuPool,
handle: Handle,
addr: Address,
forbidden_ip: Arc<HashSet<IpAddr>>)
-> Box<Future<Item = TcpStream, Error = io::Error>> {
trace!("Connecting to remote {}", addr);
Box::new(TcpRelayServer::resolve_remote(cpu_pool, addr, forbidden_ip)
.and_then(move |addr| TcpStream::connect(&addr, &handle)))
}
pub fn handle_client(handle: &Handle,
cpu_pool: CpuPool,
s: TcpStream,
svr_cfg: Arc<ServerConfig>,
forbidden_ip: Arc<HashSet<IpAddr>>)
-> io::Result<()> {
let peer_addr = try!(s.peer_addr());
trace!("Got connection from {}", peer_addr);
let cloned_handle = handle.clone();
let fut = futures::lazy(|| Ok(s.split()))
.and_then(move |(r, w)| TcpRelayServer::handshake((r, w), svr_cfg))
.and_then(|(r, w)| Address::read_from(r).map(|(r, addr)| (r, w, addr)).map_err(From::from))
.and_then(move |(r, w, addr)| {
info!("Connecting {}", addr);
let cloned_addr = addr.clone();
TcpRelayServer::connect_remote(cpu_pool, cloned_handle, addr, forbidden_ip)
.map(|svr_s| (svr_s, r, w, cloned_addr))
})
.and_then(|(svr_s, r, w, addr)| {
let (svr_r, svr_w) = svr_s.split();
let c2s = copy(r, svr_w);
let s2c = copy(svr_r, w);
c2s.join(s2c)
.and_then(move |(c2s_amt, s2c_amt)| {
trace!("Relayed {} client -> remote {}bytes", addr, c2s_amt);
trace!("Relayed {} client <- remote {}bytes", addr, s2c_amt);
Ok(())
})
});
handle.spawn(fut.then(|res| {
match res {
Ok(..) => Ok(()),
Err(err) => {
panic!("Error occurs while accepting: {}", err);
error!("Failed to handle client: {}", err);
Err(())
}
}
}));
Ok(())
}
/// Runs the server
pub fn run(self, handle: Handle) -> Box<Future<Item = (), Error = io::Error>> {
let mut fut: Option<Box<Future<Item = (), Error = io::Error>>> = None;
for svr_cfg in &self.config.server {
let listener = {
let addr = &svr_cfg.addr;
let listener = TcpListener::bind(addr, &handle).unwrap();
trace!("ShadowSocks TCP Listening on {}", addr);
listener
};
if let Err(err) = stream.set_read_timeout(timeout) {
error!("Failed to set read timeout: {:?}", err);
continue;
}
let svr_cfg = svr_cfg.clone();
let handle = handle.clone();
let forbidden_ip = self.config.forbidden_ip.clone();
let cpu_pool = self.cpu_pool.clone();
let listening = listener.incoming()
.for_each(move |(socket, addr)| {
let server_cfg = svr_cfg.clone();
let forbidden_ip = forbidden_ip.clone();
let cpu_pool = cpu_pool.clone();
if let Err(err) = stream.set_nodelay(true) {
error!("Failed to set no delay: {}", err);
continue;
}
let pwd = pwd.clone();
let encrypt_method = method;
let dnscache = dnscache_arc.clone();
let forbidden_ip = forbidden_ip.clone();
Scheduler::spawn(move || {
let remote_iv = {
let mut iv = Vec::with_capacity(encrypt_method.block_size());
unsafe {
iv.set_len(encrypt_method.block_size());
}
let mut total_len = 0;
while total_len < encrypt_method.block_size() {
match stream.read(&mut iv[total_len..]) {
Ok(0) => {
error!("Unexpected EOF while reading initialize vector");
return;
}
Ok(n) => total_len += n,
Err(err) => {
error!("Error while reading initialize vector: {}", err);
return;
}
}
}
iv
};
let decryptor = cipher::with_type(encrypt_method,
&pwd[..],
&remote_iv[..],
CryptoMode::Decrypt);
let mut client_writer = match stream.try_clone() {
Ok(s) => s,
Err(err) => {
error!("Error occurs while cloning client stream: {}", err);
return;
}
};
let client_reader = stream;
let iv = encrypt_method.gen_init_vec();
let encryptor = cipher::with_type(encrypt_method, &pwd[..], &iv[..], CryptoMode::Encrypt);
if let Err(err) = client_writer.write_all(&iv[..]) {
error!("Error occurs while writing initialize vector: {}", err);
return;
}
let mut decrypt_stream = DecryptedReader::new(client_reader, decryptor);
let addr = match socks5::Address::read_from(&mut decrypt_stream) {
Ok(addr) => addr,
Err(err) => {
error!("Error occurs while parsing request header, maybe wrong crypto \
method or password: {}",
err);
return;
}
};
info!("Connecting to {}", addr);
let remote_stream = match &addr {
&socks5::Address::SocketAddress(ref addr) => {
if forbidden_ip.contains(&::relay::take_ip_addr(addr)) {
info!("{} has been blocked by `forbidden_ip`", addr);
return;
}
match TcpStream::connect(&addr) {
Ok(stream) => stream,
Err(err) => {
error!("Unable to connect {:?}: {}", addr, err);
return;
}
}
}
&socks5::Address::DomainNameAddress(ref dname, ref port) => {
let addrs = match dnscache.resolve(&dname) {
Some(addrs) => addrs,
None => return,
};
let processing = || {
let mut last_err: Option<io::Result<TcpStream>> = None;
for addr in addrs.into_iter() {
let addr = match addr {
SocketAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr.ip().clone(), *port)),
SocketAddr::V6(addr) => {
SocketAddr::V6(SocketAddrV6::new(addr.ip().clone(),
*port,
addr.flowinfo(),
addr.scope_id()))
}
};
if forbidden_ip.contains(&::relay::take_ip_addr(&addr)) {
info!("{} has been blocked by `forbidden_ip`", addr);
last_err = Some(Err(io::Error::new(io::ErrorKind::Other,
"Blocked by `forbidden_ip`")));
continue;
}
match TcpStream::connect(addr) {
Ok(stream) => return Ok(stream),
Err(err) => {
error!("Unable to connect {:?}: {}", addr, err);
last_err = Some(Err(err));
}
}
}
last_err.unwrap()
};
match processing() {
Ok(s) => s,
Err(_) => return,
}
}
};
let mut remote_writer = match remote_stream.try_clone() {
Ok(s) => s,
Err(err) => {
error!("Error occurs while cloning remote stream: {}", err);
return;
}
};
let addr_cloned = addr.clone();
Scheduler::spawn(move || {
let mut remote_reader = BufReader::new(remote_stream);
let mut encrypt_stream = EncryptedWriter::new(client_writer, encryptor);
match ::relay::copy_once(&mut remote_reader, &mut encrypt_stream) {
Ok(0) => {}
Ok(n) => {
let remote_addr = encrypt_stream.get_ref().peer_addr().unwrap();
let client_addr = remote_reader.get_ref().peer_addr().unwrap();
trace!("{} local <- remote: relayed {} bytes from {} to {}",
addr,
n,
remote_addr,
client_addr);
loop {
match ::relay::copy_once(&mut remote_reader, &mut encrypt_stream) {
Ok(0) => {
trace!("{} local <- remote: EOF", addr);
break;
}
Ok(n) => {
trace!("{} local <- remote: relayed {} bytes from {} to {}",
addr,
n,
remote_addr,
client_addr)
}
Err(err) => {
error!("{} local <- remote: {}", addr, err);
break;
}
}
}
}
Err(err) => {
error!("{} local <- remote: {}", addr, err);
}
}
debug!("{} local <- remote is closing", addr);
let _ = encrypt_stream.get_mut().shutdown(Shutdown::Both);
let _ = remote_reader.get_mut().shutdown(Shutdown::Both);
trace!("Got connection, addr: {}", addr);
trace!("Picked proxy server: {:?}", server_cfg);
TcpRelayServer::handle_client(&handle, cpu_pool, socket, server_cfg, forbidden_ip)
})
.map_err(|err| {
error!("Server run failed: {}", err);
err
});
Scheduler::spawn(move || {
match ::relay::copy_once(&mut decrypt_stream, &mut remote_writer) {
Ok(0) => {}
Ok(n) => {
let remote_addr = remote_writer.peer_addr().unwrap();
let client_addr = decrypt_stream.get_ref().peer_addr().unwrap();
debug!("{} local -> remote: relayed {} bytes from {} to {}",
addr_cloned,
n,
remote_addr,
client_addr);
loop {
match ::relay::copy_once(&mut decrypt_stream, &mut remote_writer) {
Ok(0) => {
trace!("{} local -> remote: EOF", addr_cloned);
break;
}
Ok(n) => {
debug!("{} local -> remote: relayed {} bytes from {} to {}",
addr_cloned,
n,
remote_addr,
client_addr);
}
Err(err) => {
error!("{} local -> remote: {}", addr_cloned, err);
break;
}
}
}
}
Err(err) => {
error!("{} local -> remote: {}", addr_cloned, err);
}
}
debug!("{} local -> remote is closing", addr_cloned);
let _ = remote_writer.shutdown(Shutdown::Both);
let _ = decrypt_stream.get_mut().shutdown(Shutdown::Both);
});
});
}
}
}
impl TcpRelayServer {
pub fn run(&self) {
let mut futs = Vec::with_capacity(self.config.server.len());
let forbidden_ip = Arc::new(self.config.forbidden_ip.clone());
for s in &self.config.server {
let s = s.clone();
let forbidden_ip = forbidden_ip.clone();
let fut = Scheduler::spawn(move || {
TcpRelayServer::accept_loop(s, forbidden_ip);
});
futs.push(fut);
}
for fut in futs {
fut.join().unwrap();
fut = Some(match fut.take() {
Some(fut) => Box::new(fut.join(listening).map(|_| ())) as Box<Future<Item = (), Error = io::Error>>,
None => Box::new(listening) as Box<Future<Item = (), Error = io::Error>>,
})
}
fut.expect("Must have at least one server")
}
}

View File

@@ -19,12 +19,17 @@
// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#![allow(dead_code)]
use std::io::{self, Read, BufRead, Write};
use std::cmp;
use crypto::cipher::{Cipher, CipherVariant};
use crypto::{Cipher, CipherVariant};
pub struct DecryptedReader<R: Read> {
/// Reader wrapper that will decrypt data automatically
pub struct DecryptedReader<R>
where R: Read + 'static
{
reader: R,
buffer: Vec<u8>,
cipher: CipherVariant,
@@ -34,7 +39,9 @@ pub struct DecryptedReader<R: Read> {
const BUFFER_SIZE: usize = 2048;
impl<R: Read> DecryptedReader<R> {
impl<R> DecryptedReader<R>
where R: Read + 'static
{
pub fn new(r: R, cipher: CipherVariant) -> DecryptedReader<R> {
DecryptedReader {
reader: r,
@@ -59,18 +66,20 @@ impl<R: Read> DecryptedReader<R> {
&mut self.reader
}
// /// Unwraps this `DecryptedReader`, returning the underlying reader.
// ///
// /// The internal buffer is flushed before returning the reader. Any leftover
// /// data in the read buffer is lost.
// pub fn into_inner(self) -> R {
// self.reader
// }
/// Unwraps this `DecryptedReader`, returning the underlying reader.
///
/// The internal buffer is flushed before returning the reader. Any leftover
/// data in the read buffer is lost.
pub fn into_inner(self) -> R {
self.reader
}
}
impl<R: Read> BufRead for DecryptedReader<R> {
impl<R> BufRead for DecryptedReader<R>
where R: Read + 'static
{
fn fill_buf<'b>(&'b mut self) -> io::Result<&'b [u8]> {
while self.pos == self.buffer.len() {
while self.pos >= self.buffer.len() {
if self.sent_final {
return Ok(&[]);
}
@@ -81,21 +90,17 @@ impl<R: Read> BufRead for DecryptedReader<R> {
Ok(0) => {
// EOF
try!(self.cipher
.finalize(&mut self.buffer)
.map_err(|err| io::Error::new(io::ErrorKind::Other,
err.desc)));
.finalize(&mut self.buffer));
self.sent_final = true;
},
}
Ok(l) => {
try!(self.cipher
.update(&incoming[..l], &mut self.buffer)
.map_err(|err| io::Error::new(io::ErrorKind::Other,
err.desc)));
},
.update(&incoming[..l], &mut self.buffer));
}
Err(err) => {
return Err(err);
}
};
}
self.pos = 0;
}
@@ -108,7 +113,9 @@ impl<R: Read> BufRead for DecryptedReader<R> {
}
}
impl<R: Read> Read for DecryptedReader<R> {
impl<R> Read for DecryptedReader<R>
where R: Read + 'static
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let nread = {
let mut available = try!(self.fill_buf());
@@ -119,13 +126,19 @@ impl<R: Read> Read for DecryptedReader<R> {
}
}
pub struct EncryptedWriter<W: Write> {
/// Writer wrapper that will encrypt data automatically
pub struct EncryptedWriter<W>
where W: Write + 'static
{
writer: W,
cipher: CipherVariant,
buffer: Vec<u8>,
}
impl<W: Write> EncryptedWriter<W> {
impl<W> EncryptedWriter<W>
where W: Write + 'static
{
/// Creates a new EncryptedWriter
pub fn new(w: W, cipher: CipherVariant) -> EncryptedWriter<W> {
EncryptedWriter {
writer: w,
@@ -134,21 +147,20 @@ impl<W: Write> EncryptedWriter<W> {
}
}
/// Finalize the cipher, which will writes the final block into buffer
pub fn finalize(&mut self) -> io::Result<()> {
self.buffer.clear();
match self.cipher.finalize(&mut self.buffer) {
Ok(..) => {
self.writer.write_all(&self.buffer[..])
self.writer
.write_all(&self.buffer[..])
.and_then(|_| self.writer.flush())
},
Err(err) => {
Err(io::Error::new(
io::ErrorKind::Other,
err.desc))
}
Err(err) => Err(From::from(err)),
}
}
/// Get reference to the inner writer
pub fn get_ref(&self) -> &W {
&self.writer
}
@@ -164,23 +176,19 @@ impl<W: Write> EncryptedWriter<W> {
}
}
impl<W: Write> Write for EncryptedWriter<W> {
impl<W> Write for EncryptedWriter<W>
where W: Write + 'static
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.buffer.clear();
match self.cipher.update(buf, &mut self.buffer) {
Ok(..) => {
match self.writer.write_all(&self.buffer[..]) {
Ok(..) => {
Ok(buf.len())
},
Ok(..) => Ok(buf.len()),
Err(err) => Err(err),
}
},
Err(err) => {
Err(io::Error::new(
io::ErrorKind::Other,
err.desc))
}
Err(err) => Err(From::from(err)),
}
}
@@ -189,7 +197,9 @@ impl<W: Write> Write for EncryptedWriter<W> {
}
}
impl<W: Write> Drop for EncryptedWriter<W> {
impl<W> Drop for EncryptedWriter<W>
where W: Write + 'static
{
fn drop(&mut self) {
let _ = self.finalize();
}