mirror of
https://github.com/shadowsocks/shadowsocks-rust.git
synced 2026-02-09 01:59:16 +08:00
Rewrite EncryptWriter as a BufWriter
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -1,6 +1,6 @@
|
||||
[root]
|
||||
name = "shadowsocks-rust"
|
||||
version = "1.1.0"
|
||||
version = "1.2.0"
|
||||
dependencies = [
|
||||
"base64 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"byteorder 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
|
||||
@@ -54,7 +54,7 @@ pub mod server;
|
||||
mod stream;
|
||||
pub mod client;
|
||||
|
||||
const BUFFER_SIZE: usize = 4096;
|
||||
const BUFFER_SIZE: usize = 32 * 1024; // 32K buffer
|
||||
|
||||
/// Directions in the tunnel
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
@@ -111,7 +111,7 @@ pub fn proxy_server_handshake(remote_stream: TcpStream,
|
||||
// Send relay address to remote
|
||||
let local_buf = Vec::new();
|
||||
relay_addr.write_to(local_buf)
|
||||
.and_then(move |buf| try_timeout(enc_w.write_all_encrypted(buf), timeout, &handle))
|
||||
.and_then(move |buf| try_timeout(write_all(enc_w, buf), timeout, &handle))
|
||||
.map(|(w, _)| w)
|
||||
});
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@
|
||||
|
||||
use std::io::{self, Read, BufRead, Write};
|
||||
use std::cmp;
|
||||
use std::mem;
|
||||
use std::time::Duration;
|
||||
|
||||
use crypto::{Cipher, CipherVariant};
|
||||
@@ -31,6 +30,7 @@ use crypto::{Cipher, CipherVariant};
|
||||
use futures::{Future, Poll, Async};
|
||||
|
||||
use tokio_core::reactor::{Handle, Timeout};
|
||||
use tokio_core::io::copy;
|
||||
|
||||
use super::{BUFFER_SIZE, BoxIoFuture, boxed_future};
|
||||
|
||||
@@ -138,6 +138,8 @@ pub struct EncryptedWriter<W>
|
||||
{
|
||||
writer: W,
|
||||
cipher: CipherVariant,
|
||||
buf: Vec<u8>,
|
||||
finalized: bool,
|
||||
}
|
||||
|
||||
impl<W> EncryptedWriter<W>
|
||||
@@ -148,57 +150,57 @@ impl<W> EncryptedWriter<W>
|
||||
EncryptedWriter {
|
||||
writer: w,
|
||||
cipher: cipher,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get reference to the inner writer
|
||||
pub fn get_ref(&self) -> &W {
|
||||
&self.writer
|
||||
}
|
||||
|
||||
/// Gets a mutable reference to the underlying writer.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// It is inadvisable to read directly from or write directly to the
|
||||
/// underlying writer.
|
||||
pub fn get_mut(&mut self) -> &mut W {
|
||||
&mut self.writer
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn cipher_update(&mut self, buf: &[u8], out: &mut Vec<u8>) -> io::Result<()> {
|
||||
self.cipher.update(buf, out).map_err(From::from)
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn cipher_finalize(&mut self, out: &mut Vec<u8>) -> io::Result<()> {
|
||||
self.cipher.finalize(out).map_err(From::from)
|
||||
}
|
||||
|
||||
/// write_all
|
||||
pub fn write_all_encrypted<B: AsRef<[u8]>>(self, buf: B) -> EncryptedWriteAll<W, B> {
|
||||
EncryptedWriteAll::Writing {
|
||||
writer: self,
|
||||
buf: buf,
|
||||
pos: 0,
|
||||
enc_buf: Vec::new(),
|
||||
encrypted: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Copy all data from reader
|
||||
pub fn copy_from_encrypted<R: Read>(self, r: R) -> EncryptedCopy<R, W> {
|
||||
EncryptedCopy {
|
||||
reader: r,
|
||||
writer: self,
|
||||
read_done: false,
|
||||
amt: 0,
|
||||
pos: 0,
|
||||
cap: 0,
|
||||
buf: Vec::new(),
|
||||
finalized: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn cipher_finalize(&mut self) -> io::Result<()> {
|
||||
if self.finalized {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.cipher
|
||||
.finalize(&mut self.buf)
|
||||
.and_then(|_| {
|
||||
self.finalized = true;
|
||||
Ok(())
|
||||
})
|
||||
.map_err(From::from)
|
||||
}
|
||||
|
||||
fn flush_buf(&mut self) -> io::Result<usize> {
|
||||
let mut written = 0;
|
||||
let expected_len = self.buf.len();
|
||||
let mut ret = Ok(());
|
||||
while written < expected_len {
|
||||
match self.writer.write(&self.buf[written..]) {
|
||||
Ok(0) => {
|
||||
ret = Err(io::Error::new(io::ErrorKind::WriteZero,
|
||||
"failed to write the buffered data"));
|
||||
break;
|
||||
}
|
||||
Ok(n) => written += n,
|
||||
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
|
||||
Err(e) => {
|
||||
ret = Err(e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if written > 0 {
|
||||
self.buf.drain(..written);
|
||||
}
|
||||
|
||||
ret.map(|_| written)
|
||||
}
|
||||
|
||||
fn fill_buf(&mut self, data: &[u8]) -> io::Result<()> {
|
||||
assert!(!self.finalized, "Called fill_buf after finalized!");
|
||||
|
||||
self.cipher.update(data, &mut self.buf).map_err(From::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> EncryptedWriter<W>
|
||||
@@ -210,7 +212,7 @@ impl<W> EncryptedWriter<W>
|
||||
{
|
||||
match timeout {
|
||||
Some(timeout) => boxed_future(EncryptedCopyTimeout::new(r, self, timeout, handle)),
|
||||
None => boxed_future(self.copy_from_encrypted(r)),
|
||||
None => boxed_future(copy(r, self)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -219,110 +221,32 @@ impl<W> Write for EncryptedWriter<W>
|
||||
where W: Write
|
||||
{
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
self.writer.write(buf)
|
||||
if !self.buf.is_empty() {
|
||||
self.flush_buf()?;
|
||||
}
|
||||
|
||||
self.fill_buf(buf)?;
|
||||
match self.flush_buf() {
|
||||
Ok(..) => {}
|
||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {}
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
|
||||
Ok(buf.len())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.writer.flush()
|
||||
self.flush_buf().and_then(|_| self.writer.flush())
|
||||
}
|
||||
}
|
||||
|
||||
/// write_all and encrypt data
|
||||
pub enum EncryptedWriteAll<W, B>
|
||||
where W: Write,
|
||||
B: AsRef<[u8]>
|
||||
impl<W> Drop for EncryptedWriter<W>
|
||||
where W: Write
|
||||
{
|
||||
Writing {
|
||||
writer: EncryptedWriter<W>,
|
||||
buf: B,
|
||||
pos: usize,
|
||||
enc_buf: Vec<u8>,
|
||||
encrypted: bool,
|
||||
},
|
||||
Empty,
|
||||
}
|
||||
|
||||
impl<W, B> Future for EncryptedWriteAll<W, B>
|
||||
where W: Write,
|
||||
B: AsRef<[u8]>
|
||||
{
|
||||
type Item = (EncryptedWriter<W>, B);
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
match *self {
|
||||
EncryptedWriteAll::Empty => panic!("poll after EncryptedWriteAll finished"),
|
||||
EncryptedWriteAll::Writing { ref mut writer, ref buf, ref mut pos, ref mut enc_buf, ref mut encrypted } => {
|
||||
if !*encrypted {
|
||||
*encrypted = true;
|
||||
try!(writer.cipher_update(buf.as_ref(), enc_buf));
|
||||
}
|
||||
|
||||
while *pos < enc_buf.len() {
|
||||
let n = try_nb!(writer.write(&enc_buf[*pos..]));
|
||||
*pos += n;
|
||||
if n == 0 {
|
||||
let err = io::Error::new(io::ErrorKind::Other, "zero-length write");
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match mem::replace(self, EncryptedWriteAll::Empty) {
|
||||
EncryptedWriteAll::Writing { writer, buf, .. } => Ok((writer, buf).into()),
|
||||
EncryptedWriteAll::Empty => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encrypted copy
|
||||
pub struct EncryptedCopy<R: Read, W: Write> {
|
||||
reader: R,
|
||||
writer: EncryptedWriter<W>,
|
||||
read_done: bool,
|
||||
amt: u64,
|
||||
pos: usize,
|
||||
cap: usize,
|
||||
buf: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<R: Read, W: Write> Future for EncryptedCopy<R, W> {
|
||||
type Item = u64;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
let mut local_buf = [0u8; BUFFER_SIZE];
|
||||
loop {
|
||||
// If our buffer is empty, then we need to read some data to
|
||||
// continue.
|
||||
if self.pos == self.cap && !self.read_done {
|
||||
let n = try_nb!(self.reader.read(&mut local_buf[..]));
|
||||
self.buf.clear();
|
||||
if n == 0 {
|
||||
self.read_done = true;
|
||||
try!(self.writer.cipher_finalize(&mut self.buf));
|
||||
} else {
|
||||
try!(self.writer.cipher_update(&local_buf[..n], &mut self.buf));
|
||||
}
|
||||
self.pos = 0;
|
||||
self.cap = self.buf.len();
|
||||
}
|
||||
|
||||
// If our buffer has some data, let's write it out!
|
||||
while self.pos < self.cap {
|
||||
let i = try_nb!(self.writer.write(&self.buf[self.pos..self.cap]));
|
||||
self.pos += i;
|
||||
self.amt += i as u64;
|
||||
}
|
||||
|
||||
// If we've written al the data and we've seen EOF, flush out the
|
||||
// data and finish the transfer.
|
||||
// done with the entire transfer.
|
||||
if self.pos == self.cap && self.read_done {
|
||||
try_nb!(self.writer.flush());
|
||||
return Ok(self.amt.into());
|
||||
}
|
||||
fn drop(&mut self) {
|
||||
if let Ok(..) = self.cipher_finalize() {
|
||||
// I don't care if it is failed to write
|
||||
let _ = self.flush_buf();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -335,7 +259,6 @@ pub struct EncryptedCopyTimeout<R: Read, W: Write> {
|
||||
amt: u64,
|
||||
pos: usize,
|
||||
cap: usize,
|
||||
write_buf: Vec<u8>,
|
||||
timeout: Duration,
|
||||
handle: Handle,
|
||||
timer: Option<Timeout>,
|
||||
@@ -351,7 +274,6 @@ impl<R: Read, W: Write> EncryptedCopyTimeout<R, W> {
|
||||
amt: 0,
|
||||
pos: 0,
|
||||
cap: 0,
|
||||
write_buf: Vec::new(),
|
||||
timeout: dur,
|
||||
handle: handle,
|
||||
timer: None,
|
||||
@@ -384,7 +306,11 @@ impl<R: Read, W: Write> EncryptedCopyTimeout<R, W> {
|
||||
self.clear_timer();
|
||||
|
||||
match self.reader.read(&mut self.read_buf) {
|
||||
Ok(n) => Ok(n),
|
||||
Ok(n) => {
|
||||
self.cap = n;
|
||||
self.pos = 0;
|
||||
Ok(n)
|
||||
}
|
||||
Err(e) => {
|
||||
if e.kind() == io::ErrorKind::WouldBlock {
|
||||
self.timer = Some(Timeout::new(self.timeout, &self.handle).unwrap());
|
||||
@@ -394,15 +320,18 @@ impl<R: Read, W: Write> EncryptedCopyTimeout<R, W> {
|
||||
}
|
||||
}
|
||||
|
||||
fn write_or_set_timeout(&mut self, beg: usize, end: usize) -> io::Result<usize> {
|
||||
fn write_or_set_timeout(&mut self) -> io::Result<usize> {
|
||||
// First, return if timeout
|
||||
try!(self.try_poll_timeout());
|
||||
|
||||
// Then, unset the previous timeout
|
||||
self.clear_timer();
|
||||
|
||||
match self.writer.write(&self.write_buf[beg..end]) {
|
||||
Ok(n) => Ok(n),
|
||||
match self.writer.write(&self.read_buf[self.pos..self.cap]) {
|
||||
Ok(n) => {
|
||||
self.pos += n;
|
||||
Ok(n)
|
||||
}
|
||||
Err(e) => {
|
||||
if e.kind() == io::ErrorKind::WouldBlock {
|
||||
self.timer = Some(Timeout::new(self.timeout, &self.handle).unwrap());
|
||||
@@ -424,22 +353,18 @@ impl<R: Read, W: Write> Future for EncryptedCopyTimeout<R, W> {
|
||||
// continue.
|
||||
if self.pos == self.cap && !self.read_done {
|
||||
let n = try_nb!(self.read_or_set_timeout());
|
||||
self.write_buf.clear();
|
||||
if n == 0 {
|
||||
self.read_done = true;
|
||||
try!(self.writer.cipher_finalize(&mut self.write_buf));
|
||||
} else {
|
||||
try!(self.writer.cipher_update(&self.read_buf[..n], &mut self.write_buf));
|
||||
}
|
||||
self.pos = 0;
|
||||
self.cap = self.write_buf.len();
|
||||
}
|
||||
|
||||
// If our buffer has some data, let's write it out!
|
||||
while self.pos < self.cap {
|
||||
let (pos, cap) = (self.pos, self.cap);
|
||||
let i = try_nb!(self.write_or_set_timeout(pos, cap));
|
||||
self.pos += i;
|
||||
let i = try_nb!(self.write_or_set_timeout());
|
||||
if i == 0 {
|
||||
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "early eof");
|
||||
return Err(err);
|
||||
}
|
||||
self.amt += i as u64;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user