From 2ad57fa5ee7e30c67c5250eeb7c779566ad65133 Mon Sep 17 00:00:00 2001 From: "Y. T. Chung" Date: Sat, 11 Feb 2017 23:50:19 +0800 Subject: [PATCH] revert EncryptedWriter as Write --- src/relay/tcprelay/mod.rs | 2 +- src/relay/tcprelay/stream.rs | 215 +++++++++++++++++++++++------------ 2 files changed, 143 insertions(+), 74 deletions(-) diff --git a/src/relay/tcprelay/mod.rs b/src/relay/tcprelay/mod.rs index e10e1bc2..6ecd2d0d 100644 --- a/src/relay/tcprelay/mod.rs +++ b/src/relay/tcprelay/mod.rs @@ -111,7 +111,7 @@ pub fn proxy_server_handshake(remote_stream: TcpStream, // Send relay address to remote let local_buf = Vec::new(); relay_addr.write_to(local_buf) - .and_then(move |buf| try_timeout(write_all(enc_w, buf), timeout, &handle)) + .and_then(move |buf| try_timeout(enc_w.write_all_encrypted(buf), timeout, &handle)) .map(|(w, _)| w) }); diff --git a/src/relay/tcprelay/stream.rs b/src/relay/tcprelay/stream.rs index d326b6dd..50711f62 100644 --- a/src/relay/tcprelay/stream.rs +++ b/src/relay/tcprelay/stream.rs @@ -22,6 +22,7 @@ #![allow(dead_code)] use std::io::{self, Read, BufRead, Write}; +use std::mem; use std::cmp; use std::time::Duration; @@ -30,7 +31,6 @@ use crypto::{Cipher, CipherVariant}; use futures::{Future, Poll, Async}; use tokio_core::reactor::{Handle, Timeout}; -use tokio_core::io::copy; use super::{BUFFER_SIZE, BoxIoFuture, boxed_future}; @@ -138,8 +138,6 @@ pub struct EncryptedWriter { writer: W, cipher: CipherVariant, - buf: Vec, - finalized: bool, } impl EncryptedWriter @@ -150,56 +148,51 @@ impl EncryptedWriter EncryptedWriter { writer: w, cipher: cipher, + } + } + + #[doc(hidden)] + pub fn cipher_update(&mut self, data: &[u8], buf: &mut Vec) -> io::Result<()> { + self.cipher.update(data, buf).map_err(From::from) + } + + #[doc(hidden)] + pub fn cipher_finalize(&mut self, buf: &mut Vec) -> io::Result<()> { + self.cipher.finalize(buf).map_err(From::from) + } + + /// write_all + pub fn write_all_encrypted>(self, buf: B) -> EncryptedWriteAll { + EncryptedWriteAll::Writing { + writer: self, + buf: buf, + pos: 0, + enc_buf: Vec::new(), + encrypted: false, + } + } + + /// Copy all data from reader + pub fn copy_from_encrypted(self, r: R) -> EncryptedCopy { + EncryptedCopy { + reader: r, + writer: self, + read_done: false, + amt: 0, + pos: 0, + cap: 0, buf: Vec::new(), - finalized: false, } } - fn cipher_finalize(&mut self) -> io::Result<()> { - if self.finalized { - return Ok(()); - } - - self.cipher - .finalize(&mut self.buf) - .and_then(|_| { - self.finalized = true; - Ok(()) - }) - .map_err(From::from) + /// Write raw bytes + pub fn write_raw(&mut self, data: &[u8]) -> io::Result { + self.writer.write(data) } - fn flush_buf(&mut self) -> io::Result { - let mut written = 0; - let expected_len = self.buf.len(); - let mut ret = Ok(()); - while written < expected_len { - match self.writer.write(&self.buf[written..]) { - Ok(0) => { - ret = Err(io::Error::new(io::ErrorKind::WriteZero, - "failed to write the buffered data")); - break; - } - Ok(n) => written += n, - Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} - Err(e) => { - ret = Err(e); - break; - } - } - } - - if written > 0 { - self.buf.drain(..written); - } - - ret.map(|_| written) - } - - fn fill_buf(&mut self, data: &[u8]) -> io::Result<()> { - assert!(!self.finalized, "Called fill_buf after finalized!"); - - self.cipher.update(data, &mut self.buf).map_err(From::from) + /// Flush data + pub fn flush(&mut self) -> io::Result<()> { + self.writer.flush() } } @@ -212,41 +205,107 @@ impl EncryptedWriter { match timeout { Some(timeout) => boxed_future(EncryptedCopyTimeout::new(r, self, timeout, handle)), - None => boxed_future(copy(r, self)), + None => boxed_future(self.copy_from_encrypted(r)), } } } -impl Write for EncryptedWriter - where W: Write +/// write_all and encrypt data +pub enum EncryptedWriteAll + where W: Write, + B: AsRef<[u8]> { - fn write(&mut self, buf: &[u8]) -> io::Result { - if !self.buf.is_empty() { - self.flush_buf()?; + Writing { + writer: EncryptedWriter, + buf: B, + pos: usize, + enc_buf: Vec, + encrypted: bool, + }, + Empty, +} + +impl Future for EncryptedWriteAll + where W: Write, + B: AsRef<[u8]> +{ + type Item = (EncryptedWriter, B); + type Error = io::Error; + + fn poll(&mut self) -> Poll { + match *self { + EncryptedWriteAll::Empty => panic!("poll after EncryptedWriteAll finished"), + EncryptedWriteAll::Writing { ref mut writer, ref buf, ref mut pos, ref mut enc_buf, ref mut encrypted } => { + if !*encrypted { + *encrypted = true; + try!(writer.cipher_update(buf.as_ref(), enc_buf)); + } + + while *pos < enc_buf.len() { + let n = try_nb!(writer.write_raw(&enc_buf[*pos..])); + *pos += n; + if n == 0 { + let err = io::Error::new(io::ErrorKind::Other, "zero-length write"); + return Err(err); + } + } + } } - self.fill_buf(buf)?; - match self.flush_buf() { - Ok(..) => {} - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {} - Err(err) => return Err(err), + match mem::replace(self, EncryptedWriteAll::Empty) { + EncryptedWriteAll::Writing { writer, buf, .. } => Ok((writer, buf).into()), + EncryptedWriteAll::Empty => unreachable!(), } - - Ok(buf.len()) - } - - fn flush(&mut self) -> io::Result<()> { - self.flush_buf().and_then(|_| self.writer.flush()) } } -impl Drop for EncryptedWriter - where W: Write -{ - fn drop(&mut self) { - if let Ok(..) = self.cipher_finalize() { - // I don't care if it is failed to write - let _ = self.flush_buf(); +/// Encrypted copy +pub struct EncryptedCopy { + reader: R, + writer: EncryptedWriter, + read_done: bool, + amt: u64, + pos: usize, + cap: usize, + buf: Vec, +} + +impl Future for EncryptedCopy { + type Item = u64; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + let mut local_buf = [0u8; BUFFER_SIZE]; + loop { + // If our buffer is empty, then we need to read some data to + // continue. + if self.pos == self.cap && !self.read_done { + let n = try_nb!(self.reader.read(&mut local_buf[..])); + self.buf.clear(); + if n == 0 { + self.read_done = true; + try!(self.writer.cipher_finalize(&mut self.buf)); + } else { + try!(self.writer.cipher_update(&local_buf[..n], &mut self.buf)); + } + self.pos = 0; + self.cap = self.buf.len(); + } + + // If our buffer has some data, let's write it out! + while self.pos < self.cap { + let i = try_nb!(self.writer.write_raw(&self.buf[self.pos..self.cap])); + self.pos += i; + self.amt += i as u64; + } + + // If we've written al the data and we've seen EOF, flush out the + // data and finish the transfer. + // done with the entire transfer. + if self.pos == self.cap && self.read_done { + try_nb!(self.writer.flush()); + return Ok(self.amt.into()); + } } } } @@ -263,6 +322,7 @@ pub struct EncryptedCopyTimeout { handle: Handle, timer: Option, read_buf: [u8; BUFFER_SIZE], + write_buf: Vec, } impl EncryptedCopyTimeout { @@ -278,6 +338,7 @@ impl EncryptedCopyTimeout { handle: handle, timer: None, read_buf: [0u8; BUFFER_SIZE], + write_buf: Vec::new(), } } @@ -305,9 +366,17 @@ impl EncryptedCopyTimeout { // Then, unset the previous timeout self.clear_timer(); + self.write_buf.clear(); match self.reader.read(&mut self.read_buf) { + Ok(0) => { + self.writer.cipher_finalize(&mut self.write_buf)?; + self.cap = self.write_buf.len(); + self.pos = 0; + Ok(0) + } Ok(n) => { - self.cap = n; + self.writer.cipher_update(&self.read_buf[..n], &mut self.write_buf)?; + self.cap = self.write_buf.len(); self.pos = 0; Ok(n) } @@ -327,7 +396,7 @@ impl EncryptedCopyTimeout { // Then, unset the previous timeout self.clear_timer(); - match self.writer.write(&self.read_buf[self.pos..self.cap]) { + match self.writer.write_raw(&self.write_buf[self.pos..self.cap]) { Ok(n) => { self.pos += n; Ok(n)