From 5bf6560352798d02cc1f3fb591828c23e2485e69 Mon Sep 17 00:00:00 2001 From: "Y. T. Chung" Date: Fri, 21 Oct 2016 19:25:03 +0800 Subject: [PATCH] Migrating to futures and tokio --- Cargo.toml | 11 +- src/bin/local.rs | 104 +- src/bin/server.rs | 73 +- src/config.rs | 281 ++--- src/crypto/cipher.rs | 250 +++-- src/crypto/mod.rs | 4 +- src/crypto/openssl.rs | 38 +- src/lib.rs | 9 +- src/relay/loadbalancing/server/mod.rs | 6 +- src/relay/loadbalancing/server/roundrobin.rs | 24 +- src/relay/local.rs | 107 +- src/relay/mod.rs | 56 - src/relay/server.rs | 74 +- src/relay/socks5.rs | 420 +++++--- src/relay/tcprelay/http.rs | 136 ++- src/relay/tcprelay/local.rs | 1010 +++++++----------- src/relay/tcprelay/mod.rs | 315 +++--- src/relay/tcprelay/server.rs | 461 ++++---- src/relay/tcprelay/stream.rs | 90 +- 19 files changed, 1553 insertions(+), 1916 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a46a4d6e..8201016c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,8 +16,6 @@ default = [ "cipher-chacha20", "cipher-salsa20", - - "enable-udp", ] cipher-aes-cfb = [] @@ -58,13 +56,12 @@ qrcode = "^0.1.6" env_logger = "^0.3.2" rust-crypto = "^0.2.34" ip = "1.0.0" -openssl = "^0.7.1" +openssl = "^0.8" lru-cache = "0.0.7" libc = "^0.2.7" hyper = "0.9" url = "^1.2" httparse = "^1.1" - -[dependencies.coio] -git = "https://github.com/zonyitoo/coio-rs.git" -#branch = "io-timeouts" +futures = "0.1" +futures-cpupool = "0.1" +tokio-core = "0.1" diff --git a/src/bin/local.rs b/src/bin/local.rs index df1181ac..57ca3d9d 100644 --- a/src/bin/local.rs +++ b/src/bin/local.rs @@ -30,27 +30,21 @@ extern crate clap; extern crate shadowsocks; #[macro_use] extern crate log; -extern crate time; -extern crate coio; extern crate env_logger; -extern crate ip; +extern crate time; use clap::{App, Arg}; -use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::net::SocketAddr; use std::env; -use std::time::Duration; - -use coio::Scheduler; +use std::sync::Arc; use env_logger::LogBuilder; use log::{LogRecord, LogLevelFilter}; -use ip::IpAddr; - use shadowsocks::config::{self, Config, ServerConfig}; use shadowsocks::config::DEFAULT_DNS_CACHE_CAPACITY; -use shadowsocks::relay::{RelayLocal, Relay}; +use shadowsocks::relay::RelayLocal; fn main() { let matches = App::new("shadowsocks") @@ -75,21 +69,11 @@ fn main() { .long("server-addr") .takes_value(true) .help("Server address")) - .arg(Arg::with_name("SERVER_PORT") - .short("p") - .long("server-port") - .takes_value(true) - .help("Server port")) .arg(Arg::with_name("LOCAL_ADDR") .short("b") .long("local-addr") .takes_value(true) .help("Local address, listen only to this address if specified")) - .arg(Arg::with_name("LOCAL_PORT") - .short("l") - .long("local-port") - .takes_value(true) - .help("Local port")) .arg(Arg::with_name("PASSWORD") .short("k") .long("password") @@ -203,26 +187,21 @@ fn main() { let mut has_provided_server_config = false; - if matches.value_of("SERVER_ADDR").is_some() && matches.value_of("SERVER_PORT").is_some() && - matches.value_of("PASSWORD").is_some() && matches.value_of("ENCRYPT_METHOD").is_some() { - let (svr_addr, svr_port, password, method) = matches.value_of("SERVER_ADDR") + if matches.value_of("SERVER_ADDR").is_some() && matches.value_of("PASSWORD").is_some() && + matches.value_of("ENCRYPT_METHOD").is_some() { + let (svr_addr, password, method) = matches.value_of("SERVER_ADDR") .and_then(|svr_addr| { - matches.value_of("SERVER_PORT") - .map(|svr_port| (svr_addr, svr_port)) - }) - .and_then(|(svr_addr, svr_port)| { matches.value_of("PASSWORD") - .map(|pwd| (svr_addr, svr_port, pwd)) + .map(|pwd| (svr_addr, pwd)) }) - .and_then(|(svr_addr, svr_port, pwd)| { + .and_then(|(svr_addr, pwd)| { matches.value_of("ENCRYPT_METHOD") - .map(|m| (svr_addr, svr_port, pwd, m)) + .map(|m| (svr_addr, pwd, m)) }) .unwrap(); let sc = ServerConfig { - addr: svr_addr.to_owned(), - port: svr_port.parse().ok().expect("`port` should be an integer"), + addr: svr_addr.parse::().expect("Invalid server addr"), password: password.to_owned(), method: match method.parse() { Ok(m) => m, @@ -234,34 +213,26 @@ fn main() { dns_cache_capacity: DEFAULT_DNS_CACHE_CAPACITY, }; - config.server.push(sc); + config.server.push(Arc::new(sc)); has_provided_server_config = true; - } else if matches.value_of("SERVER_ADDR").is_none() && matches.value_of("SERVER_PORT").is_none() && - matches.value_of("PASSWORD").is_none() && matches.value_of("ENCRYPT_METHOD").is_none() { + } else if matches.value_of("SERVER_ADDR").is_none() && matches.value_of("PASSWORD").is_none() && + matches.value_of("ENCRYPT_METHOD").is_none() { // Does not provide server config } else { - panic!("`server-addr`, `server-port`, `method` and `password` should be provided together"); + panic!("`server-addr`, `method` and `password` should be provided together"); } let mut has_provided_local_config = false; - if matches.value_of("LOCAL_ADDR").is_some() && matches.value_of("LOCAL_PORT").is_some() { - let (local_addr, local_port) = matches.value_of("LOCAL_ADDR") - .and_then(|local_addr| { - matches.value_of("LOCAL_PORT") - .map(|p| (local_addr, p)) - }) + if matches.value_of("LOCAL_ADDR").is_some() { + let local_addr = matches.value_of("LOCAL_ADDR") .unwrap(); - let local_addr: IpAddr = local_addr.parse() + let local_addr: SocketAddr = local_addr.parse() .ok() .expect("`local-addr` is not a valid IP address"); - let local_port: u16 = local_port.parse().ok().expect("`local-port` is not a valid integer"); - config.local = Some(match local_addr { - IpAddr::V4(v4) => SocketAddr::V4(SocketAddrV4::new(v4, local_port)), - IpAddr::V6(v6) => SocketAddr::V6(SocketAddrV6::new(v6, local_port, 0, 0)), - }); + config.local = Some(Arc::new(local_addr)); has_provided_local_config = true; } @@ -277,40 +248,5 @@ fn main() { debug!("Config: {:?}", config); - let threads = matches.value_of("THREADS") - .unwrap_or("1") - .parse::() - .ok() - .expect("`threads` should be an integer"); - - let stack_size = if matches.occurrences_of("VERBOSE") >= 1 { - 1 * 1024 * 1024 // 1M stack for formatting! - } else { - 128 * 1024 - }; - - trace!("Coroutine stack size: {}, thread count {}", - stack_size, - threads); - - Scheduler::new() - .with_workers(threads) - .default_stack_size(stack_size) - .run(move || { - if debug_level > 0 && cfg!(debug_assertions) { - // Statistic coroutine - Scheduler::spawn(|| { - loop { - debug!("STAT Coroutines: {}, TCP work: {}, HTTP work: {}", - Scheduler::instance().unwrap().work_count(), - RelayLocal::global_tcp_work_count(), - RelayLocal::global_http_work_count()); - coio::sleep(Duration::from_secs(1)); - } - }); - } - - RelayLocal::new(config).run(); - }) - .unwrap(); + RelayLocal::new(Arc::new(config)).run().unwrap(); } diff --git a/src/bin/server.rs b/src/bin/server.rs index 73098b19..c51b6aeb 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -32,24 +32,21 @@ extern crate clap; extern crate shadowsocks; #[macro_use] extern crate log; -extern crate time; -extern crate coio; extern crate env_logger; +extern crate time; use std::env; -use std::time::Duration; +use std::sync::Arc; +use std::net::SocketAddr; use clap::{App, Arg}; -use coio::Scheduler; - use env_logger::LogBuilder; use log::{LogRecord, LogLevelFilter}; -use time::PreciseTime; use shadowsocks::config::{self, Config, ServerConfig}; use shadowsocks::config::DEFAULT_DNS_CACHE_CAPACITY; -use shadowsocks::relay::{RelayServer, Relay}; +use shadowsocks::relay::RelayServer; fn main() { let matches = App::new("shadowsocks") @@ -74,11 +71,6 @@ fn main() { .long("server-addr") .takes_value(true) .help("Server address")) - .arg(Arg::with_name("SERVER_PORT") - .short("p") - .long("server-port") - .takes_value(true) - .help("Server port")) .arg(Arg::with_name("LOCAL_ADDR") .short("b") .long("local-addr") @@ -201,26 +193,21 @@ fn main() { let mut has_provided_server_config = false; - if matches.value_of("SERVER_ADDR").is_some() && matches.value_of("SERVER_PORT").is_some() && - matches.value_of("PASSWORD").is_some() && matches.value_of("ENCRYPT_METHOD").is_some() { - let (svr_addr, svr_port, password, method) = matches.value_of("SERVER_ADDR") + if matches.value_of("SERVER_ADDR").is_some() && matches.value_of("PASSWORD").is_some() && + matches.value_of("ENCRYPT_METHOD").is_some() { + let (svr_addr, password, method) = matches.value_of("SERVER_ADDR") .and_then(|svr_addr| { - matches.value_of("SERVER_PORT") - .map(|svr_port| (svr_addr, svr_port)) - }) - .and_then(|(svr_addr, svr_port)| { matches.value_of("PASSWORD") - .map(|pwd| (svr_addr, svr_port, pwd)) + .map(|pwd| (svr_addr, pwd)) }) - .and_then(|(svr_addr, svr_port, pwd)| { + .and_then(|(svr_addr, pwd)| { matches.value_of("ENCRYPT_METHOD") - .map(|m| (svr_addr, svr_port, pwd, m)) + .map(|m| (svr_addr, pwd, m)) }) .unwrap(); let sc = ServerConfig { - addr: svr_addr.to_owned(), - port: svr_port.parse().ok().expect("`port` should be an integer"), + addr: svr_addr.parse::().expect("`server-addr` invalid"), password: password.to_owned(), method: match method.parse() { Ok(m) => m, @@ -232,13 +219,13 @@ fn main() { dns_cache_capacity: DEFAULT_DNS_CACHE_CAPACITY, }; - config.server.push(sc); + config.server.push(Arc::new(sc)); has_provided_server_config = true; - } else if matches.value_of("SERVER_ADDR").is_none() && matches.value_of("SERVER_PORT").is_none() && - matches.value_of("PASSWORD").is_none() && matches.value_of("ENCRYPT_METHOD").is_none() { + } else if matches.value_of("SERVER_ADDR").is_none() && matches.value_of("PASSWORD").is_none() && + matches.value_of("ENCRYPT_METHOD").is_none() { // Does not provide server config } else { - panic!("`server-addr`, `server-port`, `method` and `password` should be provided together"); + panic!("`server-addr`, `method` and `password` should be provided together"); } if !has_provided_config && !has_provided_server_config { @@ -264,33 +251,5 @@ fn main() { .ok() .expect("`threads` should be an integer"); - let stack_size = if matches.occurrences_of("VERBOSE") >= 1 { - 1 * 1024 * 1024 // 1M stack for formatting! - } else { - 128 * 1024 - }; - - trace!("Coroutine stack size: {}, thread count {}", - stack_size, - threads); - - let show_run_time = matches.occurrences_of("VERBOSE") >= 2; - Scheduler::new() - .with_workers(threads) - .default_stack_size(stack_size) - .run(move || { - if show_run_time { - let start_time = PreciseTime::now(); - Scheduler::spawn(move || { - loop { - info!("SYSTEM System has already run {}", - start_time.to(PreciseTime::now())); - coio::sleep(Duration::from_secs(5)); - } - }); - } - - RelayServer::new(config).run(); - }) - .unwrap(); + RelayServer::new(Arc::new(config)).run(threads).unwrap(); } diff --git a/src/config.rs b/src/config.rs index 81601a9f..560daf32 100644 --- a/src/config.rs +++ b/src/config.rs @@ -76,6 +76,7 @@ use std::fmt::{self, Debug, Formatter}; use std::path::Path; use std::collections::HashSet; use std::time::Duration; +use std::sync::Arc; use ip::IpAddr; @@ -87,8 +88,7 @@ pub const DEFAULT_DNS_CACHE_CAPACITY: usize = 65536; /// Configuration for a server #[derive(Clone, Debug)] pub struct ServerConfig { - pub addr: String, - pub port: u16, + pub addr: SocketAddr, pub password: String, pub method: CipherType, pub timeout: Option, @@ -96,10 +96,9 @@ pub struct ServerConfig { } impl ServerConfig { - pub fn basic(addr: String, port: u16, password: String, method: CipherType) -> ServerConfig { + pub fn basic(addr: SocketAddr, password: String, method: CipherType) -> ServerConfig { ServerConfig { addr: addr, - port: port, password: password, method: method, timeout: None, @@ -112,8 +111,18 @@ impl json::ToJson for ServerConfig { fn to_json(&self) -> json::Json { use serialize::json::Json; let mut obj = json::Object::new(); - obj.insert("address".to_owned(), Json::String(self.addr.clone())); - obj.insert("port".to_owned(), Json::U64(self.port as u64)); + + match self.addr { + SocketAddr::V4(ref v4) => { + obj.insert("address".to_owned(), Json::String(v4.ip().to_string())); + obj.insert("port".to_owned(), Json::U64(v4.port() as u64)); + } + SocketAddr::V6(ref v6) => { + obj.insert("address".to_owned(), Json::String(v6.ip().to_string())); + obj.insert("port".to_owned(), Json::U64(v6.port() as u64)); + } + } + obj.insert("password".to_owned(), Json::String(self.password.clone())); obj.insert("method".to_owned(), Json::String(self.method.to_string())); if let Some(t) = self.timeout { @@ -138,12 +147,12 @@ pub enum ConfigType { /// Configuration #[derive(Clone, Debug)] pub struct Config { - pub server: Vec, - pub local: Option, - pub http_proxy: Option, + pub server: Vec>, + pub local: Option>, + pub http_proxy: Option>, pub enable_udp: bool, pub timeout: Option, - pub forbidden_ip: HashSet, + pub forbidden_ip: Arc>, } impl Default for Config { @@ -195,10 +204,110 @@ impl Config { http_proxy: None, enable_udp: false, timeout: None, - forbidden_ip: HashSet::new(), + forbidden_ip: Arc::new(HashSet::new()), } } + fn parse_server(server: &json::Object) -> Result { + let method = server.get("method") + .ok_or_else(|| Error::new(ErrorKind::MissingField, "need to specify a method", None)) + .and_then(|method_o| { + method_o.as_string() + .ok_or_else(|| Error::new(ErrorKind::Malformed, "`method` should be a string", None)) + }) + .and_then(|method_str| { + method_str.parse::() + .map_err(|_| { + Error::new(ErrorKind::Invalid, + "not supported method", + Some(format!("`{}` is not a supported method", method_str))) + }) + }); + + let method = try!(method); + + let port = server.get("port") + .or_else(|| server.get("server_port")) + .ok_or_else(|| { + Error::new(ErrorKind::MissingField, + "need to specify a server port", + None) + }) + .and_then(|port_o| { + port_o.as_u64() + .map(|u| u as u16) + .ok_or_else(|| Error::new(ErrorKind::Malformed, "`port` should be an integer", None)) + }); + + let port = try!(port); + + let addr = server.get("address") + .or_else(|| server.get("server")) + .ok_or_else(|| { + Error::new(ErrorKind::MissingField, + "need to specify a server address", + None) + }) + .and_then(|addr_o| { + addr_o.as_string() + .ok_or_else(|| Error::new(ErrorKind::Malformed, "`address` should be a string", None)) + }) + .and_then(|addr_str| { + addr_str.parse::() + .map(|v4| SocketAddr::V4(SocketAddrV4::new(v4, port))) + .or_else(|_| { + addr_str.parse::() + .map(|v6| SocketAddr::V6(SocketAddrV6::new(v6, port, 0, 0))) + }) + .map_err(|_| Error::new(ErrorKind::Malformed, "invalid server addr", None)) + }); + + let mut addr = try!(addr); + + // Merge address and port + match addr { + SocketAddr::V4(ref mut v4) => v4.set_port(port), + SocketAddr::V6(ref mut v6) => v6.set_port(port), + } + + let password = server.get("password") + .ok_or_else(|| Error::new(ErrorKind::MissingField, "need to specify a password", None)) + .and_then(|pwd_o| { + pwd_o.as_string() + .ok_or_else(|| Error::new(ErrorKind::Malformed, "`password` should be a string", None)) + .map(|s| s.to_string()) + }); + + let password = try!(password); + + let timeout = match server.get("timeout") { + Some(t) => { + let val = try!(t.as_u64() + .ok_or(Error::new(ErrorKind::Malformed, "`timeout` should be an integer", None))); + Some(Duration::from_secs(val)) + } + None => None, + }; + + let dns_cache_capacity = match server.get("dns_cache_capacity") { + Some(t) => { + try!(t.as_u64() + .ok_or(Error::new(ErrorKind::Malformed, + "`dns_cache_capacity` should be an integer", + None))) as usize + } + None => DEFAULT_DNS_CACHE_CAPACITY, + }; + + Ok(ServerConfig { + addr: addr, + password: password, + method: method, + timeout: timeout, + dns_cache_capacity: dns_cache_capacity, + }) + } + fn parse_json_object(o: &json::Object, require_local_info: bool) -> Result { let mut config = Config::new(); @@ -218,119 +327,17 @@ impl Config { .ok_or(Error::new(ErrorKind::Malformed, "`servers` should be a list", None))); for server in server_list.iter() { - let method_o = try!(server.find("method") - .ok_or(Error::new(ErrorKind::MissingField, "need to specify a method", None))); - let method_str = try!(method_o.as_string() - .ok_or(Error::new(ErrorKind::Malformed, "`method` should be a string", None))); - let method = try!(method_str.parse::().map_err(|_| { - Error::new(ErrorKind::Invalid, - "not supported method", - Some(format!("`{}` is not a supported method", method_str))) - })); - - let addr_o = try!(server.find("address") - .ok_or(Error::new(ErrorKind::MissingField, - "need to specify a server address", - None))); - let addr_str = try!(addr_o.as_string() - .ok_or(Error::new(ErrorKind::Malformed, "`address` should be a string", None))); - - let cfg = ServerConfig { - addr: addr_str.to_string(), - port: try!(try!(server.find("port") - .ok_or(Error::new(ErrorKind::MissingField, - "need to specify a server port", - None))) - .as_u64() - .ok_or(Error::new(ErrorKind::Malformed, - "`port` should be an \ - integer", - None))) as u16, - password: try!(try!(server.find("password") - .ok_or(Error::new(ErrorKind::MissingField, "need to specify a password", None))) - .as_string() - .ok_or(Error::new(ErrorKind::Malformed, "`password` should be a string", None))) - .to_string(), - method: method, - timeout: match server.find("timeout") { - Some(t) => { - let val = try!(t.as_u64() - .ok_or(Error::new(ErrorKind::Malformed, "`timeout` should be an integer", None))); - Some(Duration::from_secs(val)) - } - None => None, - }, - dns_cache_capacity: match server.find("dns_cache_capacity") { - Some(t) => { - try!(t.as_u64() - .ok_or(Error::new(ErrorKind::Malformed, - "`dns_cache_capacity` should be an integer", - None))) as usize - } - None => DEFAULT_DNS_CACHE_CAPACITY, - }, - }; - - config.server.push(cfg); + if let Some(server) = server.as_object() { + let cfg = try!(Config::parse_server(server)); + config.server.push(Arc::new(cfg)); + } } } else if o.contains_key("server") && o.contains_key("server_port") && o.contains_key("password") && o.contains_key("method") { // Traditional configuration file - let method_o = try!(o.get("method") - .ok_or(Error::new(ErrorKind::MissingField, "need to specify method", None))); - let method_str = try!(method_o.as_string() - .ok_or(Error::new(ErrorKind::Malformed, "`method` should be a string", None))); - let method = try!(method_str.parse::() - .map_err(|_| { - Error::new(ErrorKind::Invalid, - "not supported method", - Some(format!("`{}` is not a supported method", method_str))) - })); - let addr_o = try!(o.get("server") - .ok_or(Error::new(ErrorKind::MissingField, - "need to specify server address", - None))); - let addr_str = try!(addr_o.as_string() - .ok_or(Error::new(ErrorKind::Malformed, "`server` should be a string", None))); - - let single_server = ServerConfig { - addr: addr_str.to_string(), - port: try!(try!(o.get("server_port") - .ok_or(Error::new(ErrorKind::MissingField, - "need to specify a server port", - None))) - .as_u64() - .ok_or(Error::new(ErrorKind::Malformed, - "`port` should be an \ - integer", - None))) as u16, - password: try!(try!(o.get("password") - .ok_or(Error::new(ErrorKind::MissingField, "need to specify a password", None))) - .as_string() - .ok_or(Error::new(ErrorKind::Malformed, "`password` should be a string", None))) - .to_string(), - method: method, - timeout: match o.get("timeout") { - Some(t) => { - let val = try!(t.as_u64() - .ok_or(Error::new(ErrorKind::Malformed, "`timeout` should be an integer", None))); - Some(Duration::from_secs(val)) - } - None => None, - }, - dns_cache_capacity: match o.get("dns_cache_capacity") { - Some(t) => { - try!(t.as_u64() - .ok_or(Error::new(ErrorKind::Malformed, - "`dns_cache_capacity` should be an integer", - None))) as usize - } - None => DEFAULT_DNS_CACHE_CAPACITY, - }, - }; - - config.server = vec![single_server]; + let single_server = try!(Config::parse_server(o)); + config.server = vec![Arc::new(single_server)]; } if require_local_info { @@ -353,10 +360,10 @@ impl Config { None))) as u16; match addr_str.parse::() { - Ok(ip) => Some(SocketAddr::V4(SocketAddrV4::new(ip, port))), + Ok(ip) => Some(Arc::new(SocketAddr::V4(SocketAddrV4::new(ip, port)))), Err(..) => { match addr_str.parse::() { - Ok(ip) => Some(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0))), + Ok(ip) => Some(Arc::new(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)))), Err(..) => { return Err(Error::new(ErrorKind::Malformed, "`local_address` is not a valid IP \ @@ -392,10 +399,10 @@ impl Config { None))) as u16; match addr_str.parse::() { - Ok(ip) => Some(SocketAddr::V4(SocketAddrV4::new(ip, port))), + Ok(ip) => Some(Arc::new(SocketAddr::V4(SocketAddrV4::new(ip, port)))), Err(..) => { match addr_str.parse::() { - Ok(ip) => Some(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0))), + Ok(ip) => Some(Arc::new(SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)))), Err(..) => { return Err(Error::new(ErrorKind::Malformed, "`local_http_address` is not a valid IP \ @@ -418,7 +425,9 @@ impl Config { .ok_or(Error::new(ErrorKind::Malformed, "`forbidden_ip` should be a list", None))); - config.forbidden_ip.extend(forbidden_ip_arr.into_iter().filter_map(|x| { + let mut forbidden_ip = HashSet::new(); + + forbidden_ip.extend(forbidden_ip_arr.into_iter().filter_map(|x| { let x = match x.as_string() { Some(x) => x, None => { @@ -436,6 +445,8 @@ impl Config { } } })); + + config.forbidden_ip = Arc::new(forbidden_ip); } Ok(config) @@ -512,10 +523,20 @@ impl json::ToJson for Config { let mut obj = json::Object::new(); if self.server.len() == 1 { // Official format - obj.insert("server".to_owned(), - Json::String(self.server[0].addr.clone())); - obj.insert("server_port".to_owned(), - Json::U64(self.server[0].port as u64)); + + let server = &self.server[0]; + + match server.addr { + SocketAddr::V4(ref v4) => { + obj.insert("server".to_owned(), Json::String(v4.ip().to_string())); + obj.insert("server_port".to_owned(), Json::U64(v4.port() as u64)); + } + SocketAddr::V6(ref v6) => { + obj.insert("server".to_owned(), Json::String(v6.ip().to_string())); + obj.insert("server_port".to_owned(), Json::U64(v6.port() as u64)); + } + } + obj.insert("password".to_owned(), Json::String(self.server[0].password.clone())); obj.insert("method".to_owned(), @@ -528,8 +549,8 @@ impl json::ToJson for Config { obj.insert("servers".to_owned(), Json::Array(arr)); } - if let Some(l) = self.local { - let ip_str = match &l { + if let Some(ref l) = self.local { + let ip_str = match &**l { &SocketAddr::V4(ref v4) => v4.ip().to_string(), &SocketAddr::V6(ref v6) => v6.ip().to_string(), }; diff --git a/src/crypto/cipher.rs b/src/crypto/cipher.rs index f959ad75..d2aca89d 100644 --- a/src/crypto/cipher.rs +++ b/src/crypto/cipher.rs @@ -22,7 +22,8 @@ //! Ciphers use std::str::FromStr; -use std::fmt::{Debug, Display, self}; +use std::fmt::{self, Debug, Display}; +use std::io; use rand::{self, Rng}; use std::convert::From; @@ -34,6 +35,8 @@ use crypto::crypto::CryptoCipher; use crypto::digest::{self, DigestType}; +use openssl::crypto::symm; + /// Basic operation of Cipher, which is a Symmetric Cipher. /// /// The `update` method could be called multiple times, and the `finalize` method will @@ -45,48 +48,48 @@ pub trait Cipher { pub type CipherResult = Result; -#[derive(Copy, Clone)] -pub enum ErrorKind { +pub enum Error { UnknownCipherType, - OpenSSLError, -} - -pub struct Error { - pub kind: ErrorKind, - pub desc: &'static str, - pub detail: Option, -} - -impl Error { - pub fn new(kind: ErrorKind, desc: &'static str, detail: Option) -> Error { - Error { - kind: kind, - desc: desc, - detail: detail, - } - } + OpenSSLError(::openssl::error::ErrorStack), + IoError(io::Error), } impl Debug for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - try!(write!(f, "{}", self.desc)); - match self.detail { - Some(ref d) => write!(f, " ({})", d), - None => Ok(()) + match self { + &Error::UnknownCipherType => write!(f, "UnknownCipherType"), + &Error::OpenSSLError(ref err) => write!(f, "{:?}", err), + &Error::IoError(ref err) => write!(f, "{:?}", err), } } } impl Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - try!(write!(f, "{}", self.desc)); - match self.detail { - Some(ref d) => write!(f, " ({})", d), - None => Ok(()) + match self { + &Error::UnknownCipherType => write!(f, "UnknownCipherType"), + &Error::OpenSSLError(ref err) => write!(f, "{}", err), + &Error::IoError(ref err) => write!(f, "{}", err), } } } +impl From for io::Error { + fn from(e: Error) -> io::Error { + match e { + Error::UnknownCipherType => io::Error::new(io::ErrorKind::Other, "Unknown Cipher type"), + Error::OpenSSLError(err) => From::from(err), + Error::IoError(err) => err, + } + } +} + +impl From<::openssl::error::ErrorStack> for Error { + fn from(e: ::openssl::error::ErrorStack) -> Error { + Error::OpenSSLError(e) + } +} + #[cfg(feature = "cipher-aes-cfb")] const CIPHER_AES_128_CFB: &'static str = "aes-128-cfb"; #[cfg(feature = "cipher-aes-cfb")] @@ -121,21 +124,33 @@ const CIPHER_SALSA20: &'static str = "salsa20"; pub enum CipherType { Table, - #[cfg(feature = "cipher-aes-cfb")] Aes128Cfb, - #[cfg(feature = "cipher-aes-cfb")] Aes128Cfb1, - #[cfg(feature = "cipher-aes-cfb")] Aes128Cfb8, - #[cfg(feature = "cipher-aes-cfb")] Aes128Cfb128, + #[cfg(feature = "cipher-aes-cfb")] + Aes128Cfb, + #[cfg(feature = "cipher-aes-cfb")] + Aes128Cfb1, + #[cfg(feature = "cipher-aes-cfb")] + Aes128Cfb8, + #[cfg(feature = "cipher-aes-cfb")] + Aes128Cfb128, - #[cfg(feature = "cipher-aes-cfb")] Aes256Cfb, - #[cfg(feature = "cipher-aes-cfb")] Aes256Cfb1, - #[cfg(feature = "cipher-aes-cfb")] Aes256Cfb8, - #[cfg(feature = "cipher-aes-cfb")] Aes256Cfb128, + #[cfg(feature = "cipher-aes-cfb")] + Aes256Cfb, + #[cfg(feature = "cipher-aes-cfb")] + Aes256Cfb1, + #[cfg(feature = "cipher-aes-cfb")] + Aes256Cfb8, + #[cfg(feature = "cipher-aes-cfb")] + Aes256Cfb128, - #[cfg(feature = "cipher-rc4")] Rc4, - #[cfg(feature = "cipher-rc4")] Rc4Md5, + #[cfg(feature = "cipher-rc4")] + Rc4, + #[cfg(feature = "cipher-rc4")] + Rc4Md5, - #[cfg(feature = "cipher-chacha20")] ChaCha20, - #[cfg(feature = "cipher-salsa20")] Salsa20, + #[cfg(feature = "cipher-chacha20")] + ChaCha20, + #[cfg(feature = "cipher-salsa20")] + Salsa20, } impl CipherType { @@ -143,18 +158,25 @@ impl CipherType { match *self { CipherType::Table => 0, - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb => 16, - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb1 => 16, - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb8 => 16, - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb128 => 16, + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes128Cfb => symm::Type::AES_128_CFB128.block_size(), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes128Cfb1 => symm::Type::AES_128_CFB1.block_size(), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes128Cfb8 => symm::Type::AES_128_CFB8.block_size(), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes128Cfb128 => symm::Type::AES_128_CFB128.block_size(), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes256Cfb => symm::Type::AES_256_CFB128.block_size(), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes256Cfb1 => symm::Type::AES_256_CFB1.block_size(), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes256Cfb8 => symm::Type::AES_256_CFB8.block_size(), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes256Cfb128 => symm::Type::AES_256_CFB128.block_size(), - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb => 16, - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb1 => 16, - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb8 => 16, - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb128 => 16, - - #[cfg(feature = "cipher-rc4")] CipherType::Rc4 => 0, - #[cfg(feature = "cipher-rc4")] CipherType::Rc4Md5 => 16, + #[cfg(feature = "cipher-rc4")] CipherType::Rc4 => symm::Type::RC4_128.block_size(), + #[cfg(feature = "cipher-rc4")] CipherType::Rc4Md5 => symm::Type::RC4_128.block_size(), #[cfg(feature = "cipher-chacha20")] CipherType::ChaCha20 => 8, #[cfg(feature = "cipher-salsa20")] CipherType::Salsa20 => 8, @@ -165,18 +187,25 @@ impl CipherType { match *self { CipherType::Table => 0, - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb => 16, - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb1 => 16, - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb8 => 16, - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes128Cfb128 => 16, + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes128Cfb => symm::Type::AES_128_CFB128.key_len(), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes128Cfb1 => symm::Type::AES_128_CFB1.key_len(), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes128Cfb8 => symm::Type::AES_128_CFB8.key_len(), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes128Cfb128 => symm::Type::AES_128_CFB128.key_len(), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes256Cfb => symm::Type::AES_256_CFB128.key_len(), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes256Cfb1 => symm::Type::AES_256_CFB1.key_len(), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes256Cfb8 => symm::Type::AES_256_CFB8.key_len(), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes256Cfb128 => symm::Type::AES_256_CFB128.key_len(), - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb => 32, - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb1 => 32, - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb8 => 32, - #[cfg(feature = "cipher-aes-cfb")] CipherType::Aes256Cfb128 => 32, - - #[cfg(feature = "cipher-rc4")] CipherType::Rc4 => 16, - #[cfg(feature = "cipher-rc4")] CipherType::Rc4Md5 => 16, + #[cfg(feature = "cipher-rc4")] CipherType::Rc4 => symm::Type::RC4_128.key_len(), + #[cfg(feature = "cipher-rc4")] CipherType::Rc4Md5 => symm::Type::RC4_128.key_len(), #[cfg(feature = "cipher-chacha20")] CipherType::ChaCha20 => 32, #[cfg(feature = "cipher-salsa20")] CipherType::Salsa20 => 32, @@ -203,15 +232,49 @@ impl CipherType { i += 1 } - let whole = m.iter().fold(Vec::new(), |mut a, b| { a.extend_from_slice(b); a }); + let whole = m.iter().fold(Vec::new(), |mut a, b| { + a.extend_from_slice(b); + a + }); let key = whole[0..key_len].to_vec(); key } + pub fn iv_size(&self) -> usize { + match *self { + CipherType::Table => 0, + + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes128Cfb => symm::Type::AES_128_CFB128.iv_len().unwrap_or(0), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes128Cfb1 => symm::Type::AES_128_CFB1.iv_len().unwrap_or(0), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes128Cfb8 => symm::Type::AES_128_CFB8.iv_len().unwrap_or(0), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes128Cfb128 => symm::Type::AES_128_CFB128.iv_len().unwrap_or(0), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes256Cfb => symm::Type::AES_256_CFB128.iv_len().unwrap_or(0), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes256Cfb1 => symm::Type::AES_256_CFB1.iv_len().unwrap_or(0), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes256Cfb8 => symm::Type::AES_256_CFB8.iv_len().unwrap_or(0), + #[cfg(feature = "cipher-aes-cfb")] + CipherType::Aes256Cfb128 => symm::Type::AES_256_CFB128.iv_len().unwrap_or(0), + + #[cfg(feature = "cipher-rc4")] CipherType::Rc4 => symm::Type::RC4_128.iv_len().unwrap_or(0), + #[cfg(feature = "cipher-rc4")] CipherType::Rc4Md5 => symm::Type::RC4_128.iv_len().unwrap_or(0), + + #[cfg(feature = "cipher-chacha20")] CipherType::ChaCha20 => 32, + #[cfg(feature = "cipher-salsa20")] CipherType::Salsa20 => 32, + } + } + pub fn gen_init_vec(&self) -> Vec { - let iv_len = self.block_size(); + let iv_len = self.iv_size(); let mut iv = Vec::with_capacity(iv_len); - unsafe { iv.set_len(iv_len); } + unsafe { + iv.set_len(iv_len); + } rand::thread_rng().fill_bytes(iv.as_mut_slice()); iv @@ -224,46 +287,34 @@ impl FromStr for CipherType { match s { CIPHER_TABLE | "" => Ok(CipherType::Table), #[cfg(feature = "cipher-aes-cfb")] - CIPHER_AES_128_CFB => - Ok(CipherType::Aes128Cfb), + CIPHER_AES_128_CFB => Ok(CipherType::Aes128Cfb), #[cfg(feature = "cipher-aes-cfb")] - CIPHER_AES_128_CFB_1 => - Ok(CipherType::Aes128Cfb1), + CIPHER_AES_128_CFB_1 => Ok(CipherType::Aes128Cfb1), #[cfg(feature = "cipher-aes-cfb")] - CIPHER_AES_128_CFB_8 => - Ok(CipherType::Aes128Cfb8), + CIPHER_AES_128_CFB_8 => Ok(CipherType::Aes128Cfb8), #[cfg(feature = "cipher-aes-cfb")] - CIPHER_AES_128_CFB_128 => - Ok(CipherType::Aes128Cfb128), + CIPHER_AES_128_CFB_128 => Ok(CipherType::Aes128Cfb128), #[cfg(feature = "cipher-aes-cfb")] - CIPHER_AES_256_CFB => - Ok(CipherType::Aes256Cfb), + CIPHER_AES_256_CFB => Ok(CipherType::Aes256Cfb), #[cfg(feature = "cipher-aes-cfb")] - CIPHER_AES_256_CFB_1 => - Ok(CipherType::Aes256Cfb1), + CIPHER_AES_256_CFB_1 => Ok(CipherType::Aes256Cfb1), #[cfg(feature = "cipher-aes-cfb")] - CIPHER_AES_256_CFB_8 => - Ok(CipherType::Aes256Cfb8), + CIPHER_AES_256_CFB_8 => Ok(CipherType::Aes256Cfb8), #[cfg(feature = "cipher-aes-cfb")] - CIPHER_AES_256_CFB_128 => - Ok(CipherType::Aes256Cfb128), + CIPHER_AES_256_CFB_128 => Ok(CipherType::Aes256Cfb128), #[cfg(feature = "cipher-rc4")] - CIPHER_RC4 => - Ok(CipherType::Rc4), + CIPHER_RC4 => Ok(CipherType::Rc4), #[cfg(feature = "cipher-rc4")] - CIPHER_RC4_MD5 => - Ok(CipherType::Rc4Md5), + CIPHER_RC4_MD5 => Ok(CipherType::Rc4Md5), #[cfg(feature = "cipher-chacha20")] - CIPHER_CHACHA20 => - Ok(CipherType::ChaCha20), + CIPHER_CHACHA20 => Ok(CipherType::ChaCha20), #[cfg(feature = "cipher-salsa20")] - CIPHER_SALSA20 => - Ok(CipherType::Salsa20), + CIPHER_SALSA20 => Ok(CipherType::Salsa20), - _ => Err(Error::new(ErrorKind::UnknownCipherType, "Unknown cipher type", None)) + _ => Err(Error::UnknownCipherType), } } } @@ -360,15 +411,12 @@ pub fn with_type(t: CipherType, key: &[u8], iv: &[u8], mode: CryptoMode) -> Ciph CipherType::Table => CipherVariant::new(table::TableCipher::new(key, mode)), #[cfg(feature = "cipher-chacha20")] - CipherType::ChaCha20 => - CipherVariant::new(CryptoCipher::new(t, key, iv)), + CipherType::ChaCha20 => CipherVariant::new(CryptoCipher::new(t, key, iv)), #[cfg(feature = "cipher-salsa20")] - CipherType::Salsa20 => - CipherVariant::new(CryptoCipher::new(t, key, iv)), + CipherType::Salsa20 => CipherVariant::new(CryptoCipher::new(t, key, iv)), #[cfg(feature = "cipher-rc4")] - CipherType::Rc4Md5 => - CipherVariant::new(rc4_md5::Rc4Md5Cipher::new(key, iv, mode)), + CipherType::Rc4Md5 => CipherVariant::new(rc4_md5::Rc4Md5Cipher::new(key, iv, mode)), _ => CipherVariant::new(openssl::OpenSSLCipher::new(t, key, iv, mode)), } @@ -383,8 +431,14 @@ mod test_cipher { fn test_get_cipher() { let key = CipherType::Aes128Cfb.bytes_to_key(b"PassWORD"); let iv = CipherType::Aes128Cfb.gen_init_vec(); - let mut encryptor = with_type(CipherType::Aes128Cfb, &key[0..], &iv[0..], CryptoMode::Encrypt); - let mut decryptor = with_type(CipherType::Aes128Cfb, &key[0..], &iv[0..], CryptoMode::Decrypt); + let mut encryptor = with_type(CipherType::Aes128Cfb, + &key[0..], + &iv[0..], + CryptoMode::Encrypt); + let mut decryptor = with_type(CipherType::Aes128Cfb, + &key[0..], + &iv[0..], + CryptoMode::Decrypt); let message = "HELLO WORLD"; let mut encrypted_msg = Vec::new(); diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 2e5cd17c..1f501e3d 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -26,6 +26,8 @@ use std::convert::From; use openssl::crypto::symm; +pub use self::cipher::{CipherType, Cipher, CipherVariant}; + pub mod cipher; pub mod openssl; pub mod digest; @@ -36,7 +38,7 @@ pub mod crypto; #[derive(Clone, Copy)] pub enum CryptoMode { Encrypt, - Decrypt + Decrypt, } impl From for symm::Mode { diff --git a/src/crypto/openssl.rs b/src/crypto/openssl.rs index 3034de3e..3ce4188f 100644 --- a/src/crypto/openssl.rs +++ b/src/crypto/openssl.rs @@ -35,6 +35,7 @@ use openssl::crypto::symm; use openssl::crypto::hash; pub struct OpenSSLCrypto { + cipher: symm::Type, inner: symm::Crypter, } @@ -56,26 +57,38 @@ impl OpenSSLCrypto { #[cfg(feature = "cipher-rc4")] CipherType::Rc4 => symm::Type::RC4_128, - _ => panic!("Cipher type {:?} does not supported by OpenSSLCrypt yet", cipher_type), + _ => { + panic!("Cipher type {:?} does not supported by OpenSSLCrypt yet", + cipher_type) + } }; - let cipher = symm::Crypter::new(t); - cipher.init(From::from(mode), key, iv); + let key = cipher_type.bytes_to_key(key); + + // Panic if error occurs + let cipher = symm::Crypter::new(t, From::from(mode), &key[..], Some(iv)).unwrap(); OpenSSLCrypto { + cipher: t, inner: cipher, } } pub fn update(&mut self, data: &[u8], out: &mut Vec) -> CipherResult<()> { - let output = self.inner.update(data); - out.extend_from_slice(&output); + let orig_length = out.len(); + let least_reserved = data.len() + self.cipher.block_size(); + out.resize(orig_length + least_reserved, 0); + let length = try!(self.inner.update(data, &mut out[orig_length..])); + out.resize(orig_length + length, 0); Ok(()) } pub fn finalize(&mut self, out: &mut Vec) -> CipherResult<()> { - let output = self.inner.finalize(); - out.extend_from_slice(&output); + let orig_length = out.len(); + let least_reserved = self.cipher.block_size(); + out.resize(orig_length + least_reserved, 0); + let length = try!(self.inner.finalize(&mut out[orig_length..])); + out.resize(orig_length + length, 0); Ok(()) } } @@ -113,9 +126,7 @@ pub struct OpenSSLCipher { impl OpenSSLCipher { pub fn new(cipher_type: cipher::CipherType, key: &[u8], iv: &[u8], mode: CryptoMode) -> OpenSSLCipher { - OpenSSLCipher { - worker: OpenSSLCrypto::new(cipher_type, &key[..], &iv[..], mode), - } + OpenSSLCipher { worker: OpenSSLCrypto::new(cipher_type, &key[..], &iv[..], mode) } } } @@ -143,9 +154,7 @@ impl OpenSSLDigest { digest::DigestType::Sha1 => hash::Type::SHA1, }; - OpenSSLDigest { - inner: hash::Hasher::new(t), - } + OpenSSLDigest { inner: hash::Hasher::new(t).unwrap() } } } @@ -157,6 +166,7 @@ impl Digest for OpenSSLDigest { } fn digest(&mut self) -> Vec { - self.inner.finish() + // TODO: Check error + self.inner.finish().unwrap() } } diff --git a/src/lib.rs b/src/lib.rs index 0ef584b0..d2d6b36a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,8 +22,6 @@ #![crate_type = "lib"] #![crate_name = "shadowsocks"] -#![feature(lookup_host)] - extern crate rustc_serialize as serialize; #[macro_use] extern crate log; @@ -32,8 +30,6 @@ extern crate lru_cache; extern crate byteorder; extern crate rand; -extern crate coio; - extern crate crypto as rust_crypto; extern crate ip; extern crate openssl; @@ -41,6 +37,11 @@ extern crate hyper; extern crate url; extern crate httparse; +extern crate futures; +extern crate futures_cpupool; +#[macro_use] +extern crate tokio_core; + extern crate libc; pub const VERSION: &'static str = env!("CARGO_PKG_VERSION"); diff --git a/src/relay/loadbalancing/server/mod.rs b/src/relay/loadbalancing/server/mod.rs index c103e139..dbf9c0ce 100644 --- a/src/relay/loadbalancing/server/mod.rs +++ b/src/relay/loadbalancing/server/mod.rs @@ -19,13 +19,15 @@ // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +use std::sync::Arc; + pub use self::roundrobin::RoundRobin; use config::ServerConfig; pub mod roundrobin; -pub trait LoadBalancer { - fn pick_server<'a>(&'a mut self) -> &'a ServerConfig; +pub trait LoadBalancer: Send + 'static { + fn pick_server(&mut self) -> Arc; fn total(&self) -> usize; } diff --git a/src/relay/loadbalancing/server/roundrobin.rs b/src/relay/loadbalancing/server/roundrobin.rs index 9ed46a51..31fdd984 100644 --- a/src/relay/loadbalancing/server/roundrobin.rs +++ b/src/relay/loadbalancing/server/roundrobin.rs @@ -19,36 +19,40 @@ // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +use std::sync::Arc; + use relay::loadbalancing::server::LoadBalancer; -use config::ServerConfig; +use config::{Config, ServerConfig}; #[derive(Clone)] pub struct RoundRobin { - server: Vec, + config: Arc, index: usize, } impl RoundRobin { - pub fn new(config: Vec) -> RoundRobin { + pub fn new(config: Arc) -> RoundRobin { RoundRobin { - server: config, + config: config, index: 0usize, } } } impl LoadBalancer for RoundRobin { - fn pick_server<'a>(&'a mut self) -> &'a ServerConfig { - if self.server.is_empty() { + fn pick_server(&mut self) -> Arc { + let server = &self.config.server; + + if server.is_empty() { panic!("No server"); } - let ref s = self.server[self.index]; - self.index = (self.index + 1) % self.server.len(); - s + let ref s = server[self.index]; + self.index = (self.index + 1) % server.len(); + s.clone() } fn total(&self) -> usize { - self.server.len() + self.config.server.len() } } diff --git a/src/relay/local.rs b/src/relay/local.rs index c6fa05e4..e917efaf 100644 --- a/src/relay/local.rs +++ b/src/relay/local.rs @@ -21,9 +21,11 @@ //! Local side -use coio::Scheduler; +use std::sync::Arc; +use std::io; + +use tokio_core::reactor::Core; -use relay::Relay; use relay::tcprelay::local::TcpRelayLocal; #[cfg(feature = "enable-udp")] use relay::udprelay::local::UdpRelayLocal; @@ -55,103 +57,18 @@ use config::Config; /// ``` #[derive(Clone)] pub struct RelayLocal { - enable_udp: bool, - enable_http: bool, - tcprelay: TcpRelayLocal, - #[cfg(feature = "enable-udp")] - udprelay: UdpRelayLocal, + config: Arc, } impl RelayLocal { - #[cfg(feature = "enable-udp")] - pub fn new(config: Config) -> RelayLocal { - let tcprelay = TcpRelayLocal::new(config.clone()); - let udprelay = UdpRelayLocal::new(config.clone()); - RelayLocal { - tcprelay: tcprelay, - udprelay: udprelay, - enable_udp: config.enable_udp, - enable_http: config.http_proxy.is_some(), - } + pub fn new(config: Arc) -> RelayLocal { + RelayLocal { config: config } } - #[cfg(not(feature = "enable-udp"))] - pub fn new(config: Config) -> RelayLocal { - let tcprelay = TcpRelayLocal::new(config.clone()); - RelayLocal { - tcprelay: tcprelay, - enable_udp: config.enable_udp, - enable_http: config.http_proxy.is_some(), - } - } - - /// Global TCP work count - pub fn global_tcp_work_count() -> usize { - super::tcprelay::global_tcp_work_count() - } - - /// Global HTTP work count - pub fn global_http_work_count() -> usize { - super::tcprelay::global_http_work_count() - } -} - -impl Relay for RelayLocal { - #[cfg(not(feature = "enable-udp"))] - fn run(&self) { - if self.enable_udp { - warn!("UDP relay feature is disabled, recompile with feature=\"enable-udp\" to enable this feature"); - } - let mut futs = Vec::new(); - - let tcprelay = self.tcprelay.clone(); - let tcp_fut = Scheduler::spawn(move || { - info!("Enabled TCP relay"); - tcprelay.run_tcp() - }); - futs.push(tcp_fut); - - let tcprelay = self.tcprelay.clone(); - let tcp_fut = Scheduler::spawn(move || { - info!("Enabled HTTP relay"); - tcprelay.run_http() - }); - futs.push(tcp_fut); - - for fut in futs { - fut.join().unwrap(); - } - } - - fn run(&self) { - let mut futs = Vec::new(); - - let tcprelay = self.tcprelay.clone(); - let tcp_fut = Scheduler::spawn(move || { - info!("Enabled TCP relay"); - tcprelay.run_tcp() - }); - - futs.push(tcp_fut); - - if self.enable_udp { - let udprelay = self.udprelay.clone(); - let udp_fut = Scheduler::spawn(move || { - info!("Enabled UDP relay"); - udprelay.run() - }); - futs.push(udp_fut); - } - - let tcprelay = self.tcprelay.clone(); - let tcp_fut = Scheduler::spawn(move || { - info!("Enabled HTTP relay"); - tcprelay.run_http() - }); - futs.push(tcp_fut); - - for fut in futs { - fut.join().unwrap(); - } + pub fn run(self) -> io::Result<()> { + let mut lp = try!(Core::new()); + let handle = lp.handle(); + let tcp_fut = TcpRelayLocal::new(self.config.clone()).run(handle.clone()); + lp.run(tcp_fut) } } diff --git a/src/relay/mod.rs b/src/relay/mod.rs index de922f21..4cb680ed 100644 --- a/src/relay/mod.rs +++ b/src/relay/mod.rs @@ -21,15 +21,9 @@ //! Relay server in local and server side implementations. -use std::io::{self, Read, Write}; -use std::net::SocketAddr; -use std::mem; - pub use self::local::RelayLocal; pub use self::server::RelayServer; -use ip::IpAddr; - mod tcprelay; #[cfg(feature = "enable-udp")] mod udprelay; @@ -37,53 +31,3 @@ pub mod local; pub mod server; mod loadbalancing; pub mod socks5; - -pub trait Relay { - fn run(&self); -} - -fn copy_once(r: &mut R, w: &mut W) -> io::Result { - let mut buf: [u8; 4096] = unsafe { mem::uninitialized() }; - let len = match r.read(&mut buf) { - Ok(0) => return Ok(0), - Ok(len) => len, - Err(e) => return Err(e), - }; - w.write_all(&buf[..len]).and_then(|_| w.flush()).map(|_| len) -} - -fn copy_exact(r: &mut R, w: &mut W, len: usize) -> io::Result<()> { - let mut buf: [u8; 4096] = unsafe { mem::uninitialized() }; - let mut remain = len; - - while remain > 0 { - let bufl = if remain > buf.len() { - buf.len() - } else { - remain - }; - let len = match r.read(&mut buf[..bufl]) { - Ok(0) => break, - Ok(len) => { - remain -= len; - len - } - Err(e) => return Err(e), - }; - try!(w.write_all(&buf[..len]).and_then(|_| w.flush())); - } - - if remain != 0 { - let err = io::Error::new(io::ErrorKind::UnexpectedEof, "Unexpected Eof"); - Err(err) - } else { - Ok(()) - } -} - -fn take_ip_addr(sockaddr: &SocketAddr) -> IpAddr { - match sockaddr { - &SocketAddr::V4(ref v4) => IpAddr::V4(*v4.ip()), - &SocketAddr::V6(ref v6) => IpAddr::V6(*v6.ip()), - } -} diff --git a/src/relay/server.rs b/src/relay/server.rs index f067b5f0..2b17c6f4 100644 --- a/src/relay/server.rs +++ b/src/relay/server.rs @@ -21,12 +21,14 @@ //! Server side -use coio::Scheduler; +use std::sync::Arc; +use std::io; -#[cfg(feature = "enable-udp")] -use relay::udprelay::server::UdpRelayServer; +use tokio_core::reactor::Core; + +// #[cfg(feature = "enable-udp")] +// use relay::udprelay::server::UdpRelayServer; use relay::tcprelay::server::TcpRelayServer; -use relay::Relay; use config::Config; /// Relay server running on server side. @@ -55,66 +57,18 @@ use config::Config; /// #[derive(Clone)] pub struct RelayServer { - enable_udp: bool, - tcprelay: TcpRelayServer, - #[cfg(feature = "enable-udp")] - udprelay: UdpRelayServer, + config: Arc, } impl RelayServer { - #[cfg(feature = "enable-udp")] - pub fn new(config: Config) -> RelayServer { - let tcprelay = TcpRelayServer::new(config.clone()); - let udprelay = UdpRelayServer::new(config.clone()); - RelayServer { - tcprelay: tcprelay, - udprelay: udprelay, - enable_udp: config.enable_udp, - } + pub fn new(config: Arc) -> RelayServer { + RelayServer { config: config } } - #[cfg(not(feature = "enable-udp"))] - pub fn new(config: Config) -> RelayServer { - let tcprelay = TcpRelayServer::new(config.clone()); - RelayServer { - tcprelay: tcprelay, - enable_udp: config.enable_udp, - } - } -} - -impl Relay for RelayServer { - #[cfg(feature = "enable-udp")] - fn run(&self) { - let mut futs = Vec::new(); - - let tcprelay = self.tcprelay.clone(); - let tcp_fut = Scheduler::spawn(move || tcprelay.run()); - info!("Enabled TCP relay"); - futs.push(tcp_fut); - - if self.enable_udp { - let udprelay = self.udprelay.clone(); - let udp_fut = Scheduler::spawn(move || udprelay.run()); - info!("Enabled UDP relay"); - futs.push(udp_fut); - } - - for fut in futs { - fut.join().unwrap(); - } - } - - #[cfg(not(feature = "enable-udp"))] - fn run(&self) { - if self.enable_udp { - warn!("UDP relay feature is disabled, recompile with feature=\"enable-udp\" to enable this feature"); - } - - let tcprelay = self.tcprelay.clone(); - let fut = Scheduler::spawn(move || tcprelay.run()); - info!("Enabled TCP relay"); - - fut.join().unwrap(); + pub fn run(self, threads: usize) -> io::Result<()> { + let mut lp = try!(Core::new()); + let handle = lp.handle(); + let tcp_fut = TcpRelayServer::new(self.config.clone(), threads).run(handle.clone()); + lp.run(tcp_fut) } } diff --git a/src/relay/socks5.rs b/src/relay/socks5.rs index 4e5aea3c..1b14407f 100644 --- a/src/relay/socks5.rs +++ b/src/relay/socks5.rs @@ -23,13 +23,17 @@ use std::fmt::{self, Debug, Formatter}; use std::net::{Ipv4Addr, Ipv6Addr, ToSocketAddrs, SocketAddr, SocketAddrV4, SocketAddrV6}; -use std::io::{self, Read, Write}; +use std::io::{self, Cursor, Read, Write}; use std::vec; use std::error; use std::convert::From; use byteorder::{ReadBytesExt, WriteBytesExt, BigEndian}; +use futures::{self, Future, BoxFuture}; + +use tokio_core::io::{read_exact, write_all}; + const SOCKS5_VERSION: u8 = 0x05; pub const SOCKS5_AUTH_METHOD_NONE: u8 = 0x00; @@ -179,6 +183,12 @@ impl From for Error { } } +impl From for io::Error { + fn from(err: Error) -> io::Error { + io::Error::new(io::ErrorKind::Other, err.message) + } +} + #[derive(Clone, PartialEq, Eq, Hash)] pub enum Address { SocketAddress(SocketAddr), @@ -186,16 +196,16 @@ pub enum Address { } impl Address { - #[inline] - pub fn read_from(reader: &mut R) -> Result { - match parse_request_header(reader) { - Ok((_, addr)) => Ok(addr), - Err(err) => Err(err), - } + pub fn read_from(stream: R) -> BoxFuture<(R, Address), Error> + where R: Read + Send + 'static + { + parse_request_header(stream) } #[inline] - pub fn write_to(&self, writer: &mut W) -> io::Result<()> { + pub fn write_to(self, writer: W) -> BoxFuture + where W: Write + Send + 'static + { write_addr(self, writer) } @@ -249,31 +259,46 @@ impl TcpRequestHeader { } } - #[inline] - pub fn read_from(stream: &mut R) -> Result { - let ver = try!(stream.read_u8()); - if ver != SOCKS5_VERSION { - return Err(Error::new(Reply::ConnectionRefused, "Unsupported Socks version")); - } + pub fn read_from(r: R) -> BoxFuture<(R, TcpRequestHeader), Error> + where R: Read + Send + 'static + { + read_exact(r, [0u8; 3]) + .map_err(From::from) + .and_then(|(r, buf)| { + let ver = buf[0]; + if ver != SOCKS5_VERSION { + return Err(Error::new(Reply::ConnectionRefused, "Unsupported Socks version")); + } - let cmd = try!(stream.read_u8()); - let _ = try!(stream.read_u8()); + let cmd = buf[1]; + let command = match Command::from_u8(cmd) { + Some(c) => c, + None => return Err(Error::new(Reply::CommandNotSupported, "Unsupported command")), + }; - Ok(TcpRequestHeader { - command: match Command::from_u8(cmd) { - Some(c) => c, - None => return Err(Error::new(Reply::CommandNotSupported, "Unsupported command")), - }, - address: try!(Address::read_from(stream)), - }) + Ok((r, command)) + }) + .and_then(|(r, command)| { + Address::read_from(r).map(move |(conn, address)| { + let header = TcpRequestHeader { + command: command, + address: address, + }; + + (conn, header) + }) + }) + .boxed() } - #[inline] - pub fn write_to(&self, stream: &mut W) -> io::Result<()> { - try!(stream.write_all(&[SOCKS5_VERSION, self.command.as_u8(), 0x00])); - try!(self.address.write_to(stream)); + pub fn write_to(&self, w: W) -> BoxFuture + where W: Write + Send + 'static + { + let addr = self.address.clone(); - Ok(()) + write_all(w, [SOCKS5_VERSION, self.command.as_u8(), 0x00]) + .and_then(move |(conn, _)| addr.write_to(conn)) + .boxed() } #[inline] @@ -296,28 +321,42 @@ impl TcpResponseHeader { } } - #[inline] - pub fn read_from(stream: &mut R) -> Result { - let ver = try!(stream.read_u8()); - if ver != SOCKS5_VERSION { - return Err(Error::new(Reply::ConnectionRefused, "Unsupported Socks version")); - } + pub fn read_from(r: R) -> BoxFuture<(R, TcpResponseHeader), Error> + where R: Read + Send + 'static + { + read_exact(r, [0u8; 3]) + .map_err(From::from) + .and_then(|(r, buf)| { + let ver = buf[0]; + let reply_code = buf[1]; - let reply_code = try!(stream.read_u8()); - let _ = try!(stream.read_u8()); + if ver != SOCKS5_VERSION { + return Err(Error::new(Reply::ConnectionRefused, "Unsupported Socks version")); + } - Ok(TcpResponseHeader { - reply: Reply::from_u8(reply_code), - address: try!(Address::read_from(stream)), - }) + Ok((r, reply_code)) + }) + .and_then(|(r, reply_code)| { + Address::read_from(r).map(move |(r, address)| { + let rep = TcpResponseHeader { + reply: Reply::from_u8(reply_code), + address: address, + }; + + (r, rep) + }) + }) + .boxed() } - #[inline] - pub fn write_to(&self, stream: &mut W) -> io::Result<()> { - try!(stream.write_all(&[SOCKS5_VERSION, self.reply.as_u8(), 0x00])); - try!(self.address.write_to(stream)); - - Ok(()) + pub fn write_to(&self, w: W) -> BoxFuture + where W: Write + Send + 'static + { + let addr = self.address.clone(); + write_all(w, [SOCKS5_VERSION, self.reply.as_u8(), 0x00]) + .map(From::from) + .and_then(move |(w, _)| addr.write_to(w)) + .boxed() } #[inline] @@ -326,82 +365,134 @@ impl TcpResponseHeader { } } -#[inline] -fn parse_request_header(stream: &mut R) -> Result<(usize, Address), Error> { - let atyp = match stream.read_u8() { - Ok(atyp) => atyp, - Err(_) => return Err(Error::new(Reply::GeneralFailure, "Error while reading address type")), - }; +fn parse_request_header(stream: R) -> BoxFuture<(R, Address), Error> + where R: Read + Send + 'static +{ + read_exact(stream, [0u8]) + .map_err(|_| Error::new(Reply::GeneralFailure, "Error while reading address type")) + .and_then(|(conn, atyp)| { + match atyp[0] { + SOCKS5_ADDR_TYPE_IPV4 => { + let v4addr = read_exact(conn, [0u8; 6]).map_err(From::from); + v4addr.and_then(|(conn, v4addr)| { + let mut stream = Cursor::new(v4addr); + let v4addr = Ipv4Addr::new(try!(stream.read_u8()), + try!(stream.read_u8()), + try!(stream.read_u8()), + try!(stream.read_u8())); + let port = try!(stream.read_u16::()); - match atyp { - SOCKS5_ADDR_TYPE_IPV4 => { - let v4addr = Ipv4Addr::new(try!(stream.read_u8()), - try!(stream.read_u8()), - try!(stream.read_u8()), - try!(stream.read_u8())); - let port = try!(stream.read_u16::()); - Ok((7usize, Address::SocketAddress(SocketAddr::V4(SocketAddrV4::new(v4addr, port))))) - } - SOCKS5_ADDR_TYPE_IPV6 => { - let v6addr = Ipv6Addr::new(try!(stream.read_u16::()), - try!(stream.read_u16::()), - try!(stream.read_u16::()), - try!(stream.read_u16::()), - try!(stream.read_u16::()), - try!(stream.read_u16::()), - try!(stream.read_u16::()), - try!(stream.read_u16::())); - let port = try!(stream.read_u16::()); - - Ok((19usize, Address::SocketAddress(SocketAddr::V6(SocketAddrV6::new(v6addr, port, 0, 0))))) - } - SOCKS5_ADDR_TYPE_DOMAIN_NAME => { - let addr_len = try!(stream.read_u8()) as usize; - let mut raw_addr = Vec::with_capacity(addr_len); - try!(stream.take(addr_len as u64).read_to_end(&mut raw_addr)); - let port = try!(stream.read_u16::()); - - let addr = match String::from_utf8(raw_addr) { - Ok(addr) => addr, - Err(..) => return Err(Error::new(Reply::GeneralFailure, "Invalid address encoding")), - }; - - Ok((4 + addr_len, Address::DomainNameAddress(addr, port))) - } - _ => { - // Address type not supported - Err(Error::new(Reply::AddressTypeNotSupported, "Not supported address type")) - } - } -} - -#[inline] -fn write_addr(addr: &Address, buf: &mut W) -> io::Result<()> { - match addr { - &Address::SocketAddress(addr) => { - match addr { - SocketAddr::V4(addr) => { - try!(buf.write_all(&[SOCKS5_ADDR_TYPE_IPV4])); - try!(buf.write_all(&addr.ip().octets())); + Ok((conn, Address::SocketAddress(SocketAddr::V4(SocketAddrV4::new(v4addr, port))))) + }) + .boxed() } - SocketAddr::V6(addr) => { - try!(buf.write_u8(SOCKS5_ADDR_TYPE_IPV6)); - for seg in &addr.ip().segments() { - try!(buf.write_u16::(*seg)); - } + SOCKS5_ADDR_TYPE_IPV6 => { + let v6addr = read_exact(conn, [0u8; 18]).map_err(From::from); + v6addr.and_then(|(conn, v6addr)| { + let mut stream = Cursor::new(v6addr); + let v6addr = Ipv6Addr::new(try!(stream.read_u16::()), + try!(stream.read_u16::()), + try!(stream.read_u16::()), + try!(stream.read_u16::()), + try!(stream.read_u16::()), + try!(stream.read_u16::()), + try!(stream.read_u16::()), + try!(stream.read_u16::())); + let port = try!(stream.read_u16::()); + + Ok((conn, Address::SocketAddress(SocketAddr::V6(SocketAddrV6::new(v6addr, port, 0, 0))))) + }) + .boxed() + } + SOCKS5_ADDR_TYPE_DOMAIN_NAME => { + let addr_len = read_exact(conn, [0u8]).map_err(From::from); + addr_len.and_then(|(conn, addr_len)| { + let addr_len = addr_len[0] as usize; + let raw_addr = read_exact(conn, vec![0u8; addr_len]).map_err(From::from); + raw_addr.and_then(|(conn, raw_addr)| { + let port = read_exact(conn, [0u8; 2]).map_err(From::from); + port.and_then(|(conn, port)| { + let mut stream = Cursor::new(port); + let port = try!(stream.read_u16::()); + + let addr = match String::from_utf8(raw_addr) { + Ok(addr) => addr, + Err(..) => { + return Err(Error::new(Reply::GeneralFailure, "Invalid address encoding")) + } + }; + + Ok((conn, Address::DomainNameAddress(addr, port))) + }) + }) + }) + .boxed() + } + _ => { + // Address type not supported + futures::failed(Error::new(Reply::AddressTypeNotSupported, "Not supported address type")).boxed() } } - try!(buf.write_u16::(addr.port())); + }) + .boxed() +} + +fn write_addr(addr: Address, w: W) -> BoxFuture + where W: Write + Send + 'static +{ + match addr { + Address::SocketAddress(addr) => { + let write_addr = match addr { + SocketAddr::V4(addr) => { + write_all(w, [SOCKS5_ADDR_TYPE_IPV4]) + .and_then(move |(w, _)| write_all(w, addr.ip().octets())) + .map(|(conn, _)| conn) + .boxed() + } + SocketAddr::V6(addr) => { + write_all(w, [SOCKS5_ADDR_TYPE_IPV6]) + .and_then(move |(w, _)| { + let mut rbuf = [0u8; 16]; + { + let mut buf = Cursor::new(&mut rbuf[..]); + for seg in &addr.ip().segments() { + try!(buf.write_u16::(*seg)); + } + } + Ok((w, rbuf)) + }) + .and_then(|(w, rbuf)| write_all(w, rbuf)) + .map(|(conn, _)| conn) + .boxed() + } + }; + + write_addr.and_then(move |w| { + let mut rbuf = [0u8; 2]; + { + let mut buf = Cursor::new(&mut rbuf[..]); + try!(buf.write_u16::(addr.port())); + } + Ok((w, rbuf)) + }) + .and_then(|(w, rbuf)| write_all(w, rbuf)) + .map(|(conn, _)| conn) + .boxed() } - &Address::DomainNameAddress(ref dnaddr, port) => { - try!(buf.write_u8(SOCKS5_ADDR_TYPE_DOMAIN_NAME)); - try!(buf.write_u8(dnaddr.len() as u8)); - try!(buf.write_all(dnaddr[..].as_bytes())); - try!(buf.write_u16::(port)); + Address::DomainNameAddress(dnaddr, port) => { + futures::lazy(move || { + let mut buf = Vec::with_capacity(dnaddr.len() + 4); + try!(buf.write_u8(SOCKS5_ADDR_TYPE_DOMAIN_NAME)); + try!(buf.write_u8(dnaddr.len() as u8)); + try!(buf.write_all(dnaddr[..].as_bytes())); + try!(buf.write_u16::(port)); + Ok(buf) + }) + .and_then(|buf| write_all(w, buf)) + .map(|(conn, _)| conn) + .boxed() } } - - Ok(()) } #[inline] @@ -432,25 +523,34 @@ impl HandshakeRequest { HandshakeRequest { methods: methods } } - pub fn read_from(stream: &mut R) -> io::Result { - let ver = try!(stream.read_u8()); - if ver != SOCKS5_VERSION { - return Err(io::Error::new(io::ErrorKind::Other, "Invalid Socks5 version")); - } + pub fn read_from(r: R) -> BoxFuture<(R, HandshakeRequest), io::Error> + where R: Read + Send + 'static + { + read_exact(r, [0u8, 0u8]) + .and_then(|(r, buf)| { + let ver = buf[0]; + let nmet = buf[1]; - let nmet = try!(stream.read_u8()); + if ver != SOCKS5_VERSION { + return Err(io::Error::new(io::ErrorKind::Other, "Invalid Socks5 version")); + } - let mut methods = Vec::new(); - try!(stream.take(nmet as u64).read_to_end(&mut methods)); - - Ok(HandshakeRequest { methods: methods }) + Ok((r, nmet)) + }) + .and_then(|(r, nmet)| { + read_exact(r, vec![0u8; nmet as usize]) + .and_then(|(r, methods)| Ok((r, HandshakeRequest { methods: methods }))) + }) + .boxed() } - pub fn write_to(&self, stream: &mut Write) -> io::Result<()> { - try!(stream.write_all(&[SOCKS5_VERSION, self.methods.len() as u8])); - try!(stream.write_all(&self.methods[..])); - - Ok(()) + pub fn write_to(self, w: W) -> BoxFuture + where W: Write + Send + 'static + { + write_all(w, [SOCKS5_VERSION, self.methods.len() as u8]) + .and_then(move |(w, _)| write_all(w, self.methods)) + .map(|(w, _)| w) + .boxed() } } @@ -469,19 +569,27 @@ impl HandshakeResponse { HandshakeResponse { chosen_method: cm } } - pub fn read_from(stream: &mut R) -> io::Result { - let ver = try!(stream.read_u8()); - if ver != SOCKS5_VERSION { - return Err(io::Error::new(io::ErrorKind::Other, "Invalid Socks5 version")); - } + pub fn read_from(r: R) -> BoxFuture<(R, HandshakeResponse), io::Error> + where R: Read + Send + 'static + { + read_exact(r, [0u8, 0u8]) + .and_then(|(r, buf)| { + let ver = buf[0]; + let met = buf[1]; - let met = try!(stream.read_u8()); - - Ok(HandshakeResponse { chosen_method: met }) + if ver != SOCKS5_VERSION { + Err(io::Error::new(io::ErrorKind::Other, "Invalid Socks5 version")) + } else { + Ok((r, HandshakeResponse { chosen_method: met })) + } + }) + .boxed() } - pub fn write_to(&self, stream: &mut Write) -> io::Result<()> { - stream.write_all(&[SOCKS5_VERSION, self.chosen_method]) + pub fn write_to(self, w: W) -> BoxFuture + where W: Write + Send + 'static + { + write_all(w, [SOCKS5_VERSION, self.chosen_method]).map(|(w, _)| w).boxed() } } @@ -499,19 +607,29 @@ impl UdpAssociateHeader { } } - pub fn read_from(reader: &mut R) -> Result { - let _ = try!(reader.read_u8()); - let _ = try!(reader.read_u8()); - let frag = try!(reader.read_u8()); - - Ok(UdpAssociateHeader::new(frag, try!(Address::read_from(reader)))) + pub fn read_from(r: R) -> BoxFuture<(R, UdpAssociateHeader), Error> + where R: Read + Send + 'static + { + read_exact(r, [0u8; 3]) + .map_err(From::from) + .and_then(|(r, buf)| { + let frag = buf[2]; + Address::read_from(r).map(move |(r, address)| { + let h = UdpAssociateHeader::new(frag, address); + (r, h) + }) + }) + .boxed() } - pub fn write_to(&self, writer: &mut W) -> io::Result<()> { - try!(writer.write_all(&[0x00, 0x00, self.frag])); - try!(self.address.write_to(writer)); - - Ok(()) + pub fn write_to(&self, w: W) -> BoxFuture + where W: Write + Send + 'static + { + let addr = self.address.clone(); + write_all(w, [0x00, 0x00, self.frag]) + .map_err(From::from) + .and_then(move |(w, _)| addr.write_to(w)) + .boxed() } pub fn len(&self) -> usize { diff --git a/src/relay/tcprelay/http.rs b/src/relay/tcprelay/http.rs index ff11719a..fd26978f 100644 --- a/src/relay/tcprelay/http.rs +++ b/src/relay/tcprelay/http.rs @@ -21,10 +21,10 @@ /// Http Proxy -use std::io::{self, Write}; +use std::io::{self, Read, Write}; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::mem; -use hyper::server::response::Response; use hyper::uri::RequestUri; use hyper::header::Headers; use hyper::status::StatusCode; @@ -36,8 +36,13 @@ use httparse::{self, Request}; use url::Host; +use futures::{self, Future, BoxFuture, Poll}; + +use tokio_core::io::write_all; + use relay::socks5::Address; +#[derive(Debug)] pub struct HttpRequest { pub version: HttpVersion, pub method: Method, @@ -85,20 +90,27 @@ impl HttpRequest { } } - pub fn write_to(&self, w: &mut W) -> io::Result<()> { - try!(write!(w, - "{} {} {}\r\n", - self.method, - self.request_uri, - self.version)); + pub fn write_to(self, w: W) -> BoxFuture + where W: Write + Send + 'static + { + futures::lazy(move || { + let mut w = Vec::new(); + try!(write!(w, + "{} {} {}\r\n", + self.method, + self.request_uri, + self.version)); - for header in self.headers.iter() { - try!(write!(w, "{}: {}\r\n", header.name(), header.value_string())); - } + for header in self.headers.iter() { + try!(write!(w, "{}: {}\r\n", header.name(), header.value_string())); + } - try!(write!(w, "\r\n")); - - Ok(()) + try!(write!(w, "\r\n")); + Ok(w) + }) + .and_then(|buf| write_all(w, buf)) + .map(|(w, _)| w) + .boxed() } } @@ -152,11 +164,93 @@ pub fn get_address(uri: &RequestUri) -> Result { } } -pub fn write_response(stream: &mut Write, status: StatusCode) -> io::Result<()> { - let mut headers = Headers::new(); - let mut resp = Response::new(stream, &mut headers); - *resp.status_mut() = status; - try!(resp.start().and_then(|r| r.end())); - - Ok(()) +pub fn write_response(w: W, version: HttpVersion, status: StatusCode) -> BoxFuture + where W: Write + Send + 'static +{ + let buf = format!("{} {}\r\n\r\n", version, status); + write_all(w, buf.into_bytes()).map(|(w, _)| w).boxed() } + +/// HTTP Client +pub enum RequestReader + where R: Read +{ + Pending { r: R, buf: Vec }, + Empty, +} + +impl RequestReader + where R: Read +{ + pub fn new(r: R) -> RequestReader { + RequestReader::with_buf(r, Vec::new()) + } + + pub fn with_buf(r: R, buf: Vec) -> RequestReader { + RequestReader::Pending { r: r, buf: buf } + } +} + +impl Future for RequestReader + where R: Read +{ + type Item = (R, HttpRequest, Vec); + type Error = io::Error; + + fn poll(&mut self) -> Poll { + let mut lbuf = [0u8; 4096]; + let (req, len) = match self { + &mut RequestReader::Pending { ref mut r, ref mut buf } => { + let mut http_req = None; + let mut total_len = 0; + loop { + let n = try_nb!(r.read(&mut lbuf)); + buf.extend_from_slice(&lbuf[..n]); + + // Maximum 128 headers + let mut headers = [httparse::EMPTY_HEADER; 128]; + let headers_ptr = &headers as *const _; + let mut req = Request::new(&mut headers); + match req.parse(&mut buf[..]) { + Ok(httparse::Status::Partial) => { + if n == 0 { + // Already EOF! + let err = io::Error::new(io::ErrorKind::UnexpectedEof, "Unexpected Eof"); + return Err(err); + } + } + Ok(httparse::Status::Complete(len)) => { + total_len = len; + + // Make borrow checker happy + let headers_ref = unsafe { &*headers_ptr }; + let hreq = match HttpRequest::from_raw(&req, headers_ref) { + Ok(r) => r, + Err(err) => { + error!("HttpRequest::from_raw: {}", err); + let err = io::Error::new(io::ErrorKind::Other, "Hyper error"); + return Err(err); + } + }; + http_req = Some(hreq); + break; + } + Err(err) => { + error!("Request parse: {:?}", err); + let err = io::Error::new(io::ErrorKind::Other, "Hyper error"); + return Err(err); + } + } + } + + (http_req.unwrap(), total_len) + } + &mut RequestReader::Empty => panic!("poll a RequestReader after it's done"), + }; + + match mem::replace(self, RequestReader::Empty) { + RequestReader::Pending { r, buf } => Ok((r, req, buf[len..].to_vec()).into()), + RequestReader::Empty => unreachable!(), + } + } +} \ No newline at end of file diff --git a/src/relay/tcprelay/local.rs b/src/relay/tcprelay/local.rs index f96c2f5d..cc167729 100644 --- a/src/relay/tcprelay/local.rs +++ b/src/relay/tcprelay/local.rs @@ -21,693 +21,399 @@ //! TcpRelay server that running on local environment -use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; -use std::net::lookup_host; -use std::io::{self, BufWriter, BufReader, Read, Write}; -use std::collections::BTreeMap; +use std::io; +use std::net::SocketAddr; use std::sync::Arc; -use coio::Scheduler; -use coio::net::{TcpListener, TcpStream, Shutdown}; +use futures::{self, Future, BoxFuture}; +use futures::stream::Stream; +use tokio_core::net::{TcpStream, TcpListener}; +use tokio_core::reactor::Handle; +use tokio_core::io::Io; +use tokio_core::io::{ReadHalf, WriteHalf}; +use tokio_core::io::{flush, copy, write_all}; + +use hyper::header::ContentLength; use hyper::method::Method; -use hyper::header; -use httparse::{self, Request}; +use config::{Config, ServerConfig}; -use config::{Config, ClientConfig}; +use relay::socks5::{self, HandshakeRequest, HandshakeResponse, Address}; +use relay::socks5::{TcpRequestHeader, TcpResponseHeader}; +use relay::loadbalancing::server::RoundRobin; +use relay::loadbalancing::server::LoadBalancer; -use relay::socks5::{self, Address}; -use relay::loadbalancing::server::{LoadBalancer, RoundRobin}; +use super::http::{self, RequestReader}; -use super::http::HttpRequest; - -use crypto::cipher::CipherType; - -#[derive(Clone)] +/// TCP relay local server pub struct TcpRelayLocal { config: Arc, } impl TcpRelayLocal { - pub fn new(c: Config) -> TcpRelayLocal { - if c.server.is_empty() || c.local.is_none() { - panic!("You have to provide configuration for server and local"); - } - - TcpRelayLocal { config: Arc::new(c) } + pub fn new(config: Arc) -> TcpRelayLocal { + TcpRelayLocal { config: config } } - fn do_handshake(reader: &mut R, writer: &mut W) -> io::Result<()> { - // Read the handshake header - let req = try!(socks5::HandshakeRequest::read_from(reader)); - trace!("Got handshake {:?}", req); - - if !req.methods.contains(&socks5::SOCKS5_AUTH_METHOD_NONE) { - let resp = socks5::HandshakeResponse::new(socks5::SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE); - try!(resp.write_to(writer)); - warn!("Currently shadowsocks-rust does not support authentication"); - return Err(io::Error::new(io::ErrorKind::Other, - "Currently shadowsocks-rust does not support \ - authentication")); - } - - // Reply to client - let resp = socks5::HandshakeResponse::new(socks5::SOCKS5_AUTH_METHOD_NONE); - trace!("Reply handshake {:?}", resp); - resp.write_to(writer) - } - - fn handle_udp_associate_local(stream: &mut W, - _addr: SocketAddr, - _dest_addr: &socks5::Address, - local_conf: ClientConfig) - -> io::Result<()> { - let reply = socks5::TcpResponseHeader::new(socks5::Reply::Succeeded, - socks5::Address::SocketAddress(local_conf)); - trace!("Replying Header for UDP ASSOCIATE, {:?}", reply); - try!(reply.write_to(stream)); - - // TODO: record this client's information for udprelay local server to validate - // whether the client has already authenticated - - Ok(()) - } - - fn handle_tcp_client(stream: TcpStream, - server_addr: SocketAddr, - password: Vec, - encrypt_method: CipherType, - conf: Arc) { - let sockname = match stream.peer_addr() { - Ok(sockname) => sockname, - Err(err) => { - error!("Failed to get peer addr: {}", err); - return; + pub fn run(self, handle: Handle) -> Box> { + let tcp_fut = Socks5RelayLocal::new(self.config.clone()).run(handle.clone()); + match &self.config.http_proxy { + &Some(..) => { + let http_fut = HttpRelayServer::new(self.config.clone()).run(handle); + Box::new(tcp_fut.join(http_fut) + .map(|_| ())) } - }; - - let stream_writer = match stream.try_clone() { - Ok(s) => s, - Err(err) => { - error!("Failed to clone local stream: {}", err); - return; - } - }; - let mut local_reader = BufReader::new(stream); - let mut local_writer = BufWriter::new(stream_writer); - - if let Err(err) = TcpRelayLocal::do_handshake(&mut local_reader, &mut local_writer) { - error!("Error occurs while doing handshake: {}", err); - return; - } - - if let Err(err) = local_writer.flush() { - error!("Error occurs while flushing local writer: {}", err); - return; - } - - let header = match socks5::TcpRequestHeader::read_from(&mut local_reader) { - Ok(h) => h, - Err(err) => { - let header = socks5::TcpResponseHeader::new(err.reply, socks5::Address::SocketAddress(sockname)); - error!("Failed to read request header: {}", err); - if let Err(err) = header.write_to(&mut local_writer) { - error!("Failed to write response header to local stream: {}", err); - } - return; - } - }; - - trace!("Got header {:?}", header); - - let addr = header.address; - - match header.command { - socks5::Command::TcpConnect => { - info!("CONNECT {}", addr); - - let (mut decrypt_stream, mut encrypt_stream) = - match super::connect_proxy_server(&server_addr, encrypt_method, &password[..], &addr) { - Ok(x) => x, - Err(err) => { - error!("Failed to connect to proxy server: {:?}", err); - return; - } - }; - - // Send header to client - { - let header = socks5::TcpResponseHeader::new(socks5::Reply::Succeeded, - socks5::Address::SocketAddress(sockname)); - trace!("Send header to client {:?}", header); - if let Err(err) = header.write_to(&mut local_writer) - .and_then(|_| local_writer.flush()) { - error!("Error occurs while writing header to local stream: {}", err); - return; - } - } - - let addr_cloned = addr.clone(); - - Scheduler::spawn(move || { - let _guard = super::TcpWorkCounter::new(); - - loop { - match ::relay::copy_once(&mut local_reader, &mut encrypt_stream) { - Ok(0) => { - trace!("{} local -> remote: EOF", addr_cloned); - break; - } - Ok(n) => { - trace!("{} local -> remote: relayed {} bytes", addr_cloned, n); - } - Err(err) => { - error!("SYSTEM Connect {} local -> remote: {}", addr_cloned, err); - break; - } - } - } - - debug!("SYSTEM Connect {} local -> remote is closing", addr_cloned); - - let _ = encrypt_stream.get_ref().shutdown(Shutdown::Both); - let _ = local_reader.get_ref().shutdown(Shutdown::Both); - }); - - Scheduler::spawn(move || { - let mut local_writer = match local_writer.into_inner() { - Ok(writer) => writer, - Err(err) => { - error!("Error occurs while taking out local writer: {}", err); - return; - } - }; - - loop { - match ::relay::copy_once(&mut decrypt_stream, &mut local_writer) { - Ok(0) => { - trace!("{} local <- remote: EOF", addr); - break; - } - Ok(n) => { - trace!("{} local <- remote: relayed {} bytes", addr, n); - } - Err(err) => { - error!("SYSTEM Connect {} local <- remote: {}", addr, err); - break; - } - } - } - - let _ = local_writer.flush(); - - debug!("SYSTEM Connect {} local <- remote is closing", addr); - - let _ = decrypt_stream.get_mut().shutdown(Shutdown::Both); - let _ = local_writer.shutdown(Shutdown::Both); - }); - } - socks5::Command::TcpBind => { - warn!("BIND is not supported"); - socks5::TcpResponseHeader::new(socks5::Reply::CommandNotSupported, addr) - .write_to(&mut local_writer) - .unwrap_or_else(|err| error!("Failed to write BIND response: {}", err)); - } - socks5::Command::UdpAssociate => { - info!("{} requests for UDP ASSOCIATE", sockname); - if cfg!(feature = "enable-udp") && conf.enable_udp { - TcpRelayLocal::handle_udp_associate_local(&mut local_writer, sockname, &addr, conf.local.unwrap()) - .unwrap_or_else(|err| error!("Failed to write UDP ASSOCIATE response: {}", err)); - } else { - warn!("UDP ASSOCIATE is disabled"); - socks5::TcpResponseHeader::new(socks5::Reply::CommandNotSupported, addr) - .write_to(&mut local_writer) - .unwrap_or_else(|err| error!("Failed to write UDP ASSOCIATE response: {}", err)); - } - } - } - } - - fn handle_http_connect(stream: TcpStream, - stream_writer: TcpStream, - addr: Address, - server_addr: SocketAddr, - password: Vec, - encrypt_method: CipherType, - remain: &[u8]) - -> io::Result<()> { - info!("CONNECT (HTTP) {}", addr); - - let mut local_reader = BufReader::new(stream); - let mut local_writer = stream_writer; - - const HANDSHAKE: &'static [u8] = b"HTTP/1.1 200 Connection Established\r\n\r\n"; - - if let Err(err) = local_writer.write_all(HANDSHAKE).and_then(|_| local_writer.flush()) { - error!("Failed to send handshake: {:?}", err); - return Err(err); - } - - trace!("HTTP Connect: Sent HTTP tunnel handshakes"); - - let (mut decrypt_stream, mut encrypt_stream) = - match super::connect_proxy_server(&server_addr, encrypt_method, &password[..], &addr) { - Ok(x) => x, - Err(err) => { - error!("Failed to connect to proxy server: {}", err); - return Err(err); - } - }; - - trace!("HTTP Connect: Connected remote server"); - - try!(encrypt_stream.write_all(remain).and_then(|_| encrypt_stream.flush())); - - let addr_cloned = addr.clone(); - - Scheduler::spawn(move || { - let _guard = super::HttpWorkCounter::new(); - - loop { - match ::relay::copy_once(&mut local_reader, &mut encrypt_stream) { - Ok(0) => { - trace!("HTTP Connect: {} local -> remote: EOF", addr_cloned); - break; - } - Ok(n) => { - trace!("HTTP Connect: {} local -> remote: relayed {} bytes", - addr_cloned, - n); - } - Err(err) => { - error!("SYSTEM HTTP Connect {} local -> remote: {}", - addr_cloned, - err); - break; - } - } - } - - debug!("SYSTEM HTTP Connect {} local -> remote is closing", - addr_cloned); - - let _ = encrypt_stream.get_ref().shutdown(Shutdown::Both); - let _ = local_reader.get_ref().shutdown(Shutdown::Both); - }); - - Scheduler::spawn(move || { - loop { - match ::relay::copy_once(&mut decrypt_stream, &mut local_writer) { - Ok(0) => { - trace!("HTTP Connect: {} local <- remote: EOF", addr); - break; - } - Ok(n) => { - trace!("HTTP Connect: {} local <- remote: relayed {} bytes", - addr, - n); - } - Err(err) => { - error!("SYSTEM HTTP Connect {} local <- remote: {}", addr, err); - break; - } - } - } - - let _ = local_writer.flush(); - - debug!("SYSTEM HTTP Connect {} local <- remote is closing", addr); - - let _ = decrypt_stream.get_mut().shutdown(Shutdown::Both); - let _ = local_writer.shutdown(Shutdown::Both); - }); - - Ok(()) - } - - fn handle_http_others(mut req: HttpRequest, - stream: TcpStream, - stream_writer: TcpStream, - addr: Address, - server_addr: SocketAddr, - password: Vec, - encrypt_method: CipherType, - remain: &[u8]) - -> io::Result<()> { - info!("{} (HTTP) {}", req.method, addr); - - let mut local_reader = BufReader::new(stream); - let mut local_writer = stream_writer; - - let (mut decrypt_stream, mut encrypt_stream) = - match super::connect_proxy_server(&server_addr, encrypt_method, &password[..], &addr) { - Ok(x) => x, - Err(err) => { - error!("Failed to connect to proxy server: {}", err); - return Err(err); - } - }; - - trace!("HTTP Proxy: Connected remote server"); - trace!("HTTP Proxy: {} Target url {}", req.method, req.request_uri); - - req.clear_request_uri_host(); - - try!(req.write_to(&mut encrypt_stream)); - try!(encrypt_stream.write_all(remain)); - - let addr_cloned = addr.clone(); - - let content_len = req.headers.get::().unwrap_or(&header::ContentLength(0)).0 as usize; - let mut remain_len = content_len.saturating_sub(remain.len()); - - Scheduler::spawn(move || { - let _guard = super::HttpWorkCounter::new(); - - let mut buf = [0u8; 1024]; - - let mut content_len = content_len; - - 'outer: loop { - // 1. Send body - match ::relay::copy_exact(&mut local_reader, &mut encrypt_stream, remain_len) { - Ok(..) => {} - Err(err) => { - error!("Failed to relay body: {:?}", err); - break; - } - } - - trace!("HTTP Proxy: Written body {} bytes", content_len); - - if let Err(err) = encrypt_stream.flush() { - error!("Failed to flush: {}", err); - return; - } - - // 2. Read another header - let mut req_buf = Vec::with_capacity(8192); - let mut headers = [httparse::EMPTY_HEADER; 100]; - - while let Ok(n) = local_reader.read(&mut buf) { - use httparse::Status; - - let is_eof = n == 0; - - if is_eof && req_buf.is_empty() { - break 'outer; - } - - req_buf.extend_from_slice(&buf[..n]); - let mut req = Request::new(&mut headers); - match req.parse(&req_buf[..]) { - Ok(Status::Complete(reqlen)) => { - let mut request = match HttpRequest::from_raw(&req, req.headers) { - Ok(r) => r, - Err(err) => { - error!("Failed to parse HttpRequest: {}", err); - break; - } - }; - - trace!("HTTP Proxy: {} Target url {}", - request.method, - request.request_uri); - - request.clear_request_uri_host(); - if let Err(err) = request.write_to(&mut encrypt_stream) { - error!("Failed to write HttpRequest: {}", err); - break; - } - - if let Err(err) = encrypt_stream.write_all(&req_buf[reqlen..]) { - error!("Failed to write to remote: {}", err); - break; - } - - content_len = request.headers - .get::() - .unwrap_or(&header::ContentLength(0)) - .0 as usize; - remain_len = content_len.saturating_sub(req_buf[reqlen..].len()); - - break; - } - _ => { - if is_eof { - error!("Unexpected Eof"); - break; - } - } - } - } - } - - debug!("SYSTEM Connect {} local -> remote is closing", addr_cloned); - - let _ = encrypt_stream.get_ref().shutdown(Shutdown::Both); - let _ = local_reader.get_ref().shutdown(Shutdown::Both); - }); - - Scheduler::spawn(move || { - loop { - match ::relay::copy_once(&mut decrypt_stream, &mut local_writer) { - Ok(0) => { - trace!("{} local <- remote: EOF", addr); - break; - } - Ok(n) => { - trace!("{} local <- remote: relayed {} bytes", addr, n); - } - Err(err) => { - error!("SYSTEM Connect {} local <- remote: {}", addr, err); - break; - } - } - } - - let _ = local_writer.flush(); - - debug!("SYSTEM Connect {} local <- remote is closing", addr); - - let _ = decrypt_stream.get_mut().shutdown(Shutdown::Both); - let _ = local_writer.shutdown(Shutdown::Both); - }); - - Ok(()) - } - - fn handle_http_client(mut stream: TcpStream, - server_addr: SocketAddr, - password: Vec, - encrypt_method: CipherType) { - use super::http::{get_address, write_response}; - - let mut stream_writer = match stream.try_clone() { - Ok(s) => s, - Err(err) => { - error!("Failed to clone stream: {:?}", err); - return; - } - }; - - let mut req_buf = Vec::with_capacity(8192); - let mut got_header = false; - - let mut headers = [httparse::EMPTY_HEADER; 100]; - - let mut buf = [0u8; 1024]; - while let Ok(n) = stream.read(&mut buf) { - use httparse::Status; - - if n == 0 && req_buf.is_empty() { - // EOF - got_header = true; - break; - } - - req_buf.extend_from_slice(&buf[..n]); - let mut req = Request::new(&mut headers); - match req.parse(&req_buf[..]) { - Ok(Status::Complete(reqlen)) => { - got_header = true; - - let request = match HttpRequest::from_raw(&req, req.headers) { - Ok(r) => r, - Err(err) => { - error!("Failed to create HttpRequest: {:?}", err); - return; - } - }; - - let addr = match get_address(&request.request_uri) { - Ok(addr) => addr, - Err(status) => { - let _ = write_response(&mut stream_writer, status); - return; - } - }; - - match request.method.clone() { - Method::Connect => { - let _ = TcpRelayLocal::handle_http_connect(stream, - stream_writer, - addr, - server_addr, - password, - encrypt_method, - &req_buf[reqlen..]); - } - _ => { - let _ = TcpRelayLocal::handle_http_others(request, - stream, - stream_writer, - addr, - server_addr, - password, - encrypt_method, - &req_buf[reqlen..]); - } - } - - break; - } - Ok(Status::Partial) => {} - Err(err) => { - error!("Failed to parse HTTP request: {:?}", err); - return; - } - } - } - - if !got_header { - error!("Failed to get full HTTP Request"); + &None => tcp_fut, } } } +/// Socks5 local server +pub struct Socks5RelayLocal { + config: Arc, +} +impl Socks5RelayLocal { + pub fn new(config: Arc) -> Socks5RelayLocal { + Socks5RelayLocal { config: config } + } -impl TcpRelayLocal { - fn run_server(&self, local_conf: SocketAddr, handler: F) - where F: Fn(TcpStream, SocketAddr, Vec, CipherType, Arc) - { - let mut server_load_balancer = RoundRobin::new(self.config.server.clone()); + fn handle_socks5_connect(handle: &Handle, + (r, w): (ReadHalf, WriteHalf), + client_addr: SocketAddr, + addr: Address, + svr_cfg: Arc) + -> BoxFuture<(), io::Error> { + let cloned_addr = addr.clone(); + super::connect_proxy_server(handle, svr_cfg, addr) + .and_then(move |(svr_r, svr_w)| { + let header = TcpResponseHeader::new(socks5::Reply::Succeeded, + Address::SocketAddress(client_addr)); + trace!("Send header: {:?}", header); - let acceptor = match TcpListener::bind(&local_conf) { - Ok(acpt) => acpt, - Err(e) => { - panic!("Error occurs while listening local address: {}", - e.to_string()); - } - }; + header.write_to(w) + .and_then(|w| flush(w)) + .and_then(|w| Ok((svr_r, svr_w, w))) + }) + .and_then(move |(svr_r, svr_w, w)| { + let c2s = copy(r, svr_w); + let s2c = copy(svr_r, w); + c2s.join(s2c) + .and_then(move |(c2s_amt, s2c_amt)| { + trace!("Relayed {} client -> remote {}bytes", cloned_addr, c2s_amt); + trace!("Relayed {} client <- remote {}bytes", cloned_addr, s2c_amt); + Ok(()) + }) + }) + .boxed() + } - info!("Shadowsocks listening on {}", local_conf); + fn handle_client(handle: &Handle, s: TcpStream, _: SocketAddr, conf: Arc) -> io::Result<()> { + let cloned_handle = handle.clone(); + let client_addr = try!(s.peer_addr()); + let cloned_client_addr = client_addr.clone(); + let fut = futures::lazy(|| Ok(s.split())) + .and_then(|(r, w)| { + // Socks5 handshakes + HandshakeRequest::read_from(r).and_then(move |(r, req)| { + trace!("Socks5 {:?}", req); - let mut cached_proxy: BTreeMap = BTreeMap::new(); - - for s in acceptor.incoming() { - let stream = match s { - Ok((s, addr)) => { - debug!("Got connection from client {:?}", addr); - s - } - Err(err) => { - panic!("Error occurs while accepting: {:?}", err); - } - }; - - if let Err(err) = stream.set_read_timeout(self.config.timeout) { - error!("Failed to set read timeout: {:?}", err); - continue; - } - - if let Err(err) = stream.set_nodelay(true) { - error!("Failed to set no delay: {:?}", err); - continue; - } - - let mut succeed = false; - for _ in 0..server_load_balancer.total() { - let ref server_cfg = server_load_balancer.pick_server(); - let addr = { - match cached_proxy.get(&server_cfg.addr[..]).map(|x| x.clone()) { - Some(addr) => addr, - None => { - match lookup_host(&server_cfg.addr[..]) { - Ok(mut addr_itr) => { - match addr_itr.next() { - None => { - error!("cannot resolve proxy server `{}`", server_cfg.addr); - continue; - } - Some(addr) => { - let addr = addr.clone(); - cached_proxy.insert(server_cfg.addr.clone(), addr.clone()); - addr - } - } - } - Err(err) => { - error!("cannot resolve proxy server `{}`: {}", server_cfg.addr, err); - continue; - } - } + if !req.methods.contains(&socks5::SOCKS5_AUTH_METHOD_NONE) { + let resp = HandshakeResponse::new(socks5::SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE); + resp.write_to(w) + .then(|_| { + warn!("Currently shadowsocks-rust does not support authentication"); + Err(io::Error::new(io::ErrorKind::Other, + "Currently shadowsocks-rust does not support authentication")) + }) + .boxed() + } else { + // Reply to client + let resp = HandshakeResponse::new(socks5::SOCKS5_AUTH_METHOD_NONE); + trace!("Reply handshake {:?}", resp); + resp.write_to(w).and_then(|w| Ok((r, w))).boxed() + } + }) + }) + .and_then(move |(r, w)| { + // Fetch headers + TcpRequestHeader::read_from(r).then(move |res| { + match res { + Ok((r, h)) => futures::finished((r, w, h)).boxed(), + Err(err) => { + error!("Failed to get TcpRequestHeader: {}", err); + TcpResponseHeader::new(err.reply, Address::SocketAddress(client_addr)) + .write_to(w) + .then(|_| Err(From::from(err))) + .boxed() } } - }; + }) + }) + .and_then(move |(r, w, header)| { + trace!("Socks5 {:?}", header); - let server_addr = match addr { - SocketAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr.ip().clone(), server_cfg.port)), - SocketAddr::V6(addr) => { - SocketAddr::V6(SocketAddrV6::new(addr.ip().clone(), - server_cfg.port, - addr.flowinfo(), - addr.scope_id())) + let addr = header.address; + match header.command { + socks5::Command::TcpConnect => { + info!("CONNECT {}", addr); + Socks5RelayLocal::handle_socks5_connect(&cloned_handle, (r, w), cloned_client_addr, addr, conf) } - }; - - if self.config.forbidden_ip.contains(&::relay::take_ip_addr(&server_addr)) { - info!("{} is in `forbidden_ip` list, skipping", server_addr); - continue; + socks5::Command::TcpBind => { + warn!("BIND is not supported"); + TcpResponseHeader::new(socks5::Reply::CommandNotSupported, addr) + .write_to(w) + .map(|_| ()) + .boxed() + } + socks5::Command::UdpAssociate => unimplemented!(), } + }); - debug!("Using proxy `{}:{}` (`{}`)", - server_cfg.addr, - server_cfg.port, - server_addr); - let encrypt_method = server_cfg.method.clone(); - let pwd = encrypt_method.bytes_to_key(server_cfg.password.as_bytes()); - - let conf = self.config.clone(); - handler(stream, server_addr, pwd, encrypt_method, conf); - - succeed = true; - break; + // Runs in Tokio + handle.spawn(fut.then(|res| { + match res { + Ok(..) => Ok(()), + Err(err) => { + error!("Failed to handle client: {}", err); + Err(()) + } } + })); - if !succeed { - panic!("All proxy servers are failed!"); - } - } + Ok(()) } - pub fn run_tcp(&self) { - self.run_server(self.config.local.expect("Require local config"), - |stream, server_addr, pwd, encrypt_method, conf| { - Scheduler::spawn(move || { - TcpRelayLocal::handle_tcp_client(stream, server_addr, pwd, encrypt_method, conf); - }); - }); - } + // Runs TCP relay local server + pub fn run(self, handle: Handle) -> Box> { + let listener = { + let local_addr = self.config.local.as_ref().unwrap(); + let listener = TcpListener::bind(local_addr, &handle).unwrap(); + info!("ShadowSocks TCP Listening on {}", local_addr); + listener + }; - pub fn run_http(&self) { - self.run_server(self.config.http_proxy.expect("Require local config"), - |stream, server_addr, pwd, encrypt_method, _| { - Scheduler::spawn(move || { - TcpRelayLocal::handle_http_client(stream, server_addr, pwd, encrypt_method); - }); - }); + let mut servers = RoundRobin::new(self.config); + let listening = listener.incoming() + .for_each(move |(socket, addr)| { + let server_cfg = servers.pick_server(); + trace!("Got connection, addr: {}", addr); + trace!("Picked proxy server: {:?}", server_cfg); + Socks5RelayLocal::handle_client(&handle, socket, addr, server_cfg) + }); + + Box::new(listening.map_err(|err| { + error!("Socks5 server run failed: {}", err); + err + })) + } +} + +/// HTTP local server +pub struct HttpRelayServer { + config: Arc, +} + +impl HttpRelayServer { + pub fn new(config: Arc) -> HttpRelayServer { + HttpRelayServer { config: config } + } + + fn handle_connect(handle: Handle, + (r, w): (ReadHalf, WriteHalf), + req: http::HttpRequest, + addr: Address, + remains: Vec, + svr_cfg: Arc) + -> BoxFuture<(), io::Error> { + let cloned_addr = addr.clone(); + let http_version = req.version; + super::connect_proxy_server(&handle, svr_cfg, addr) + .and_then(move |(svr_r, svr_w)| { + let handshake_resp = format!("{} 200 Connection Established\r\n\r\n", http_version); + write_all(w, handshake_resp.into_bytes()).and_then(|(w, _)| flush(w)).map(|w| (svr_r, svr_w, w)) + }) + .and_then(move |(svr_r, svr_w, w)| { + req.write_to(svr_w) + .and_then(|svr_w| write_all(svr_w, remains)) + .map(move |(svr_w, _)| (svr_r, svr_w, w)) + }) + .and_then(move |(svr_r, svr_w, w)| { + let c2s = copy(r, svr_w); + let s2c = copy(svr_r, w); + c2s.join(s2c) + .and_then(move |(c2s_amt, s2c_amt)| { + trace!("Relayed {} client -> remote {}bytes", cloned_addr, c2s_amt); + trace!("Relayed {} client <- remote {}bytes", cloned_addr, s2c_amt); + Ok(()) + }) + }) + .boxed() + } + + fn handle_http_again((r, w): (ReadHalf, WriteHalf), + (svr_r, svr_w): (super::DecryptedHalf, super::EncryptedHalf), + remains: Vec) + -> BoxFuture<(), io::Error> { + RequestReader::with_buf(r, remains) + .and_then(move |(r, mut req, remains)| { + trace!("Got HTTP Request, version: {}, method: {}, uri: {}", + req.version, + req.method, + req.request_uri); + + match http::get_address(&req.request_uri) { + Ok(..) => { + req.clear_request_uri_host(); + let content_length = req.headers.get::().unwrap_or(&ContentLength(0)).0; + req.write_to(svr_w) + .and_then(move |svr_w| { + if content_length == 0 { + // Do nothing because this request does not have a body + futures::finished((r, svr_w, remains)).boxed() + } else { + write_all(svr_w, remains) + .and_then(move |(svr_w, remains)| { + let remain_len = content_length - remains.len() as u64; + super::copy_exact(r, svr_w, remain_len as usize) + }) + .map(|(r, svr_w)| (r, svr_w, vec![])) + .boxed() + } + }) + .map(move |(r, svr_w, remains)| { + HttpRelayServer::handle_http_again((r, w), (svr_r, svr_w), remains) + }) + .boxed() + } + Err(status_code) => { + http::write_response(w, req.version, status_code) + .then(|_| { + let err = io::Error::new(io::ErrorKind::Other, "Invalid Uri"); + Err(err) + }) + .boxed() + } + } + }) + .then(|res| { + if let Err(err) = res { + error!("HTTP again: {}", err); + } + Ok(()) + }) + .boxed() + } + + fn handle_http_proxy(handle: Handle, + (r, w): (ReadHalf, WriteHalf), + req: http::HttpRequest, + addr: Address, + remains: Vec, + svr_cfg: Arc) + -> BoxFuture<(), io::Error> { + let content_length = req.headers.get::().unwrap_or(&ContentLength(0)).0; + + super::connect_proxy_server(&handle, svr_cfg, addr) + .and_then(move |(svr_r, svr_w)| { + trace!("Going to pass req to server: {:?}", req); + req.write_to(svr_w) + .and_then(move |svr_w| { + trace!("Going to relay request body, len: {}", content_length); + if content_length == 0 { + // Do nothing because this request does not have a body + futures::finished((r, svr_w, remains)).boxed() + } else { + write_all(svr_w, remains) + .and_then(move |(svr_w, remains)| { + let remain_len = content_length - remains.len() as u64; + super::copy_exact(r, svr_w, remain_len as usize) + }) + .map(|(r, svr_w)| (r, svr_w, vec![])) + .boxed() + } + }) + .map(move |(r, svr_w, remains)| { + HttpRelayServer::handle_http_again((r, w), (svr_r, svr_w), remains) + }) + }) + .map(|_| ()) + .boxed() + } + + fn handle_client(handle: &Handle, socket: TcpStream, _: SocketAddr, svr_cfg: Arc) -> io::Result<()> { + let cloned_handle = handle.clone(); + let fut = futures::lazy(|| Ok(socket.split())) + .and_then(|(r, w)| { + RequestReader::new(r).and_then(move |(r, mut req, remains)| { + trace!("Got HTTP Request, version: {}, method: {}, uri: {}", + req.version, + req.method, + req.request_uri); + + match http::get_address(&req.request_uri) { + Ok(addr) => { + req.clear_request_uri_host(); + futures::finished((r, w, req, addr, remains)).boxed() + } + Err(status_code) => { + http::write_response(w, req.version, status_code) + .then(|_| { + let err = io::Error::new(io::ErrorKind::Other, "Invalid Uri"); + Err(err) + }) + .boxed() + } + } + }) + }) + .and_then(move |(r, w, req, addr, remains)| { + match req.method.clone() { + Method::Connect => { + info!("CONNECT (Http) {}", addr); + HttpRelayServer::handle_connect(cloned_handle, (r, w), req, addr, remains, svr_cfg) + } + met => { + info!("{} (Http) {}", met, addr); + HttpRelayServer::handle_http_proxy(cloned_handle, (r, w), req, addr, remains, svr_cfg) + } + } + }); + + handle.spawn(fut.then(|res| { + match res { + Ok(..) => Ok(()), + Err(err) => { + error!("Failed to handle client: {}", err); + Err(()) + } + } + })); + + Ok(()) + } + + pub fn run(self, handle: Handle) -> Box> { + let listener = { + let local_addr = self.config.http_proxy.as_ref().unwrap(); + let listener = TcpListener::bind(local_addr, &handle).unwrap(); + info!("ShadowSocks HTTP Listening on {}", local_addr); + listener + }; + + let mut servers = RoundRobin::new(self.config); + let listening = listener.incoming() + .for_each(move |(socket, addr)| { + let server_cfg = servers.pick_server(); + trace!("Got connection, addr: {}", addr); + trace!("Picked proxy server: {:?}", server_cfg); + HttpRelayServer::handle_client(&handle, socket, addr, server_cfg) + }); + + Box::new(listening.map_err(|err| { + error!("HTTP server run failed: {}", err); + err + })) } } diff --git a/src/relay/tcprelay/mod.rs b/src/relay/tcprelay/mod.rs index 8f329c00..010aea00 100644 --- a/src/relay/tcprelay/mod.rs +++ b/src/relay/tcprelay/mod.rs @@ -21,184 +21,193 @@ //! TcpRelay implementation -use std::net::SocketAddr; use std::io::{self, Read, Write}; +use std::sync::Arc; +use std::mem; -use crypto::cipher::{self, CipherType}; +use crypto::cipher; use crypto::CryptoMode; use relay::socks5::Address; +use config::ServerConfig; -use coio::net::TcpStream; +use tokio_core::net::TcpStream; +use tokio_core::reactor::Handle; +use tokio_core::io::{read_exact, write_all, flush}; +use tokio_core::io::{ReadHalf, WriteHalf}; +use tokio_core::io::Io; -use self::stream::{DecryptedReader, EncryptedWriter}; +use futures::{Future, BoxFuture, Poll}; -mod cached_dns; +use self::stream::{EncryptedWriter, DecryptedReader}; + +// use coio::net::TcpStream; + +// use self::stream::{DecryptedReader, EncryptedWriter}; + +// mod cached_dns; pub mod local; pub mod server; mod stream; mod http; -fn connect_proxy_server(server_addr: &SocketAddr, - encrypt_method: CipherType, - pwd: &[u8], - relay_addr: &Address) - -> io::Result<(DecryptedReader, EncryptedWriter)> { - let mut remote_stream = try!(TcpStream::connect(&server_addr)); +type DecryptedHalf = DecryptedReader>; +type EncryptedHalf = EncryptedWriter>; - // Encrypt data to remote server +fn connect_proxy_server(handle: &Handle, + svr_cfg: Arc, + relay_addr: Address) + -> BoxFuture<(DecryptedHalf, EncryptedHalf), io::Error> { + TcpStream::connect(&svr_cfg.addr, handle) + .and_then(move |remote_stream| { + let (r, w) = remote_stream.split(); - // Send initialize vector to remote and create encryptor - let mut encrypt_stream = { - let local_iv = encrypt_method.gen_init_vec(); - trace!("Going to send initialize vector: {:?}", local_iv); - let encryptor = cipher::with_type(encrypt_method, pwd, &local_iv[..], CryptoMode::Encrypt); - if let Err(err) = remote_stream.write_all(&local_iv[..]) { - error!("Error occurs while writing initialize vector: {}", err); - return Err(err); + // Encrypt data to remote server + + // Send initialize vector to remote and create encryptor + + let local_iv = svr_cfg.method.gen_init_vec(); + trace!("Going to send initialize vector: {:?}", local_iv); + + + write_all(w, local_iv) + .and_then(move |(w, local_iv)| { + let encryptor = cipher::with_type(svr_cfg.method, + svr_cfg.password.as_bytes(), + &local_iv[..], + CryptoMode::Encrypt); + + Ok((svr_cfg, r, EncryptedWriter::new(w, encryptor))) + }) + .and_then(|(svr_cfg, r, enc_w)| { + trace!("Got encrypt stream and going to send addr: {:?}", + relay_addr); + + // Send relay address to remote + relay_addr.write_to(Vec::new()) + .and_then(|addr_buf| { + write_all(enc_w, addr_buf) + .and_then(|(enc_w, _)| flush(enc_w)) + .and_then(|enc_w| Ok((svr_cfg, r, enc_w))) + }) + }) + }) + .and_then(|(svr_cfg, r, enc_w)| { + // Decrypt data from remote server + + let iv_len = svr_cfg.method.iv_size(); + read_exact(r, vec![0u8; iv_len]).and_then(move |(r, remote_iv)| { + trace!("Got initialize vector {:?}", remote_iv); + + let decryptor = cipher::with_type(svr_cfg.method, + svr_cfg.password.as_bytes(), + &remote_iv[..], + CryptoMode::Decrypt); + let decrypt_stream = DecryptedReader::new(r, decryptor); + + trace!("Finished creating remote encrypt stream pair"); + + Ok((decrypt_stream, enc_w)) + }) + }) + .boxed() +} + +/// Copy exactly N bytes +pub enum CopyExact + where R: Read, + W: Write +{ + Pending { + reader: R, + writer: W, + buf: [u8; 4096], + remain: usize, + pos: usize, + cap: usize, + }, + Empty, +} + +impl CopyExact + where R: Read, + W: Write +{ + pub fn new(r: R, w: W, amt: usize) -> CopyExact { + CopyExact::Pending { + reader: r, + writer: w, + buf: [0u8; 4096], + remain: amt, + pos: 0, + cap: 0, } - - let remote_writer = match remote_stream.try_clone() { - Ok(s) => s, - Err(err) => { - error!("Error occurs while cloning remote stream: {}", err); - return Err(err); - } - }; - EncryptedWriter::new(remote_writer, encryptor) - }; - - trace!("Got encrypt stream and going to send addr: {:?}", - relay_addr); - - // Send relay address to remote - let mut addr_buf = Vec::new(); - try!(relay_addr.write_to(&mut addr_buf)); - if let Err(err) = encrypt_stream.write_all(&addr_buf).and_then(|_| encrypt_stream.flush()) { - error!("Error occurs while writing address: {}", err); - return Err(err); } +} - // Decrypt data from remote server +impl Future for CopyExact + where R: Read, + W: Write +{ + type Item = (R, W); + type Error = io::Error; - let remote_iv = { - let mut iv = Vec::with_capacity(encrypt_method.block_size()); - unsafe { - iv.set_len(encrypt_method.block_size()); - } + fn poll(&mut self) -> Poll { + match self { + &mut CopyExact::Empty => panic!("poll after CopyExact is finished"), + &mut CopyExact::Pending { ref mut reader, + ref mut writer, + ref mut buf, + ref mut remain, + ref mut pos, + ref mut cap } => { + loop { + // If our buffer is empty, then we need to read some data to + // continue. + if *pos == *cap && *remain != 0 { + let buf_len = if *remain > buf.len() { + buf.len() + } else { + *remain + }; + let n = try_nb!(reader.read(&mut buf[..buf_len])); + if n == 0 { + // Unexpected EOF! + let err = io::Error::new(io::ErrorKind::UnexpectedEof, "Unexpected Eof"); + return Err(err); + } else { + *pos = 0; + *cap = n; + *remain -= n; + } + } - let mut total_len = 0; - while total_len < encrypt_method.block_size() { - match remote_stream.read(&mut iv[total_len..]) { - Ok(0) => { - error!("Unexpected EOF while reading initialize vector"); - debug!("Already read: {:?}", &iv[..total_len]); + // If our buffer has some data, let's write it out! + while *pos < *cap { + let i = try_nb!(writer.write(&buf[*pos..*cap])); + *pos += i; + } - let err = io::Error::new(io::ErrorKind::UnexpectedEof, - "Unexpected EOF while reading initialize vector"); - return Err(err); - } - Ok(n) => total_len += n, - Err(err) => { - error!("Error while reading initialize vector: {}", err); - return Err(err); + // If we've written al the data and we've seen EOF, flush out the + // data and finish the transfer. + // done with the entire transfer. + if *pos == *cap && *remain == 0 { + try_nb!(writer.flush()); + break; // The only path to execute the following logic + } } } } - iv - }; - trace!("Got initialize vector {:?}", remote_iv); - - let decryptor = cipher::with_type(encrypt_method, pwd, &remote_iv[..], CryptoMode::Decrypt); - let decrypt_stream = DecryptedReader::new(remote_stream, decryptor); - - trace!("Finished creating remote encrypt stream pair"); - Ok((decrypt_stream, encrypt_stream)) -} - -#[cfg(debug_assertions)] -mod stat { - use std::sync::atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT}; - - static GLOBAL_TCP_WORK_COUNT: AtomicUsize = ATOMIC_USIZE_INIT; - static GLOBAL_HTTP_WORK_COUNT: AtomicUsize = ATOMIC_USIZE_INIT; - - pub fn global_tcp_work_count_add() { - GLOBAL_TCP_WORK_COUNT.fetch_add(1, Ordering::Relaxed); - } - pub fn global_tcp_work_count_sub() { - GLOBAL_TCP_WORK_COUNT.fetch_sub(1, Ordering::Relaxed); - } - - pub fn global_tcp_work_count_get() -> usize { - GLOBAL_TCP_WORK_COUNT.load(Ordering::Relaxed) - } - - pub fn global_http_work_count_add() { - GLOBAL_HTTP_WORK_COUNT.fetch_add(1, Ordering::Relaxed); - } - pub fn global_http_work_count_sub() { - GLOBAL_HTTP_WORK_COUNT.fetch_sub(1, Ordering::Relaxed); - } - - pub fn global_http_work_count_get() -> usize { - GLOBAL_HTTP_WORK_COUNT.load(Ordering::Relaxed) + match mem::replace(self, CopyExact::Empty) { + CopyExact::Pending { reader, writer, .. } => Ok((reader, writer).into()), + CopyExact::Empty => unreachable!(), + } } } -#[cfg(not(debug_assertions))] -mod stat { - pub fn global_tcp_work_count_add() {} - pub fn global_tcp_work_count_sub() {} - - pub fn global_tcp_work_count_get() -> usize { - 0 - } - - pub fn global_http_work_count_add() {} - pub fn global_http_work_count_sub() {} - - pub fn global_http_work_count_get() -> usize { - 0 - } -} - -struct TcpWorkCounter; - -impl TcpWorkCounter { - fn new() -> TcpWorkCounter { - stat::global_tcp_work_count_add(); - TcpWorkCounter - } -} - -impl Drop for TcpWorkCounter { - fn drop(&mut self) { - stat::global_tcp_work_count_sub(); - } -} - -struct HttpWorkCounter; - -impl HttpWorkCounter { - fn new() -> HttpWorkCounter { - stat::global_http_work_count_add(); - HttpWorkCounter - } -} - -impl Drop for HttpWorkCounter { - fn drop(&mut self) { - stat::global_http_work_count_sub(); - } -} - -/// Get total TCP relay work count -pub fn global_tcp_work_count() -> usize { - stat::global_tcp_work_count_get() -} - -/// Get total HTTP relay work count -pub fn global_http_work_count() -> usize { - stat::global_http_work_count_get() +pub fn copy_exact(r: R, w: W, amt: usize) -> CopyExact + where R: Read, + W: Write +{ + CopyExact::new(r, w, amt) } diff --git a/src/relay/tcprelay/server.rs b/src/relay/tcprelay/server.rs index a56839cf..97ea52c2 100644 --- a/src/relay/tcprelay/server.rs +++ b/src/relay/tcprelay/server.rs @@ -21,314 +21,213 @@ //! TcpRelay server that running on the server side +use std::io; +use std::net::{SocketAddr, ToSocketAddrs}; use std::sync::Arc; -use std::io::{self, Read, Write, BufReader}; -use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; use std::collections::HashSet; +use config::{Config, ServerConfig}; + +use crypto::CryptoMode; +use crypto::cipher; + +use super::stream::{EncryptedWriter, DecryptedReader}; + +use relay::socks5::Address; + +use futures::{self, Future, BoxFuture}; +use futures::stream::Stream; + +use futures_cpupool::CpuPool; + +use tokio_core::reactor::Handle; +use tokio_core::net::{TcpStream, TcpListener}; +use tokio_core::io::Io; +use tokio_core::io::{ReadHalf, WriteHalf}; +use tokio_core::io::{read_exact, write_all, copy}; + use ip::IpAddr; -use coio::Scheduler; -use coio::net::{TcpListener, TcpStream, Shutdown}; +type ClientRead = ReadHalf; +type ClientWrite = WriteHalf; -use config::{Config, ServerConfig}; -use relay::socks5; -use relay::tcprelay::cached_dns::CachedDns; -use relay::tcprelay::stream::{DecryptedReader, EncryptedWriter}; -use crypto::cipher; -use crypto::CryptoMode; +type EncryptedHalf = EncryptedWriter; +type DecryptedHalf = DecryptedReader; -#[derive(Clone)] +/// TCP Relay backend pub struct TcpRelayServer { - config: Config, + config: Arc, + cpu_pool: CpuPool, } impl TcpRelayServer { - pub fn new(c: Config) -> TcpRelayServer { - if c.server.is_empty() { - panic!("You have to provide a server configuration"); + /// Creates an instance + pub fn new(config: Arc, threads: usize) -> TcpRelayServer { + TcpRelayServer { + config: config, + cpu_pool: CpuPool::new(threads), } - TcpRelayServer { config: c } } - fn accept_loop(s: ServerConfig, forbidden_ip: Arc>) { - let acceptor = TcpListener::bind(&(&s.addr[..], s.port)) - .unwrap_or_else(|err| panic!("Failed to bind a TCP socket: {}", err)); + fn handshake((r, w): (ClientRead, ClientWrite), + svr_cfg: Arc) + -> BoxFuture<(DecryptedHalf, EncryptedHalf), io::Error> { + let iv_len = svr_cfg.method.iv_size(); + read_exact(r, vec![0u8; iv_len]) + .and_then(move |(r, iv)| { + trace!("Got handshake iv: {:?}", iv); + let decryptor = cipher::with_type(svr_cfg.method, + svr_cfg.password.as_bytes(), + &iv[..], + CryptoMode::Decrypt); + let decrypt_stream = DecryptedReader::new(r, decryptor); - info!("Shadowsocks listening on {}:{}", s.addr, s.port); + Ok((svr_cfg, decrypt_stream)) + }) + .and_then(|(svr_cfg, enc_r)| { + let iv = svr_cfg.method.gen_init_vec(); + trace!("Going to send handshake iv: {:?}", iv); + write_all(w, iv).and_then(move |(w, iv)| { + let encryptor = cipher::with_type(svr_cfg.method, + svr_cfg.password.as_bytes(), + &iv[..], + CryptoMode::Encrypt); + let encrypt_stream = EncryptedWriter::new(w, encryptor); - let dnscache_arc = Arc::new(CachedDns::with_capacity(s.dns_cache_capacity)); + Ok((enc_r, encrypt_stream)) + }) + }) + .boxed() + } - let pwd = s.method.bytes_to_key(s.password.as_bytes()); - let timeout = s.timeout; - let method = s.method; + fn resolve_address(addr: Address, cpu_pool: CpuPool) -> BoxFuture { + match addr { + Address::SocketAddress(addr) => futures::finished(addr).boxed(), + Address::DomainNameAddress(dname, port) => { + cpu_pool.spawn(futures::lazy(move || { + let dname = format!("{}:{}", dname, port); + let mut addrs = try!(dname.to_socket_addrs()); + addrs.next().ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed to resolve domain")) + })) + .boxed() + } + } + } - info!("Method {}, Timeout: {:?}", method, timeout); + fn resolve_remote(cpu_pool: CpuPool, + addr: Address, + forbidden_ip: Arc>) + -> BoxFuture { + TcpRelayServer::resolve_address(addr, cpu_pool) + .and_then(move |addr| { + trace!("Resolved address as {}", addr); + let ipaddr = match addr.clone() { + SocketAddr::V4(v4) => IpAddr::V4(v4.ip().clone()), + SocketAddr::V6(v6) => IpAddr::V6(v6.ip().clone()), + }; - for s in acceptor.incoming() { - let mut stream = match s { - Ok((s, addr)) => { - debug!("Got connection from {:?}", addr); - s + if forbidden_ip.contains(&ipaddr) { + info!("{} has been forbidden", ipaddr); + let err = io::Error::new(io::ErrorKind::Other, "Forbidden IP"); + Err(err) + } else { + Ok(addr) } + }) + .boxed() + } + + fn connect_remote(cpu_pool: CpuPool, + handle: Handle, + addr: Address, + forbidden_ip: Arc>) + -> Box> { + trace!("Connecting to remote {}", addr); + Box::new(TcpRelayServer::resolve_remote(cpu_pool, addr, forbidden_ip) + .and_then(move |addr| TcpStream::connect(&addr, &handle))) + } + + pub fn handle_client(handle: &Handle, + cpu_pool: CpuPool, + s: TcpStream, + svr_cfg: Arc, + forbidden_ip: Arc>) + -> io::Result<()> { + let peer_addr = try!(s.peer_addr()); + trace!("Got connection from {}", peer_addr); + + let cloned_handle = handle.clone(); + let fut = futures::lazy(|| Ok(s.split())) + .and_then(move |(r, w)| TcpRelayServer::handshake((r, w), svr_cfg)) + .and_then(|(r, w)| Address::read_from(r).map(|(r, addr)| (r, w, addr)).map_err(From::from)) + .and_then(move |(r, w, addr)| { + info!("Connecting {}", addr); + let cloned_addr = addr.clone(); + TcpRelayServer::connect_remote(cpu_pool, cloned_handle, addr, forbidden_ip) + .map(|svr_s| (svr_s, r, w, cloned_addr)) + }) + .and_then(|(svr_s, r, w, addr)| { + let (svr_r, svr_w) = svr_s.split(); + let c2s = copy(r, svr_w); + let s2c = copy(svr_r, w); + c2s.join(s2c) + .and_then(move |(c2s_amt, s2c_amt)| { + trace!("Relayed {} client -> remote {}bytes", addr, c2s_amt); + trace!("Relayed {} client <- remote {}bytes", addr, s2c_amt); + Ok(()) + }) + }); + + handle.spawn(fut.then(|res| { + match res { + Ok(..) => Ok(()), Err(err) => { - panic!("Error occurs while accepting: {}", err); + error!("Failed to handle client: {}", err); + Err(()) } + } + })); + + Ok(()) + } + + /// Runs the server + pub fn run(self, handle: Handle) -> Box> { + let mut fut: Option>> = None; + + for svr_cfg in &self.config.server { + let listener = { + let addr = &svr_cfg.addr; + let listener = TcpListener::bind(addr, &handle).unwrap(); + trace!("ShadowSocks TCP Listening on {}", addr); + listener }; - if let Err(err) = stream.set_read_timeout(timeout) { - error!("Failed to set read timeout: {:?}", err); - continue; - } + let svr_cfg = svr_cfg.clone(); + let handle = handle.clone(); + let forbidden_ip = self.config.forbidden_ip.clone(); + let cpu_pool = self.cpu_pool.clone(); + let listening = listener.incoming() + .for_each(move |(socket, addr)| { + let server_cfg = svr_cfg.clone(); + let forbidden_ip = forbidden_ip.clone(); + let cpu_pool = cpu_pool.clone(); - if let Err(err) = stream.set_nodelay(true) { - error!("Failed to set no delay: {}", err); - continue; - } - - let pwd = pwd.clone(); - let encrypt_method = method; - let dnscache = dnscache_arc.clone(); - let forbidden_ip = forbidden_ip.clone(); - - Scheduler::spawn(move || { - let remote_iv = { - let mut iv = Vec::with_capacity(encrypt_method.block_size()); - unsafe { - iv.set_len(encrypt_method.block_size()); - } - - let mut total_len = 0; - while total_len < encrypt_method.block_size() { - match stream.read(&mut iv[total_len..]) { - Ok(0) => { - error!("Unexpected EOF while reading initialize vector"); - return; - } - Ok(n) => total_len += n, - Err(err) => { - error!("Error while reading initialize vector: {}", err); - return; - } - } - } - iv - }; - let decryptor = cipher::with_type(encrypt_method, - &pwd[..], - &remote_iv[..], - CryptoMode::Decrypt); - - let mut client_writer = match stream.try_clone() { - Ok(s) => s, - Err(err) => { - error!("Error occurs while cloning client stream: {}", err); - return; - } - }; - let client_reader = stream; - - let iv = encrypt_method.gen_init_vec(); - let encryptor = cipher::with_type(encrypt_method, &pwd[..], &iv[..], CryptoMode::Encrypt); - if let Err(err) = client_writer.write_all(&iv[..]) { - error!("Error occurs while writing initialize vector: {}", err); - return; - } - - let mut decrypt_stream = DecryptedReader::new(client_reader, decryptor); - - let addr = match socks5::Address::read_from(&mut decrypt_stream) { - Ok(addr) => addr, - Err(err) => { - error!("Error occurs while parsing request header, maybe wrong crypto \ - method or password: {}", - err); - return; - } - }; - - info!("Connecting to {}", addr); - - let remote_stream = match &addr { - &socks5::Address::SocketAddress(ref addr) => { - if forbidden_ip.contains(&::relay::take_ip_addr(addr)) { - info!("{} has been blocked by `forbidden_ip`", addr); - return; - } - - match TcpStream::connect(&addr) { - Ok(stream) => stream, - Err(err) => { - error!("Unable to connect {:?}: {}", addr, err); - return; - } - } - } - &socks5::Address::DomainNameAddress(ref dname, ref port) => { - let addrs = match dnscache.resolve(&dname) { - Some(addrs) => addrs, - None => return, - }; - - let processing = || { - let mut last_err: Option> = None; - for addr in addrs.into_iter() { - let addr = match addr { - SocketAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr.ip().clone(), *port)), - SocketAddr::V6(addr) => { - SocketAddr::V6(SocketAddrV6::new(addr.ip().clone(), - *port, - addr.flowinfo(), - addr.scope_id())) - } - }; - - if forbidden_ip.contains(&::relay::take_ip_addr(&addr)) { - info!("{} has been blocked by `forbidden_ip`", addr); - last_err = Some(Err(io::Error::new(io::ErrorKind::Other, - "Blocked by `forbidden_ip`"))); - continue; - } - - match TcpStream::connect(addr) { - Ok(stream) => return Ok(stream), - Err(err) => { - error!("Unable to connect {:?}: {}", addr, err); - last_err = Some(Err(err)); - } - } - } - - last_err.unwrap() - }; - - match processing() { - Ok(s) => s, - Err(_) => return, - } - } - }; - - let mut remote_writer = match remote_stream.try_clone() { - Ok(s) => s, - Err(err) => { - error!("Error occurs while cloning remote stream: {}", err); - return; - } - }; - let addr_cloned = addr.clone(); - - Scheduler::spawn(move || { - let mut remote_reader = BufReader::new(remote_stream); - let mut encrypt_stream = EncryptedWriter::new(client_writer, encryptor); - - match ::relay::copy_once(&mut remote_reader, &mut encrypt_stream) { - Ok(0) => {} - Ok(n) => { - let remote_addr = encrypt_stream.get_ref().peer_addr().unwrap(); - let client_addr = remote_reader.get_ref().peer_addr().unwrap(); - - trace!("{} local <- remote: relayed {} bytes from {} to {}", - addr, - n, - remote_addr, - client_addr); - - loop { - match ::relay::copy_once(&mut remote_reader, &mut encrypt_stream) { - Ok(0) => { - trace!("{} local <- remote: EOF", addr); - break; - } - Ok(n) => { - trace!("{} local <- remote: relayed {} bytes from {} to {}", - addr, - n, - remote_addr, - client_addr) - } - Err(err) => { - error!("{} local <- remote: {}", addr, err); - break; - } - } - } - } - Err(err) => { - error!("{} local <- remote: {}", addr, err); - } - } - - debug!("{} local <- remote is closing", addr); - - let _ = encrypt_stream.get_mut().shutdown(Shutdown::Both); - let _ = remote_reader.get_mut().shutdown(Shutdown::Both); + trace!("Got connection, addr: {}", addr); + trace!("Picked proxy server: {:?}", server_cfg); + TcpRelayServer::handle_client(&handle, cpu_pool, socket, server_cfg, forbidden_ip) + }) + .map_err(|err| { + error!("Server run failed: {}", err); + err }); - Scheduler::spawn(move || { - match ::relay::copy_once(&mut decrypt_stream, &mut remote_writer) { - Ok(0) => {} - Ok(n) => { - let remote_addr = remote_writer.peer_addr().unwrap(); - let client_addr = decrypt_stream.get_ref().peer_addr().unwrap(); - - debug!("{} local -> remote: relayed {} bytes from {} to {}", - addr_cloned, - n, - remote_addr, - client_addr); - - loop { - match ::relay::copy_once(&mut decrypt_stream, &mut remote_writer) { - Ok(0) => { - trace!("{} local -> remote: EOF", addr_cloned); - break; - } - Ok(n) => { - debug!("{} local -> remote: relayed {} bytes from {} to {}", - addr_cloned, - n, - remote_addr, - client_addr); - } - Err(err) => { - error!("{} local -> remote: {}", addr_cloned, err); - break; - } - } - } - } - Err(err) => { - error!("{} local -> remote: {}", addr_cloned, err); - } - } - - debug!("{} local -> remote is closing", addr_cloned); - - let _ = remote_writer.shutdown(Shutdown::Both); - let _ = decrypt_stream.get_mut().shutdown(Shutdown::Both); - }); - }); - } - } -} - -impl TcpRelayServer { - pub fn run(&self) { - let mut futs = Vec::with_capacity(self.config.server.len()); - let forbidden_ip = Arc::new(self.config.forbidden_ip.clone()); - - for s in &self.config.server { - let s = s.clone(); - let forbidden_ip = forbidden_ip.clone(); - let fut = Scheduler::spawn(move || { - TcpRelayServer::accept_loop(s, forbidden_ip); - }); - futs.push(fut); - } - - for fut in futs { - fut.join().unwrap(); + fut = Some(match fut.take() { + Some(fut) => Box::new(fut.join(listening).map(|_| ())) as Box>, + None => Box::new(listening) as Box>, + }) } + + fut.expect("Must have at least one server") } } diff --git a/src/relay/tcprelay/stream.rs b/src/relay/tcprelay/stream.rs index e3af59fe..ad9a5e52 100644 --- a/src/relay/tcprelay/stream.rs +++ b/src/relay/tcprelay/stream.rs @@ -19,12 +19,17 @@ // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +#![allow(dead_code)] + use std::io::{self, Read, BufRead, Write}; use std::cmp; -use crypto::cipher::{Cipher, CipherVariant}; +use crypto::{Cipher, CipherVariant}; -pub struct DecryptedReader { +/// Reader wrapper that will decrypt data automatically +pub struct DecryptedReader + where R: Read + 'static +{ reader: R, buffer: Vec, cipher: CipherVariant, @@ -34,7 +39,9 @@ pub struct DecryptedReader { const BUFFER_SIZE: usize = 2048; -impl DecryptedReader { +impl DecryptedReader + where R: Read + 'static +{ pub fn new(r: R, cipher: CipherVariant) -> DecryptedReader { DecryptedReader { reader: r, @@ -59,18 +66,20 @@ impl DecryptedReader { &mut self.reader } - // /// Unwraps this `DecryptedReader`, returning the underlying reader. - // /// - // /// The internal buffer is flushed before returning the reader. Any leftover - // /// data in the read buffer is lost. - // pub fn into_inner(self) -> R { - // self.reader - // } + /// Unwraps this `DecryptedReader`, returning the underlying reader. + /// + /// The internal buffer is flushed before returning the reader. Any leftover + /// data in the read buffer is lost. + pub fn into_inner(self) -> R { + self.reader + } } -impl BufRead for DecryptedReader { +impl BufRead for DecryptedReader + where R: Read + 'static +{ fn fill_buf<'b>(&'b mut self) -> io::Result<&'b [u8]> { - while self.pos == self.buffer.len() { + while self.pos >= self.buffer.len() { if self.sent_final { return Ok(&[]); } @@ -81,21 +90,17 @@ impl BufRead for DecryptedReader { Ok(0) => { // EOF try!(self.cipher - .finalize(&mut self.buffer) - .map_err(|err| io::Error::new(io::ErrorKind::Other, - err.desc))); + .finalize(&mut self.buffer)); self.sent_final = true; - }, + } Ok(l) => { try!(self.cipher - .update(&incoming[..l], &mut self.buffer) - .map_err(|err| io::Error::new(io::ErrorKind::Other, - err.desc))); - }, + .update(&incoming[..l], &mut self.buffer)); + } Err(err) => { return Err(err); } - }; + } self.pos = 0; } @@ -108,7 +113,9 @@ impl BufRead for DecryptedReader { } } -impl Read for DecryptedReader { +impl Read for DecryptedReader + where R: Read + 'static +{ fn read(&mut self, buf: &mut [u8]) -> io::Result { let nread = { let mut available = try!(self.fill_buf()); @@ -119,13 +126,19 @@ impl Read for DecryptedReader { } } -pub struct EncryptedWriter { +/// Writer wrapper that will encrypt data automatically +pub struct EncryptedWriter + where W: Write + 'static +{ writer: W, cipher: CipherVariant, buffer: Vec, } -impl EncryptedWriter { +impl EncryptedWriter + where W: Write + 'static +{ + /// Creates a new EncryptedWriter pub fn new(w: W, cipher: CipherVariant) -> EncryptedWriter { EncryptedWriter { writer: w, @@ -134,21 +147,20 @@ impl EncryptedWriter { } } + /// Finalize the cipher, which will writes the final block into buffer pub fn finalize(&mut self) -> io::Result<()> { self.buffer.clear(); match self.cipher.finalize(&mut self.buffer) { Ok(..) => { - self.writer.write_all(&self.buffer[..]) + self.writer + .write_all(&self.buffer[..]) .and_then(|_| self.writer.flush()) - }, - Err(err) => { - Err(io::Error::new( - io::ErrorKind::Other, - err.desc)) } + Err(err) => Err(From::from(err)), } } + /// Get reference to the inner writer pub fn get_ref(&self) -> &W { &self.writer } @@ -164,23 +176,19 @@ impl EncryptedWriter { } } -impl Write for EncryptedWriter { +impl Write for EncryptedWriter + where W: Write + 'static +{ fn write(&mut self, buf: &[u8]) -> io::Result { self.buffer.clear(); match self.cipher.update(buf, &mut self.buffer) { Ok(..) => { match self.writer.write_all(&self.buffer[..]) { - Ok(..) => { - Ok(buf.len()) - }, + Ok(..) => Ok(buf.len()), Err(err) => Err(err), } - }, - Err(err) => { - Err(io::Error::new( - io::ErrorKind::Other, - err.desc)) } + Err(err) => Err(From::from(err)), } } @@ -189,7 +197,9 @@ impl Write for EncryptedWriter { } } -impl Drop for EncryptedWriter { +impl Drop for EncryptedWriter + where W: Write + 'static +{ fn drop(&mut self) { let _ = self.finalize(); }