mirror of
https://github.com/shadowsocks/shadowsocks-rust.git
synced 2026-02-09 01:59:16 +08:00
fixed stack overflow, separate impl copy
This commit is contained in:
@@ -36,11 +36,11 @@ use config::{ServerConfig, ServerAddr};
|
||||
|
||||
use tokio_core::net::TcpStream;
|
||||
use tokio_core::reactor::{Handle, Timeout};
|
||||
use tokio_core::io::{read_exact, write_all, read};
|
||||
use tokio_core::io::{read_exact, write_all, copy};
|
||||
use tokio_core::io::{ReadHalf, WriteHalf};
|
||||
use tokio_core::io::Io;
|
||||
|
||||
use futures::{self, Future, Poll};
|
||||
use futures::{self, Future, Poll, Async};
|
||||
|
||||
use net2::TcpBuilder;
|
||||
|
||||
@@ -305,20 +305,140 @@ fn io_timeout<T, F>(fut: F, dur: Duration, handle: &Handle) -> BoxIoFuture<T>
|
||||
boxed_future(fut)
|
||||
}
|
||||
|
||||
pub struct CopyTimeout<R, W>
|
||||
where R: Read + 'static,
|
||||
W: Write + 'static
|
||||
{
|
||||
r: R,
|
||||
w: W,
|
||||
timeout: Duration,
|
||||
handle: Handle,
|
||||
amt: u64,
|
||||
timer: Option<Timeout>,
|
||||
buf: [u8; BUFFER_SIZE],
|
||||
pos: usize,
|
||||
cap: usize,
|
||||
}
|
||||
|
||||
impl<R, W> CopyTimeout<R, W>
|
||||
where R: Read + 'static,
|
||||
W: Write + 'static
|
||||
{
|
||||
fn new(r: R, w: W, timeout: Duration, handle: Handle) -> CopyTimeout<R, W> {
|
||||
CopyTimeout {
|
||||
r: r,
|
||||
w: w,
|
||||
timeout: timeout,
|
||||
handle: handle,
|
||||
amt: 0,
|
||||
timer: None,
|
||||
buf: [0u8; BUFFER_SIZE],
|
||||
pos: 0,
|
||||
cap: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn try_poll_timeout(&mut self) -> io::Result<()> {
|
||||
match self.timer.as_mut() {
|
||||
None => Ok(()),
|
||||
Some(t) => {
|
||||
match t.poll() {
|
||||
Err(err) => Err(err),
|
||||
Ok(Async::Ready(..)) => Err(io::Error::new(io::ErrorKind::TimedOut, "timeout")),
|
||||
Ok(Async::NotReady) => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_timer(&mut self) {
|
||||
let _ = self.timer.take();
|
||||
}
|
||||
|
||||
fn read_or_set_timeout(&mut self) -> io::Result<usize> {
|
||||
// First, return if timeout
|
||||
try!(self.try_poll_timeout());
|
||||
|
||||
// Then, unset the previous timeout
|
||||
self.clear_timer();
|
||||
|
||||
match self.r.read(&mut self.buf) {
|
||||
Ok(n) => Ok(n),
|
||||
Err(e) => {
|
||||
if e.kind() == io::ErrorKind::WouldBlock {
|
||||
self.timer = Some(Timeout::new(self.timeout, &self.handle).unwrap());
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn write_or_set_timeout(&mut self, beg: usize, end: usize) -> io::Result<usize> {
|
||||
// First, return if timeout
|
||||
try!(self.try_poll_timeout());
|
||||
|
||||
// Then, unset the previous timeout
|
||||
self.clear_timer();
|
||||
|
||||
match self.w.write(&self.buf[beg..end]) {
|
||||
Ok(n) => Ok(n),
|
||||
Err(e) => {
|
||||
if e.kind() == io::ErrorKind::WouldBlock {
|
||||
self.timer = Some(Timeout::new(self.timeout, &self.handle).unwrap());
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, W> Future for CopyTimeout<R, W>
|
||||
where R: Read + 'static,
|
||||
W: Write + 'static
|
||||
{
|
||||
type Item = u64;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
loop {
|
||||
if self.pos == self.cap {
|
||||
let n = try_nb!(self.read_or_set_timeout());
|
||||
|
||||
if n == 0 {
|
||||
// If we've written al the data and we've seen EOF, flush out the
|
||||
// data and finish the transfer.
|
||||
// done with the entire transfer.
|
||||
try_nb!(self.w.flush());
|
||||
return Ok(self.amt.into());
|
||||
}
|
||||
|
||||
self.pos = 0;
|
||||
self.cap = n;
|
||||
|
||||
// Clear it before write
|
||||
self.clear_timer();
|
||||
}
|
||||
|
||||
// If our buffer has some data, let's write it out!
|
||||
while self.pos < self.cap {
|
||||
let (pos, cap) = (self.pos, self.cap);
|
||||
let i = try_nb!(self.write_or_set_timeout(pos, cap));
|
||||
self.pos += i;
|
||||
self.amt += i as u64;
|
||||
}
|
||||
|
||||
// Clear it before read
|
||||
self.clear_timer();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_timeout<R, W>(r: R, w: W, timeout: Option<Duration>, handle: Handle) -> BoxIoFuture<u64>
|
||||
where R: Read + 'static,
|
||||
W: Write + 'static
|
||||
{
|
||||
let fut = try_timeout(read(r, vec![0u8; BUFFER_SIZE]), timeout.clone(), &handle)
|
||||
.and_then(move |(r, mut buf, n)| {
|
||||
if n == 0 {
|
||||
boxed_future(futures::finished(n as u64))
|
||||
} else {
|
||||
buf.resize(n, 0);
|
||||
let fut = try_timeout(write_all(w, buf), timeout.clone(), &handle)
|
||||
.and_then(move |(w, _)| copy_timeout(r, w, timeout, handle).map(move |x| x + n as u64));
|
||||
boxed_future(fut)
|
||||
}
|
||||
});
|
||||
boxed_future(fut)
|
||||
match timeout {
|
||||
None => boxed_future(copy(r, w)),
|
||||
Some(timeout) => boxed_future(CopyTimeout::new(r, w, timeout, handle)),
|
||||
}
|
||||
}
|
||||
@@ -28,12 +28,11 @@ use std::time::Duration;
|
||||
|
||||
use crypto::{Cipher, CipherVariant};
|
||||
|
||||
use futures::{self, Future, Poll};
|
||||
use futures::{Future, Poll, Async};
|
||||
|
||||
use tokio_core::io::{read, write_all};
|
||||
use tokio_core::reactor::Handle;
|
||||
use tokio_core::reactor::{Handle, Timeout};
|
||||
|
||||
use super::{BUFFER_SIZE, BoxIoFuture, try_timeout, boxed_future};
|
||||
use super::{BUFFER_SIZE, BoxIoFuture, boxed_future};
|
||||
|
||||
/// Reader wrapper that will decrypt data automatically
|
||||
pub struct DecryptedReader<R>
|
||||
@@ -202,32 +201,13 @@ impl<W> EncryptedWriter<W>
|
||||
}
|
||||
|
||||
/// Copy all data from reader with timeout
|
||||
pub fn copy_from_encrypted_timeout<R>(mut self, r: R, timeout: Option<Duration>, handle: Handle) -> BoxIoFuture<u64>
|
||||
where R: Read + 'static
|
||||
pub fn copy_from_encrypted_timeout<R>(self, r: R, timeout: Option<Duration>, handle: Handle) -> BoxIoFuture<u64>
|
||||
where R: Read + Send + 'static
|
||||
{
|
||||
let buf = try_timeout(read(r, vec![0u8; BUFFER_SIZE]), timeout.clone(), &handle)
|
||||
.and_then(move |(r, buf, n)| {
|
||||
let mut enc_buf = Vec::new();
|
||||
let cont = if n == 0 {
|
||||
try!(self.cipher_finalize(&mut enc_buf));
|
||||
false
|
||||
} else {
|
||||
try!(self.cipher_update(&buf[..n], &mut enc_buf));
|
||||
true
|
||||
};
|
||||
|
||||
Ok((r, self, enc_buf, n, cont))
|
||||
})
|
||||
.and_then(move |(r, w, buf, n, cont)| {
|
||||
write_all(w, buf).and_then(move |(w, _)| {
|
||||
if cont {
|
||||
boxed_future(w.copy_from_encrypted_timeout(r, timeout, handle).map(move |x| x + n as u64))
|
||||
} else {
|
||||
boxed_future(futures::finished(n as u64))
|
||||
}
|
||||
})
|
||||
});
|
||||
boxed_future(buf)
|
||||
match timeout {
|
||||
Some(timeout) => boxed_future(EncryptedCopyTimeout::new(r, self, timeout, handle)),
|
||||
None => boxed_future(self.copy_from_encrypted(r)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -293,7 +273,7 @@ impl<W, B> Future for EncryptedWriteAll<W, B>
|
||||
}
|
||||
|
||||
/// Encrypted copy
|
||||
pub struct EncryptedCopy<R: Read, W: Write + 'static> {
|
||||
pub struct EncryptedCopy<R: Read + 'static, W: Write + 'static> {
|
||||
reader: R,
|
||||
writer: EncryptedWriter<W>,
|
||||
read_done: bool,
|
||||
@@ -303,7 +283,7 @@ pub struct EncryptedCopy<R: Read, W: Write + 'static> {
|
||||
buf: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<R: Read, W: Write> Future for EncryptedCopy<R, W> {
|
||||
impl<R: Read + 'static, W: Write + 'static> Future for EncryptedCopy<R, W> {
|
||||
type Item = u64;
|
||||
type Error = io::Error;
|
||||
|
||||
@@ -332,6 +312,133 @@ impl<R: Read, W: Write> Future for EncryptedCopy<R, W> {
|
||||
self.amt += i as u64;
|
||||
}
|
||||
|
||||
// If we've written al the data and we've seen EOF, flush out the
|
||||
// data and finish the transfer.
|
||||
// done with the entire transfer.
|
||||
if self.pos == self.cap && self.read_done {
|
||||
try_nb!(self.writer.flush());
|
||||
return Ok(self.amt.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encrypted copy
|
||||
pub struct EncryptedCopyTimeout<R: Read + 'static, W: Write + 'static> {
|
||||
reader: R,
|
||||
writer: EncryptedWriter<W>,
|
||||
read_done: bool,
|
||||
amt: u64,
|
||||
pos: usize,
|
||||
cap: usize,
|
||||
write_buf: Vec<u8>,
|
||||
timeout: Duration,
|
||||
handle: Handle,
|
||||
timer: Option<Timeout>,
|
||||
read_buf: [u8; BUFFER_SIZE],
|
||||
}
|
||||
|
||||
impl<R: Read + 'static, W: Write + 'static> EncryptedCopyTimeout<R, W> {
|
||||
fn new(r: R, w: EncryptedWriter<W>, dur: Duration, handle: Handle) -> EncryptedCopyTimeout<R, W> {
|
||||
EncryptedCopyTimeout {
|
||||
reader: r,
|
||||
writer: w,
|
||||
read_done: false,
|
||||
amt: 0,
|
||||
pos: 0,
|
||||
cap: 0,
|
||||
write_buf: Vec::new(),
|
||||
timeout: dur,
|
||||
handle: handle,
|
||||
timer: None,
|
||||
read_buf: [0u8; BUFFER_SIZE],
|
||||
}
|
||||
}
|
||||
|
||||
fn try_poll_timeout(&mut self) -> io::Result<()> {
|
||||
match self.timer.as_mut() {
|
||||
None => Ok(()),
|
||||
Some(t) => {
|
||||
match t.poll() {
|
||||
Err(err) => Err(err),
|
||||
Ok(Async::Ready(..)) => Err(io::Error::new(io::ErrorKind::TimedOut, "timeout")),
|
||||
Ok(Async::NotReady) => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_timer(&mut self) {
|
||||
let _ = self.timer.take();
|
||||
}
|
||||
|
||||
fn read_or_set_timeout(&mut self) -> io::Result<usize> {
|
||||
// First, return if timeout
|
||||
try!(self.try_poll_timeout());
|
||||
|
||||
// Then, unset the previous timeout
|
||||
self.clear_timer();
|
||||
|
||||
match self.reader.read(&mut self.read_buf) {
|
||||
Ok(n) => Ok(n),
|
||||
Err(e) => {
|
||||
if e.kind() == io::ErrorKind::WouldBlock {
|
||||
self.timer = Some(Timeout::new(self.timeout, &self.handle).unwrap());
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn write_or_set_timeout(&mut self, beg: usize, end: usize) -> io::Result<usize> {
|
||||
// First, return if timeout
|
||||
try!(self.try_poll_timeout());
|
||||
|
||||
// Then, unset the previous timeout
|
||||
self.clear_timer();
|
||||
|
||||
match self.writer.write(&self.write_buf[beg..end]) {
|
||||
Ok(n) => Ok(n),
|
||||
Err(e) => {
|
||||
if e.kind() == io::ErrorKind::WouldBlock {
|
||||
self.timer = Some(Timeout::new(self.timeout, &self.handle).unwrap());
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Read, W: Write> Future for EncryptedCopyTimeout<R, W> {
|
||||
type Item = u64;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
|
||||
loop {
|
||||
// If our buffer is empty, then we need to read some data to
|
||||
// continue.
|
||||
if self.pos == self.cap && !self.read_done {
|
||||
let n = try_nb!(self.read_or_set_timeout());
|
||||
self.write_buf.clear();
|
||||
if n == 0 {
|
||||
self.read_done = true;
|
||||
try!(self.writer.cipher_finalize(&mut self.write_buf));
|
||||
} else {
|
||||
try!(self.writer.cipher_update(&self.read_buf[..n], &mut self.write_buf));
|
||||
}
|
||||
self.pos = 0;
|
||||
self.cap = self.write_buf.len();
|
||||
}
|
||||
|
||||
// If our buffer has some data, let's write it out!
|
||||
while self.pos < self.cap {
|
||||
let (pos, cap) = (self.pos, self.cap);
|
||||
let i = try_nb!(self.write_or_set_timeout(pos, cap));
|
||||
self.pos += i;
|
||||
self.amt += i as u64;
|
||||
}
|
||||
|
||||
// If we've written al the data and we've seen EOF, flush out the
|
||||
// data and finish the transfer.
|
||||
// done with the entire transfer.
|
||||
|
||||
Reference in New Issue
Block a user