From 60705eff76b1f8af10eaf35205b7bafbb4149beb Mon Sep 17 00:00:00 2001 From: Jacob Payne Date: Wed, 5 Jul 2017 11:42:55 -0700 Subject: [PATCH] Use io::Read/Write interface instead of AsyncRead/AsyncWrite + Finish Tests (#74) --- p2p/src/rate_limit.rs | 101 +++++++++++++++++++++++++----------------- 1 file changed, 60 insertions(+), 41 deletions(-) diff --git a/p2p/src/rate_limit.rs b/p2p/src/rate_limit.rs index 0cf6ca7a3..95854351e 100644 --- a/p2p/src/rate_limit.rs +++ b/p2p/src/rate_limit.rs @@ -28,7 +28,7 @@ pub struct ThrottledReader { /// Max Bytes per second max: u32, /// Stores a count of last request and last request time - allowed: isize, + allowed: usize, last_check: Instant, } @@ -39,7 +39,7 @@ impl ThrottledReader { ThrottledReader { reader: reader, max: max, - allowed: max as isize, + allowed: max as usize, last_check: Instant::now(), } } @@ -65,43 +65,35 @@ impl ThrottledReader { impl io::Read for ThrottledReader { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.reader.read(buf) - } -} - -impl AsyncRead for ThrottledReader { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - self.reader.prepare_uninitialized_buffer(buf) - } - - fn read_buf(&mut self, buf: &mut B) -> Poll { // Check passed Time let time_passed = self.last_check.elapsed(); self.last_check = Instant::now(); - self.allowed += time_passed.as_secs() as isize * self.max as isize; + self.allowed += time_passed.as_secs() as usize * self.max as usize; // Throttle - if self.allowed > self.max as isize { - self.allowed = self.max as isize; + if self.allowed > self.max as usize { + self.allowed = self.max as usize; } // Check if Allowed if self.allowed < 1 { - return Ok(Async::NotReady); + return Err(io::Error::new(io::ErrorKind::WouldBlock, "Reached Allowed Read Limit")) } - // Since we can't limit the scope that is read, - // we use a signed `allowed` counter in case n > allowed - let res = self.reader.read_buf(buf); + // Read Max Allowed + let buf = if buf.len() > self.allowed { &mut buf[0..self.allowed]} else { buf }; + let res = self.reader.read(buf); // Decrement Allowed amount written - if let Ok(Async::Ready(n)) = res { - self.allowed = self.allowed.saturating_sub(n as isize); + if let Ok(n) = res { + self.allowed -= n; } res } } +impl AsyncRead for ThrottledReader { } + /// A Rate Limited Writer #[derive(Debug)] pub struct ThrottledWriter { @@ -146,22 +138,6 @@ impl ThrottledWriter { impl io::Write for ThrottledWriter { fn write(&mut self, buf: &[u8]) -> io::Result { - self.writer.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.writer.flush() - } -} - -impl AsyncWrite for ThrottledWriter { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.writer.shutdown() - } - - fn write_buf(&mut self, buf: &mut B) -> Poll - where Self: Sized - { // Check passed Time let time_passed = self.last_check.elapsed(); self.last_check = Instant::now(); @@ -174,17 +150,60 @@ impl AsyncWrite for ThrottledWriter { // Check if Allowed if self.allowed < 1 { - return Ok(Async::NotReady); + return Err(io::Error::new(io::ErrorKind::WouldBlock, "Reached Allowed Write Limit")) } // Write max allowed - let ref mut lbuf = buf.by_ref().take(self.allowed); - let res = self.writer.write_buf(lbuf); + let buf = if buf.len() > self.allowed { &buf[0..self.allowed]} else { buf }; + let res = self.writer.write(buf); // Decrement Allowed amount written - if let Ok(Async::Ready(n)) = res { + if let Ok(n) = res { self.allowed -= n; } res } + + fn flush(&mut self) -> io::Result<()> { + self.writer.flush() + } } + +impl AsyncWrite for ThrottledWriter { + fn shutdown(&mut self) -> Poll<(), io::Error> { + self.writer.shutdown() + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::io::Cursor; + + #[test] + fn should_throttle_write() { + let buf = vec![0; 64]; + let mut t_buf = ThrottledWriter::new(Cursor::new(buf), 8); + + for _ in 0..16 { + let _ = t_buf.write_buf(&mut Cursor::new(vec![1; 8])); + } + + let cursor = t_buf.into_inner(); + assert_eq!(cursor.position(), 8); + } + + #[test] + fn should_throttle_read() { + let buf = vec![1; 64]; + let mut t_buf = ThrottledReader::new(Cursor::new(buf), 8); + + let mut dst = Cursor::new(vec![0; 64]); + + for _ in 0..16 { + let _ = t_buf.read_buf(&mut dst); + } + + assert_eq!(dst.position(), 8); + } +} \ No newline at end of file