[#28] Add customized write_all and copy fror EncryptedWriter

This commit is contained in:
Y. T. Chung
2016-10-27 01:47:06 +08:00
parent ea4c1878b7
commit c2d9ca53c0
4 changed files with 143 additions and 36 deletions

View File

@@ -97,7 +97,7 @@ impl Socks5RelayLocal {
.and_then(move |(svr_s, w)| {
super::proxy_server_handshake(svr_s, cloned_svr_cfg, addr).and_then(move |(svr_r, svr_w)| {
let rhalf = svr_r.and_then(move |svr_r| copy(svr_r, w));
let whalf = svr_w.and_then(move |svr_w| copy(r, svr_w));
let whalf = svr_w.and_then(move |svr_w| svr_w.copy_from(r));
tunnel(cloned_addr, whalf, rhalf)
})

View File

@@ -95,7 +95,8 @@ pub fn proxy_server_handshake(remote_stream: TcpStream,
relay_addr);
// Send relay address to remote
relay_addr.write_to(enc_w).and_then(flush)
let local_buf = Vec::new();
relay_addr.write_to(local_buf).and_then(|buf| enc_w.write_all(buf)).and_then(|(enc_w, _)| flush(enc_w))
})
.boxed();

View File

@@ -125,7 +125,7 @@ impl TcpRelayServer {
let (svr_r, svr_w) = svr_s.split();
tunnel(cloned_addr,
copy(r, svr_w),
w_fut.and_then(|w| copy(svr_r, w)))
w_fut.and_then(|w| w.copy_from(svr_r)))
})
});

View File

@@ -23,12 +23,15 @@
use std::io::{self, Read, BufRead, Write};
use std::cmp;
use std::mem;
use crypto::{Cipher, CipherVariant};
use futures::{Future, Poll};
/// Reader wrapper that will decrypt data automatically
pub struct DecryptedReader<R>
where R: Read + 'static
where R: Read
{
reader: R,
buffer: Vec<u8>,
@@ -40,7 +43,7 @@ pub struct DecryptedReader<R>
const BUFFER_SIZE: usize = 2048;
impl<R> DecryptedReader<R>
where R: Read + 'static
where R: Read
{
pub fn new(r: R, cipher: CipherVariant) -> DecryptedReader<R> {
DecryptedReader {
@@ -76,7 +79,7 @@ impl<R> DecryptedReader<R>
}
impl<R> BufRead for DecryptedReader<R>
where R: Read + 'static
where R: Read
{
fn fill_buf<'b>(&'b mut self) -> io::Result<&'b [u8]> {
while self.pos >= self.buffer.len() {
@@ -114,7 +117,7 @@ impl<R> BufRead for DecryptedReader<R>
}
impl<R> Read for DecryptedReader<R>
where R: Read + 'static
where R: Read
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let nread = {
@@ -128,34 +131,20 @@ impl<R> Read for DecryptedReader<R>
/// Writer wrapper that will encrypt data automatically
pub struct EncryptedWriter<W>
where W: Write + 'static
where W: Write
{
writer: W,
cipher: CipherVariant,
buffer: Vec<u8>,
}
impl<W> EncryptedWriter<W>
where W: Write + 'static
where W: Write
{
/// Creates a new EncryptedWriter
pub fn new(w: W, cipher: CipherVariant) -> EncryptedWriter<W> {
EncryptedWriter {
writer: w,
cipher: cipher,
buffer: Vec::new(),
}
}
/// Finalize the cipher, which will writes the final block into buffer
pub fn finalize(&mut self) -> io::Result<()> {
match self.cipher.finalize(&mut self.buffer) {
Ok(..) => {
self.writer
.write_all(&self.buffer[..])
.and_then(|_| self.writer.flush())
}
Err(err) => Err(From::from(err)),
}
}
@@ -173,20 +162,45 @@ impl<W> EncryptedWriter<W>
pub fn get_mut(&mut self) -> &mut W {
&mut self.writer
}
fn cipher_update(&mut self, buf: &[u8], out: &mut Vec<u8>) -> io::Result<()> {
self.cipher.update(buf, out).map_err(From::from)
}
fn cipher_finalize(&mut self, out: &mut Vec<u8>) -> io::Result<()> {
self.cipher.finalize(out).map_err(From::from)
}
/// write_all
pub fn write_all<B: AsRef<[u8]>>(self, buf: B) -> EncryptedWriteAll<W, B> {
EncryptedWriteAll::Writing {
writer: self,
buf: buf,
pos: 0,
enc_buf: Vec::new(),
encrypted: false,
}
}
/// Copy all data from reader
pub fn copy_from<R: Read>(self, r: R) -> EncryptedCopy<R, W> {
EncryptedCopy {
reader: r,
writer: self,
read_done: false,
amt: 0,
pos: 0,
cap: 0,
buf: Vec::new(),
}
}
}
impl<W> Write for EncryptedWriter<W>
where W: Write + 'static
where W: Write
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.cipher.update(buf, &mut self.buffer) {
Ok(..) => {
let len = try!(self.writer.write(&self.buffer[..]));
self.buffer.drain(..len);
Ok(len)
}
Err(err) => Err(From::from(err)),
}
self.writer.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
@@ -194,10 +208,102 @@ impl<W> Write for EncryptedWriter<W>
}
}
impl<W> Drop for EncryptedWriter<W>
where W: Write + 'static
/// write_all and encrypt data
pub enum EncryptedWriteAll<W, B>
where W: Write,
B: AsRef<[u8]>
{
fn drop(&mut self) {
let _ = self.finalize();
Writing {
writer: EncryptedWriter<W>,
buf: B,
pos: usize,
enc_buf: Vec<u8>,
encrypted: bool,
},
Empty,
}
impl<W, B> Future for EncryptedWriteAll<W, B>
where W: Write,
B: AsRef<[u8]>
{
type Item = (EncryptedWriter<W>, B);
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match *self {
EncryptedWriteAll::Empty => panic!("poll after EncryptedWriteAll finished"),
EncryptedWriteAll::Writing { ref mut writer, ref buf, ref mut pos, ref mut enc_buf, ref mut encrypted } => {
if !*encrypted {
*encrypted = true;
try!(writer.cipher_update(buf.as_ref(), enc_buf));
}
while *pos < enc_buf.len() {
let n = try_nb!(writer.write(&enc_buf[*pos..]));
*pos += n;
if n == 0 {
let err = io::Error::new(io::ErrorKind::Other, "zero-length write");
return Err(err);
}
}
}
}
match mem::replace(self, EncryptedWriteAll::Empty) {
EncryptedWriteAll::Writing { writer, buf, .. } => Ok((writer, buf).into()),
EncryptedWriteAll::Empty => unreachable!(),
}
}
}
/// Encrypted copy
pub struct EncryptedCopy<R: Read, W: Write> {
reader: R,
writer: EncryptedWriter<W>,
read_done: bool,
amt: u64,
pos: usize,
cap: usize,
buf: Vec<u8>,
}
impl<R: Read, W: Write> Future for EncryptedCopy<R, W> {
type Item = u64;
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut local_buf = [0u8; 2048];
loop {
// If our buffer is empty, then we need to read some data to
// continue.
if self.pos == self.cap && !self.read_done {
let n = try_nb!(self.reader.read(&mut local_buf[..]));
if n == 0 {
self.read_done = true;
} else {
self.pos = 0;
self.buf.clear();
try!(self.writer.cipher_update(&local_buf[..n], &mut self.buf));
self.cap = self.buf.len();
}
}
// If our buffer has some data, let's write it out!
while self.pos < self.cap {
let i = try_nb!(self.writer.write(&self.buf[self.pos..self.cap]));
self.pos += i;
self.amt += i as u64;
}
// If we've written al the data and we've seen EOF, flush out the
// data and finish the transfer.
// done with the entire transfer.
if self.pos == self.cap && self.read_done {
try_nb!(self.writer.flush());
return Ok(self.amt.into());
}
}
}
}