share tcp stream handle

This commit is contained in:
Y. T. Chung
2016-10-22 00:51:00 +08:00
parent 05d3dc2f87
commit 2ed53d2bd9
4 changed files with 195 additions and 179 deletions

View File

@@ -235,7 +235,7 @@ fn main() {
config.server.push(sc);
has_provided_server_config = true;
} else if matches.value_of("SERVER_ADDR").is_none() && matches.value_of("SERVER_PORT").is_none() &&
matches.value_of("PASSWORD").is_none() && matches.value_of("ENCRYPT_METHOD").is_none() {
matches.value_of("PASSWORD").is_none() && matches.value_of("ENCRYPT_METHOD").is_none() {
// Does not provide server config
} else {
panic!("`server-addr`, `server-port`, `method` and `password` should be provided together");

View File

@@ -23,7 +23,7 @@
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::net::lookup_host;
use std::io::{self, BufWriter, BufReader, Read, Write};
use std::io::{self, Read, Write};
use std::collections::BTreeMap;
use std::sync::Arc;
@@ -41,6 +41,7 @@ use relay::socks5::{self, Address};
use relay::loadbalancing::server::{LoadBalancer, RoundRobin};
use super::http::HttpRequest;
use super::SharedTcpStream;
use crypto::cipher::CipherType;
@@ -94,30 +95,15 @@ impl TcpRelayLocal {
Ok(())
}
fn handle_tcp_client(stream: TcpStream,
fn handle_tcp_client((stream, sockname): (TcpStream, SocketAddr),
server_addr: SocketAddr,
password: Vec<u8>,
encrypt_method: CipherType,
conf: Arc<Config>) {
let sockname = match stream.peer_addr() {
Ok(sockname) => sockname,
Err(err) => {
error!("Failed to get peer addr: {}", err);
return;
}
};
let mut local_reader = SharedTcpStream::new(stream);
let mut local_writer = local_reader.clone();
let stream_writer = match stream.try_clone() {
Ok(s) => s,
Err(err) => {
error!("Failed to clone local stream: {}", err);
return;
}
};
let mut local_reader = BufReader::new(stream);
let mut local_writer = BufWriter::new(stream_writer);
if let Err(err) = TcpRelayLocal::do_handshake(&mut local_reader, &mut local_writer) {
if let Err(err) = TcpRelayLocal::do_handshake(&mut local_reader, &mut (&*local_writer)) {
error!("Error occurs while doing handshake: {}", err);
return;
}
@@ -192,18 +178,10 @@ impl TcpRelayLocal {
debug!("SYSTEM Connect {} local -> remote is closing", addr_cloned);
let _ = encrypt_stream.get_ref().shutdown(Shutdown::Both);
let _ = local_reader.get_ref().shutdown(Shutdown::Both);
let _ = local_reader.shutdown(Shutdown::Both);
});
Scheduler::spawn(move || {
let mut local_writer = match local_writer.into_inner() {
Ok(writer) => writer,
Err(err) => {
error!("Error occurs while taking out local writer: {}", err);
return;
}
};
loop {
match ::relay::copy_once(&mut decrypt_stream, &mut local_writer) {
Ok(0) => {
@@ -249,8 +227,7 @@ impl TcpRelayLocal {
}
}
fn handle_http_connect(stream: TcpStream,
stream_writer: TcpStream,
fn handle_http_connect((stream, sockname): (SharedTcpStream, SocketAddr),
addr: Address,
server_addr: SocketAddr,
password: Vec<u8>,
@@ -258,9 +235,10 @@ impl TcpRelayLocal {
remain: &[u8])
-> io::Result<()> {
info!("CONNECT (HTTP) {}", addr);
trace!("HTTP Connect: connection from {}", sockname);
let mut local_reader = BufReader::new(stream);
let mut local_writer = stream_writer;
let mut local_reader = stream;
let mut local_writer = local_reader.clone();
const HANDSHAKE: &'static [u8] = b"HTTP/1.1 200 Connection Established\r\n\r\n";
@@ -313,7 +291,7 @@ impl TcpRelayLocal {
addr_cloned);
let _ = encrypt_stream.get_ref().shutdown(Shutdown::Both);
let _ = local_reader.get_ref().shutdown(Shutdown::Both);
let _ = local_reader.shutdown(Shutdown::Both);
});
Scheduler::spawn(move || {
@@ -347,8 +325,7 @@ impl TcpRelayLocal {
}
fn handle_http_others(mut req: HttpRequest,
stream: TcpStream,
stream_writer: TcpStream,
(stream, sockname): (SharedTcpStream, SocketAddr),
addr: Address,
server_addr: SocketAddr,
password: Vec<u8>,
@@ -356,9 +333,10 @@ impl TcpRelayLocal {
remain: &[u8])
-> io::Result<()> {
info!("{} (HTTP) {}", req.method, addr);
trace!("HTTP {}: Got connection from {}", req.method, sockname);
let mut local_reader = BufReader::new(stream);
let mut local_writer = stream_writer;
let mut local_reader = stream;
let mut local_writer = local_reader.clone();
let (mut decrypt_stream, mut encrypt_stream) =
match super::connect_proxy_server(&server_addr, encrypt_method, &password[..], &addr) {
@@ -467,7 +445,7 @@ impl TcpRelayLocal {
debug!("SYSTEM Connect {} local -> remote is closing", addr_cloned);
let _ = encrypt_stream.get_ref().shutdown(Shutdown::Both);
let _ = local_reader.get_ref().shutdown(Shutdown::Both);
let _ = local_reader.shutdown(Shutdown::Both);
});
Scheduler::spawn(move || {
@@ -498,19 +476,14 @@ impl TcpRelayLocal {
Ok(())
}
fn handle_http_client(mut stream: TcpStream,
fn handle_http_client((stream, sockname): (TcpStream, SocketAddr),
server_addr: SocketAddr,
password: Vec<u8>,
encrypt_method: CipherType) {
use super::http::{get_address, write_response};
let mut stream_writer = match stream.try_clone() {
Ok(s) => s,
Err(err) => {
error!("Failed to clone stream: {:?}", err);
return;
}
};
let mut stream = SharedTcpStream::new(stream);
let mut stream_writer = stream.clone();
let mut req_buf = Vec::with_capacity(8192);
let mut got_header = false;
@@ -551,8 +524,7 @@ impl TcpRelayLocal {
match request.method.clone() {
Method::Connect => {
let _ = TcpRelayLocal::handle_http_connect(stream,
stream_writer,
let _ = TcpRelayLocal::handle_http_connect((stream, sockname),
addr,
server_addr,
password,
@@ -561,8 +533,7 @@ impl TcpRelayLocal {
}
_ => {
let _ = TcpRelayLocal::handle_http_others(request,
stream,
stream_writer,
(stream, sockname),
addr,
server_addr,
password,
@@ -591,7 +562,11 @@ impl TcpRelayLocal {
impl TcpRelayLocal {
fn run_server<F>(&self, local_conf: SocketAddr, handler: F)
where F: Fn(TcpStream, SocketAddr, Vec<u8>, CipherType, Arc<Config>)
where F: Fn((TcpStream, SocketAddr),
SocketAddr,
Vec<u8>,
CipherType,
Arc<Config>)
{
let mut server_load_balancer = RoundRobin::new(self.config.server.clone());
@@ -608,10 +583,10 @@ impl TcpRelayLocal {
let mut cached_proxy: BTreeMap<String, SocketAddr> = BTreeMap::new();
for s in acceptor.incoming() {
let stream = match s {
let (stream, sockname) = match s {
Ok((s, addr)) => {
debug!("Got connection from client {:?}", addr);
s
(s, addr)
}
Err(err) => {
panic!("Error occurs while accepting: {:?}", err);
@@ -681,7 +656,7 @@ impl TcpRelayLocal {
let pwd = encrypt_method.bytes_to_key(server_cfg.password.as_bytes());
let conf = self.config.clone();
handler(stream, server_addr, pwd, encrypt_method, conf);
handler((stream, sockname), server_addr, pwd, encrypt_method, conf);
succeed = true;
break;

View File

@@ -23,6 +23,8 @@
use std::net::SocketAddr;
use std::io::{self, Read, Write};
use std::sync::Arc;
use std::ops::Deref;
use crypto::cipher::{self, CipherType};
use crypto::CryptoMode;
@@ -38,12 +40,44 @@ pub mod server;
mod stream;
mod http;
#[derive(Clone)]
pub struct SharedTcpStream(Arc<TcpStream>);
impl SharedTcpStream {
pub fn new(s: TcpStream) -> SharedTcpStream {
SharedTcpStream(Arc::new(s))
}
}
impl Read for SharedTcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
(&*self.0).read(buf)
}
}
impl Write for SharedTcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
(&*self.0).write(buf)
}
fn flush(&mut self) -> io::Result<()> {
(&*self.0).flush()
}
}
impl Deref for SharedTcpStream {
type Target = TcpStream;
fn deref(&self) -> &TcpStream {
&*self.0
}
}
fn connect_proxy_server(server_addr: &SocketAddr,
encrypt_method: CipherType,
pwd: &[u8],
relay_addr: &Address)
-> io::Result<(DecryptedReader<TcpStream>, EncryptedWriter<TcpStream>)> {
let mut remote_stream = try!(TcpStream::connect(&server_addr));
-> io::Result<(DecryptedReader<SharedTcpStream>, EncryptedWriter<SharedTcpStream>)> {
let mut remote_stream = SharedTcpStream::new(try!(TcpStream::connect(&server_addr)));
// Encrypt data to remote server
@@ -56,15 +90,7 @@ fn connect_proxy_server(server_addr: &SocketAddr,
error!("Error occurs while writing initialize vector: {}", err);
return Err(err);
}
let remote_writer = match remote_stream.try_clone() {
Ok(s) => s,
Err(err) => {
error!("Error occurs while cloning remote stream: {}", err);
return Err(err);
}
};
EncryptedWriter::new(remote_writer, encryptor)
EncryptedWriter::new(remote_stream.clone(), encryptor)
};
trace!("Got encrypt stream and going to send addr: {:?}",

View File

@@ -35,9 +35,11 @@ use config::{Config, ServerConfig};
use relay::socks5;
use relay::tcprelay::cached_dns::CachedDns;
use relay::tcprelay::stream::{DecryptedReader, EncryptedWriter};
use crypto::cipher;
use crypto::cipher::{self, CipherType};
use crypto::CryptoMode;
use super::SharedTcpStream;
#[derive(Clone)]
pub struct TcpRelayServer {
config: Config,
@@ -51,6 +53,117 @@ impl TcpRelayServer {
TcpRelayServer { config: c }
}
fn handshake(mut stream: SharedTcpStream,
encrypt_method: CipherType,
pwd: &[u8])
-> io::Result<(DecryptedReader<SharedTcpStream>, EncryptedWriter<SharedTcpStream>)> {
// Decrypt
let remote_iv = {
let mut iv = Vec::with_capacity(encrypt_method.block_size());
unsafe {
iv.set_len(encrypt_method.block_size());
}
let mut total_len = 0;
while total_len < encrypt_method.block_size() {
match stream.read(&mut iv[total_len..]) {
Ok(0) => {
error!("Unexpected EOF while reading initialize vector");
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "Unexpected Eof");
return Err(err);
}
Ok(n) => total_len += n,
Err(err) => {
error!("Error while reading initialize vector: {}", err);
return Err(err);
}
}
}
iv
};
let decryptor = cipher::with_type(encrypt_method, pwd, &remote_iv[..], CryptoMode::Decrypt);
let decrypt_stream = DecryptedReader::new(stream.clone(), decryptor);
// Encrypt
let iv = encrypt_method.gen_init_vec();
let encryptor = cipher::with_type(encrypt_method, pwd, &iv[..], CryptoMode::Encrypt);
if let Err(err) = stream.write_all(&iv[..]) {
error!("Error occurs while writing initialize vector: {}", err);
return Err(err);
}
let encrypt_stream = EncryptedWriter::new(stream, encryptor);
Ok((decrypt_stream, encrypt_stream))
}
fn connect_remote(addr: &socks5::Address,
dnscache: &Arc<CachedDns>,
forbidden_ip: &Arc<HashSet<IpAddr>>)
-> io::Result<TcpStream> {
info!("Connecting to {}", addr);
match addr {
&socks5::Address::SocketAddress(ref addr) => {
if forbidden_ip.contains(&::relay::take_ip_addr(addr)) {
info!("{} has been blocked by `forbidden_ip`", addr);
let err = io::Error::new(io::ErrorKind::Other, "IP blocked");
return Err(err);
}
match TcpStream::connect(&addr) {
Ok(stream) => Ok(stream),
Err(err) => {
error!("Unable to connect {:?}: {}", addr, err);
Err(err)
}
}
}
&socks5::Address::DomainNameAddress(ref dname, ref port) => {
let addrs = match dnscache.resolve(&dname) {
Some(addrs) => addrs,
None => {
error!("Failed to resolve {}", dname);
let err = io::Error::new(io::ErrorKind::Other, "DNS resolve error");
return Err(err);
}
};
let processing = || {
let mut last_err: Option<io::Result<TcpStream>> = None;
for addr in addrs.into_iter() {
let addr = match addr {
SocketAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr.ip().clone(), *port)),
SocketAddr::V6(addr) => {
SocketAddr::V6(SocketAddrV6::new(addr.ip().clone(),
*port,
addr.flowinfo(),
addr.scope_id()))
}
};
if forbidden_ip.contains(&::relay::take_ip_addr(&addr)) {
info!("{} has been blocked by `forbidden_ip`", addr);
last_err = Some(Err(io::Error::new(io::ErrorKind::Other, "Blocked by `forbidden_ip`")));
continue;
}
match TcpStream::connect(addr) {
Ok(stream) => return Ok(stream),
Err(err) => {
error!("Unable to connect {:?}: {}", addr, err);
last_err = Some(Err(err));
}
}
}
last_err.unwrap()
};
processing()
}
}
}
fn accept_loop(s: ServerConfig, forbidden_ip: Arc<HashSet<IpAddr>>) {
let acceptor = TcpListener::bind(&(&s.addr[..], s.port))
.unwrap_or_else(|err| panic!("Failed to bind a TCP socket: {}", err));
@@ -66,7 +179,7 @@ impl TcpRelayServer {
info!("Method {}, Timeout: {:?}", method, timeout);
for s in acceptor.incoming() {
let mut stream = match s {
let stream = match s {
Ok((s, addr)) => {
debug!("Got connection from {:?}", addr);
s
@@ -92,50 +205,14 @@ impl TcpRelayServer {
let forbidden_ip = forbidden_ip.clone();
Scheduler::spawn(move || {
let remote_iv = {
let mut iv = Vec::with_capacity(encrypt_method.block_size());
unsafe {
iv.set_len(encrypt_method.block_size());
}
let mut total_len = 0;
while total_len < encrypt_method.block_size() {
match stream.read(&mut iv[total_len..]) {
Ok(0) => {
error!("Unexpected EOF while reading initialize vector");
return;
}
Ok(n) => total_len += n,
Err(err) => {
error!("Error while reading initialize vector: {}", err);
return;
}
let (mut decrypt_stream, mut encrypt_stream) =
match TcpRelayServer::handshake(SharedTcpStream::new(stream), encrypt_method, &pwd[..]) {
Ok(x) => x,
Err(err) => {
error!("Failed to do handshake, {}", err);
return;
}
}
iv
};
let decryptor = cipher::with_type(encrypt_method,
&pwd[..],
&remote_iv[..],
CryptoMode::Decrypt);
let mut client_writer = match stream.try_clone() {
Ok(s) => s,
Err(err) => {
error!("Error occurs while cloning client stream: {}", err);
return;
}
};
let client_reader = stream;
let iv = encrypt_method.gen_init_vec();
let encryptor = cipher::with_type(encrypt_method, &pwd[..], &iv[..], CryptoMode::Encrypt);
if let Err(err) = client_writer.write_all(&iv[..]) {
error!("Error occurs while writing initialize vector: {}", err);
return;
}
let mut decrypt_stream = DecryptedReader::new(client_reader, decryptor);
};
let addr = match socks5::Address::read_from(&mut decrypt_stream) {
Ok(addr) => addr,
@@ -147,81 +224,19 @@ impl TcpRelayServer {
}
};
info!("Connecting to {}", addr);
let remote_stream = match &addr {
&socks5::Address::SocketAddress(ref addr) => {
if forbidden_ip.contains(&::relay::take_ip_addr(addr)) {
info!("{} has been blocked by `forbidden_ip`", addr);
return;
}
match TcpStream::connect(&addr) {
Ok(stream) => stream,
Err(err) => {
error!("Unable to connect {:?}: {}", addr, err);
return;
}
}
}
&socks5::Address::DomainNameAddress(ref dname, ref port) => {
let addrs = match dnscache.resolve(&dname) {
Some(addrs) => addrs,
None => return,
};
let processing = || {
let mut last_err: Option<io::Result<TcpStream>> = None;
for addr in addrs.into_iter() {
let addr = match addr {
SocketAddr::V4(addr) => SocketAddr::V4(SocketAddrV4::new(addr.ip().clone(), *port)),
SocketAddr::V6(addr) => {
SocketAddr::V6(SocketAddrV6::new(addr.ip().clone(),
*port,
addr.flowinfo(),
addr.scope_id()))
}
};
if forbidden_ip.contains(&::relay::take_ip_addr(&addr)) {
info!("{} has been blocked by `forbidden_ip`", addr);
last_err = Some(Err(io::Error::new(io::ErrorKind::Other,
"Blocked by `forbidden_ip`")));
continue;
}
match TcpStream::connect(addr) {
Ok(stream) => return Ok(stream),
Err(err) => {
error!("Unable to connect {:?}: {}", addr, err);
last_err = Some(Err(err));
}
}
}
last_err.unwrap()
};
match processing() {
Ok(s) => s,
Err(_) => return,
}
}
};
let mut remote_writer = match remote_stream.try_clone() {
Ok(s) => s,
let remote_stream = match TcpRelayServer::connect_remote(&addr, &dnscache, &forbidden_ip) {
Ok(s) => SharedTcpStream::new(s),
Err(err) => {
error!("Error occurs while cloning remote stream: {}", err);
error!("Failed to connect to {}: {}", addr, err);
return;
}
};
let mut remote_writer = remote_stream.clone();
let addr_cloned = addr.clone();
Scheduler::spawn(move || {
let mut remote_reader = BufReader::new(remote_stream);
let mut encrypt_stream = EncryptedWriter::new(client_writer, encryptor);
match ::relay::copy_once(&mut remote_reader, &mut encrypt_stream) {
Ok(0) => {}
Ok(n) => {