From a5ab3b39743444ad57195d45fa019e26b1884d31 Mon Sep 17 00:00:00 2001 From: zonyitoo Date: Sun, 11 Oct 2020 16:25:05 +0800 Subject: [PATCH] EncryptWriter keeps buffer instead of allocating one everytime --- src/relay/tcprelay/aead.rs | 44 ++++++++++++++---------------- src/relay/tcprelay/crypto_io.rs | 4 +-- src/relay/tcprelay/stream.rs | 48 +++++++++++---------------------- 3 files changed, 37 insertions(+), 59 deletions(-) diff --git a/src/relay/tcprelay/aead.rs b/src/relay/tcprelay/aead.rs index 1967adc7..7e7c9f50 100644 --- a/src/relay/tcprelay/aead.rs +++ b/src/relay/tcprelay/aead.rs @@ -43,7 +43,7 @@ use std::{ }; use byteorder::{BigEndian, ByteOrder}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use futures::ready; use tokio::prelude::*; @@ -224,7 +224,7 @@ impl DecryptedReader { enum EncryptWriteStep { Nothing, - Writing(BytesMut), + Writing, } /// Writer wrapper that will encrypt data automatically @@ -232,17 +232,21 @@ pub struct EncryptedWriter { cipher: BoxAeadEncryptor, tag_size: usize, steps: EncryptWriteStep, - nonce: Option, + buf: BytesMut, } impl EncryptedWriter { /// Creates a new EncryptedWriter - pub fn new(t: CipherType, key: &[u8], nonce: Bytes) -> EncryptedWriter { + pub fn new(t: CipherType, key: &[u8], nonce: &[u8]) -> EncryptedWriter { + // nonce should be sent with the first packet + let mut buf = BytesMut::with_capacity(nonce.len()); + buf.put(nonce); + EncryptedWriter { - cipher: crypto::new_aead_encryptor(t, key, &nonce), + cipher: crypto::new_aead_encryptor(t, key, nonce), tag_size: t.tag_size(), steps: EncryptWriteStep::Nothing, - nonce: Some(nonce), + buf, } } @@ -279,43 +283,35 @@ impl EncryptedWriter { let output_length = self.buffer_size(data); let data_length = data.len() as u16; - // Send the first packet with nonce - let nonce_length = match self.nonce { - Some(ref n) => n.len(), - None => 0, - }; - - let mut buf = BytesMut::with_capacity(nonce_length + output_length); - - // Put nonce first - if let Some(n) = self.nonce.take() { - buf.extend(n); - } + self.buf.reserve(output_length); let mut data_len_buf = [0u8; 2]; BigEndian::write_u16(&mut data_len_buf, data_length); unsafe { - let b = slice::from_raw_parts_mut(buf.bytes_mut().as_mut_ptr() as *mut u8, output_length); + let b = slice::from_raw_parts_mut(self.buf.bytes_mut().as_mut_ptr() as *mut u8, output_length); let output_length_size = 2 + self.tag_size; self.cipher.encrypt(&data_len_buf, &mut b[..output_length_size]); self.cipher.encrypt(data, &mut b[output_length_size..output_length]); - buf.advance_mut(output_length); + self.buf.advance_mut(output_length); } - self.steps = EncryptWriteStep::Writing(buf); + self.steps = EncryptWriteStep::Writing; } - EncryptWriteStep::Writing(ref mut buf) => { - while buf.remaining() > 0 { - let n = ready!(Pin::new(&mut *w).poll_write_buf(ctx, buf))?; + EncryptWriteStep::Writing => { + while self.buf.remaining() > 0 { + let n = ready!(Pin::new(&mut *w).poll_write_buf(ctx, &mut self.buf))?; if n == 0 { use std::io::ErrorKind; return Poll::Ready(Err(ErrorKind::UnexpectedEof.into())); } } + // Reclaim buffer + // NOTE: This operation won't free allocated memory + self.buf.clear(); self.steps = EncryptWriteStep::Nothing; return Poll::Ready(Ok(())); } diff --git a/src/relay/tcprelay/crypto_io.rs b/src/relay/tcprelay/crypto_io.rs index 191f8d06..fa39b5c5 100644 --- a/src/relay/tcprelay/crypto_io.rs +++ b/src/relay/tcprelay/crypto_io.rs @@ -109,8 +109,8 @@ impl CryptoStream { let method = svr_cfg.method(); let enc = match method.category() { - CipherCategory::Stream => EncryptedWriter::Stream(StreamEncryptedWriter::new(method, svr_cfg.key(), iv)), - CipherCategory::Aead => EncryptedWriter::Aead(AeadEncryptedWriter::new(method, svr_cfg.key(), iv)), + CipherCategory::Stream => EncryptedWriter::Stream(StreamEncryptedWriter::new(method, svr_cfg.key(), &iv)), + CipherCategory::Aead => EncryptedWriter::Aead(AeadEncryptedWriter::new(method, svr_cfg.key(), &iv)), CipherCategory::None => EncryptedWriter::None, }; diff --git a/src/relay/tcprelay/stream.rs b/src/relay/tcprelay/stream.rs index b1c84b05..7d0cc950 100644 --- a/src/relay/tcprelay/stream.rs +++ b/src/relay/tcprelay/stream.rs @@ -9,7 +9,7 @@ use std::{ }; use crate::crypto::{new_stream, BoxStreamCipher, CipherType, CryptoMode}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use futures::ready; use tokio::prelude::*; @@ -87,23 +87,27 @@ impl DecryptedReader { enum EncryptWriteStep { Nothing, - Writing(BytesMut), + Writing, } /// Writer wrapper that will encrypt data automatically pub struct EncryptedWriter { cipher: BoxStreamCipher, steps: EncryptWriteStep, - iv: Option, + buf: BytesMut, } impl EncryptedWriter { /// Creates a new EncryptedWriter - pub fn new(t: CipherType, key: &[u8], iv: Bytes) -> EncryptedWriter { + pub fn new(t: CipherType, key: &[u8], iv: &[u8]) -> EncryptedWriter { + // iv should be sent with the first packet + let mut buf = BytesMut::with_capacity(iv.len()); + buf.put(iv); + EncryptedWriter { cipher: new_stream(t, key, &iv, CryptoMode::Encrypt), steps: EncryptWriteStep::Nothing, - iv: Some(iv), + buf, } } @@ -124,26 +128,13 @@ impl EncryptedWriter { loop { match self.steps { EncryptWriteStep::Nothing => { - // Send the first packet with iv - let iv_length = match self.iv { - Some(ref i) => i.len(), - None => 0, - }; - - let mut buf = BytesMut::with_capacity(iv_length + self.buffer_size(data)); - - // Put iv first - if let Some(i) = self.iv.take() { - buf.extend(i); - } - - self.cipher_update(data, &mut buf)?; - - self.steps = EncryptWriteStep::Writing(buf); + self.buf.reserve(self.buffer_size(data)); + self.cipher.update(data, &mut self.buf)?; + self.steps = EncryptWriteStep::Writing; } - EncryptWriteStep::Writing(ref mut buf) => { - while buf.remaining() > 0 { - let n = ready!(Pin::new(&mut *w).poll_write_buf(ctx, buf))?; + EncryptWriteStep::Writing => { + while self.buf.remaining() > 0 { + let n = ready!(Pin::new(&mut *w).poll_write_buf(ctx, &mut self.buf))?; if n == 0 { use std::io::ErrorKind; return Poll::Ready(Err(ErrorKind::UnexpectedEof.into())); @@ -157,15 +148,6 @@ impl EncryptedWriter { } } - fn cipher_update(&mut self, data: &[u8], buf: &mut B) -> io::Result<()> { - self.cipher.update(data, buf).map_err(From::from) - } - - #[allow(dead_code)] - fn cipher_finalize(&mut self, buf: &mut B) -> io::Result<()> { - self.cipher.finalize(buf).map_err(From::from) - } - fn buffer_size(&self, data: &[u8]) -> usize { self.cipher.buffer_size(data) }