diff --git a/core/src/core/block.rs b/core/src/core/block.rs index 4ec8b2197..b8863df0b 100644 --- a/core/src/core/block.rs +++ b/core/src/core/block.rs @@ -34,7 +34,7 @@ use core::{ use global; use keychain::{self, BlindingFactor}; use pow::{Difficulty, Proof, ProofOfWork}; -use ser::{self, FixedLength, HashOnlyPMMRable, Readable, Reader, Writeable, Writer}; +use ser::{self, HashOnlyPMMRable, Readable, Reader, Writeable, Writer}; use util::{secp, static_secp_instance}; /// Errors thrown by Block validation @@ -140,34 +140,6 @@ pub struct BlockHeader { pub pow: ProofOfWork, } -const FIXED_HEADER_SIZE: usize = 2 // version - + 8 // height - + 8 // timestamp - + 5 * Hash::LEN // prev_hash, prev_root, output_root, range_proof_root, kernel_root - + BlindingFactor::LEN // total_kernel_offset - + 2 * 8 // output_mmr_size, kernel_mmr_size - + Difficulty::LEN // total_difficulty - + 4 // secondary_scaling - + 8; // nonce - -/// Serialized size of fixed part of a BlockHeader, i.e. without pow -fn fixed_size_of_serialized_header(_version: u16) -> usize { - FIXED_HEADER_SIZE -} - -/// Serialized size of a BlockHeader -pub fn serialized_size_of_header(version: u16, edge_bits: u8) -> usize { - let mut size = fixed_size_of_serialized_header(version); - - size += 1; // pow.edge_bits - let bitvec_len = global::proofsize() * edge_bits as usize; - size += bitvec_len / 8; // pow.nonces - if bitvec_len % 8 != 0 { - size += 1; - } - size -} - impl Default for BlockHeader { fn default() -> BlockHeader { BlockHeader { @@ -292,20 +264,6 @@ impl BlockHeader { pub fn total_kernel_offset(&self) -> BlindingFactor { self.total_kernel_offset } - - /// Serialized size of this header - pub fn serialized_size(&self) -> usize { - let mut size = fixed_size_of_serialized_header(self.version); - - size += 1; // pow.edge_bits - let nonce_bits = self.pow.edge_bits() as usize; - let bitvec_len = global::proofsize() * nonce_bits; - size += bitvec_len / 8; // pow.nonces - if bitvec_len % 8 != 0 { - size += 1; - } - size - } } /// A block as expressed in the MimbleWimble protocol. The reward is diff --git a/core/src/ser.rs b/core/src/ser.rs index 34e07d271..b6f9396b7 100644 --- a/core/src/ser.rs +++ b/core/src/ser.rs @@ -19,6 +19,8 @@ //! To use it simply implement `Writeable` or `Readable` and then use the //! `serialize` or `deserialize` functions on them as appropriate. +use std::time::Duration; + use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; use consensus; use core::hash::{Hash, Hashed}; @@ -27,6 +29,7 @@ use std::fmt::Debug; use std::io::{self, Read, Write}; use std::marker; use std::{cmp, error, fmt}; +use util::read_write::read_exact; use util::secp::constants::{ AGG_SIGNATURE_SIZE, MAX_PROOF_SIZE, PEDERSEN_COMMITMENT_SIZE, SECRET_KEY_SIZE, }; @@ -206,7 +209,8 @@ pub trait Writeable { fn write(&self, writer: &mut W) -> Result<(), Error>; } -struct IteratingReader<'a, T> { +/// Reader that exposes an Iterator interface. +pub struct IteratingReader<'a, T> { count: u64, curr: u64, reader: &'a mut Reader, @@ -214,7 +218,9 @@ struct IteratingReader<'a, T> { } impl<'a, T> IteratingReader<'a, T> { - fn new(reader: &'a mut Reader, count: u64) -> IteratingReader<'a, T> { + /// Constructor to create a new iterating reader for the provided underlying reader. + /// Takes a count so we know how many to iterate over. + pub fn new(reader: &'a mut Reader, count: u64) -> IteratingReader<'a, T> { let curr = 0; IteratingReader { count, @@ -270,7 +276,7 @@ where fn read(reader: &mut Reader) -> Result; } -/// Deserializes a Readeable from any std::io::Read implementation. +/// Deserializes a Readable from any std::io::Read implementation. pub fn deserialize(source: &mut Read) -> Result { let mut reader = BinReader { source }; T::read(&mut reader) @@ -325,13 +331,14 @@ impl<'a> Reader for BinReader<'a> { let len = self.read_u64()?; self.read_fixed_bytes(len as usize) } + /// Read a fixed number of bytes. - fn read_fixed_bytes(&mut self, length: usize) -> Result, Error> { - // not reading more than 100k in a single read - if length > 100000 { + fn read_fixed_bytes(&mut self, len: usize) -> Result, Error> { + // not reading more than 100k bytes in a single read + if len > 100_000 { return Err(Error::TooLargeReadErr); } - let mut buf = vec![0; length]; + let mut buf = vec![0; len]; self.source .read_exact(&mut buf) .map(move |_| buf) @@ -351,6 +358,90 @@ impl<'a> Reader for BinReader<'a> { } } +/// A reader that reads straight off a stream. +/// Tracks total bytes read so we can verify we read the right number afterwards. +pub struct StreamingReader<'a> { + total_bytes_read: u64, + stream: &'a mut Read, + timeout: Duration, +} + +impl<'a> StreamingReader<'a> { + /// Create a new streaming reader with the provided underlying stream. + /// Also takes a duration to be used for each individual read_exact call. + pub fn new(stream: &'a mut Read, timeout: Duration) -> StreamingReader<'a> { + StreamingReader { + total_bytes_read: 0, + stream, + timeout, + } + } + + /// Returns the total bytes read via this streaming reader. + pub fn total_bytes_read(&self) -> u64 { + self.total_bytes_read + } +} + +impl<'a> Reader for StreamingReader<'a> { + fn read_u8(&mut self) -> Result { + let buf = self.read_fixed_bytes(1)?; + deserialize(&mut &buf[..]) + } + + fn read_u16(&mut self) -> Result { + let buf = self.read_fixed_bytes(2)?; + deserialize(&mut &buf[..]) + } + + fn read_u32(&mut self) -> Result { + let buf = self.read_fixed_bytes(4)?; + deserialize(&mut &buf[..]) + } + + fn read_i32(&mut self) -> Result { + let buf = self.read_fixed_bytes(4)?; + deserialize(&mut &buf[..]) + } + + fn read_u64(&mut self) -> Result { + let buf = self.read_fixed_bytes(8)?; + deserialize(&mut &buf[..]) + } + + fn read_i64(&mut self) -> Result { + let buf = self.read_fixed_bytes(8)?; + deserialize(&mut &buf[..]) + } + + /// Read a variable size vector from the underlying stream. Expects a usize + fn read_bytes_len_prefix(&mut self) -> Result, Error> { + let len = self.read_u64()?; + self.total_bytes_read += 8; + self.read_fixed_bytes(len as usize) + } + + /// Read a fixed number of bytes. + fn read_fixed_bytes(&mut self, len: usize) -> Result, Error> { + let mut buf = vec![0u8; len]; + read_exact(&mut self.stream, &mut buf, self.timeout, true)?; + self.total_bytes_read += len as u64; + Ok(buf) + } + + fn expect_u8(&mut self, val: u8) -> Result { + let b = self.read_u8()?; + if b == val { + Ok(b) + } else { + Err(Error::UnexpectedData { + expected: vec![val], + received: vec![b], + }) + } + } +} + impl Readable for Commitment { fn read(reader: &mut Reader) -> Result { let a = reader.read_fixed_bytes(PEDERSEN_COMMITMENT_SIZE)?; diff --git a/p2p/src/conn.rs b/p2p/src/conn.rs index b3f0ca0fd..63c041eb7 100644 --- a/p2p/src/conn.rs +++ b/p2p/src/conn.rs @@ -28,8 +28,9 @@ use std::{cmp, thread, time}; use core::ser; use core::ser::FixedLength; -use msg::{read_body, read_exact, read_header, write_all, write_to_buf, MsgHeader, Type}; +use msg::{read_body, read_header, read_item, write_to_buf, MsgHeader, Type}; use types::Error; +use util::read_write::{read_exact, write_all}; use util::{RateCounter, RwLock}; /// A trait to be implemented in order to receive messages from the @@ -69,17 +70,15 @@ impl<'a> Message<'a> { Message { header, conn } } - /// Get the TcpStream - pub fn get_conn(&mut self) -> TcpStream { - return self.conn.try_clone().unwrap(); + /// Read the message body from the underlying connection + pub fn body(&mut self) -> Result { + read_body(&self.header, self.conn) } - /// Read the message body from the underlying connection - pub fn body(&mut self) -> Result - where - T: ser::Readable, - { - read_body(&self.header, self.conn) + /// Read a single "thing" from the underlying connection. + /// Return the thing and the total bytes read. + pub fn streaming_read(&mut self) -> Result<(T, u64), Error> { + read_item(self.conn) } pub fn copy_attachment(&mut self, len: usize, writer: &mut Write) -> Result { diff --git a/p2p/src/msg.rs b/p2p/src/msg.rs index 7da0ef263..098879ca8 100644 --- a/p2p/src/msg.rs +++ b/p2p/src/msg.rs @@ -15,17 +15,17 @@ //! Message types that transit over the network and related serialization code. use num::FromPrimitive; -use std::io::{self, Read, Write}; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpStream}; -use std::{thread, time}; +use std::io::{Read, Write}; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::time; use core::consensus; use core::core::hash::Hash; use core::core::BlockHeader; use core::pow::Difficulty; -use core::ser::{self, FixedLength, Readable, Reader, Writeable, Writer}; - +use core::ser::{self, FixedLength, Readable, Reader, StreamingReader, Writeable, Writer}; use types::{Capabilities, Error, ReasonForBan, MAX_BLOCK_HEADERS, MAX_LOCATORS, MAX_PEER_ADDRS}; +use util::read_write::read_exact; /// Current latest version of the protocol pub const PROTOCOL_VERSION: u32 = 1; @@ -97,111 +97,19 @@ fn max_msg_size(msg_type: Type) -> u64 { } } -/// The default implementation of read_exact is useless with async TcpStream as -/// it will return as soon as something has been read, regardless of -/// whether the buffer has been filled (and then errors). This implementation -/// will block until it has read exactly `len` bytes and returns them as a -/// `vec`. Except for a timeout, this implementation will never return a -/// partially filled buffer. -/// -/// The timeout in milliseconds aborts the read when it's met. Note that the -/// time is not guaranteed to be exact. To support cases where we want to poll -/// instead of blocking, a `block_on_empty` boolean, when false, ensures -/// `read_exact` returns early with a `io::ErrorKind::WouldBlock` if nothing -/// has been read from the socket. -pub fn read_exact( - conn: &mut TcpStream, - mut buf: &mut [u8], - timeout: time::Duration, - block_on_empty: bool, -) -> io::Result<()> { - let sleep_time = time::Duration::from_micros(10); - let mut count = time::Duration::new(0, 0); - - let mut read = 0; - loop { - match conn.read(buf) { - Ok(0) => { - return Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "read_exact", - )); - } - Ok(n) => { - let tmp = buf; - buf = &mut tmp[n..]; - read += n; - } - Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - if read == 0 && !block_on_empty { - return Err(io::Error::new(io::ErrorKind::WouldBlock, "read_exact")); - } - } - Err(e) => return Err(e), - } - if !buf.is_empty() { - thread::sleep(sleep_time); - count += sleep_time; - } else { - break; - } - if count > timeout { - return Err(io::Error::new( - io::ErrorKind::TimedOut, - "reading from tcp stream", - )); - } - } - Ok(()) -} - -/// Same as `read_exact` but for writing. -pub fn write_all(conn: &mut Write, mut buf: &[u8], timeout: time::Duration) -> io::Result<()> { - let sleep_time = time::Duration::from_micros(10); - let mut count = time::Duration::new(0, 0); - - while !buf.is_empty() { - match conn.write(buf) { - Ok(0) => { - return Err(io::Error::new( - io::ErrorKind::WriteZero, - "failed to write whole buffer", - )) - } - Ok(n) => buf = &buf[n..], - Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} - Err(e) => return Err(e), - } - if !buf.is_empty() { - thread::sleep(sleep_time); - count += sleep_time; - } else { - break; - } - if count > timeout { - return Err(io::Error::new( - io::ErrorKind::TimedOut, - "reading from tcp stream", - )); - } - } - Ok(()) -} - -/// Read a header from the provided connection without blocking if the +/// Read a header from the provided stream without blocking if the /// underlying stream is async. Typically headers will be polled for, so /// we do not want to block. -pub fn read_header(conn: &mut TcpStream, msg_type: Option) -> Result { +pub fn read_header(stream: &mut Read, msg_type: Option) -> Result { let mut head = vec![0u8; MsgHeader::LEN]; if Some(Type::Hand) == msg_type { - read_exact(conn, &mut head, time::Duration::from_millis(10), true)?; + read_exact(stream, &mut head, time::Duration::from_millis(10), true)?; } else { - read_exact(conn, &mut head, time::Duration::from_secs(10), false)?; + read_exact(stream, &mut head, time::Duration::from_secs(10), false)?; } let header = ser::deserialize::(&mut &head[..])?; let max_len = max_msg_size(header.msg_type); + // TODO 4x the limits for now to leave ourselves space to change things if header.msg_len > max_len * 4 { error!( @@ -213,33 +121,34 @@ pub fn read_header(conn: &mut TcpStream, msg_type: Option) -> Result(stream: &mut Read) -> Result<(T, u64), Error> { + let timeout = time::Duration::from_secs(20); + let mut reader = StreamingReader::new(stream, timeout); + let res = T::read(&mut reader)?; + Ok((res, reader.total_bytes_read())) +} + +/// Read a message body from the provided stream, always blocking /// until we have a result (or timeout). -pub fn read_body(h: &MsgHeader, conn: &mut TcpStream) -> Result -where - T: Readable, -{ +pub fn read_body(h: &MsgHeader, stream: &mut Read) -> Result { let mut body = vec![0u8; h.msg_len as usize]; - read_exact(conn, &mut body, time::Duration::from_secs(20), true)?; + read_exact(stream, &mut body, time::Duration::from_secs(20), true)?; ser::deserialize(&mut &body[..]).map_err(From::from) } -/// Reads a full message from the underlying connection. -pub fn read_message(conn: &mut TcpStream, msg_type: Type) -> Result -where - T: Readable, -{ - let header = read_header(conn, Some(msg_type))?; +/// Reads a full message from the underlying stream. +pub fn read_message(stream: &mut Read, msg_type: Type) -> Result { + let header = read_header(stream, Some(msg_type))?; if header.msg_type != msg_type { return Err(Error::BadMessage); } - read_body(&header, conn) + read_body(&header, stream) } -pub fn write_to_buf(msg: T, msg_type: Type) -> Vec -where - T: Writeable, -{ +pub fn write_to_buf(msg: T, msg_type: Type) -> Vec { // prepare the body first so we know its serialized length let mut body_buf = vec![]; ser::serialize(&mut body_buf, &msg).unwrap(); @@ -253,13 +162,13 @@ where msg_buf } -pub fn write_message(conn: &mut TcpStream, msg: T, msg_type: Type) -> Result<(), Error> -where - T: Writeable + 'static, -{ +pub fn write_message( + stream: &mut Write, + msg: T, + msg_type: Type, +) -> Result<(), Error> { let buf = write_to_buf(msg, msg_type); - // send the whole thing - conn.write_all(&buf[..])?; + stream.write_all(&buf[..])?; Ok(()) } @@ -605,24 +514,6 @@ impl Writeable for Headers { } } -impl Readable for Headers { - fn read(reader: &mut Reader) -> Result { - let len = reader.read_u16()?; - if (len as u32) > MAX_BLOCK_HEADERS + 1 { - return Err(ser::Error::TooLargeReadErr); - } - let mut headers: Vec = Vec::with_capacity(len as usize); - for n in 0..len as usize { - let header = BlockHeader::read(reader)?; - if n > 0 && header.height != headers[n - 1].height + 1 { - return Err(ser::Error::CorruptedData); - } - headers.push(header); - } - Ok(Headers { headers: headers }) - } -} - pub struct Ping { /// total difficulty accumulated by the sender, used to check whether sync /// may be needed diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs index b3c15261f..950bbc5c5 100644 --- a/p2p/src/protocol.rs +++ b/p2p/src/protocol.rs @@ -15,20 +15,18 @@ use std::cmp; use std::env; use std::fs::File; -use std::io::{self, BufWriter}; -use std::net::{SocketAddr, TcpStream}; +use std::io::BufWriter; +use std::net::SocketAddr; use std::sync::Arc; -use std::time; use chrono::prelude::Utc; use conn::{Message, MessageHandler, Response}; use core::core::{self, hash::Hash, CompactBlock}; -use core::{global, ser}; use util::{RateCounter, RwLock}; use msg::{ - read_exact, BanReason, GetPeerAddrs, Headers, Locator, PeerAddrs, Ping, Pong, SockAddr, - TxHashSetArchive, TxHashSetRequest, Type, + BanReason, GetPeerAddrs, Headers, Locator, PeerAddrs, Ping, Pong, SockAddr, TxHashSetArchive, + TxHashSetRequest, Type, }; use types::{Error, NetAdapter}; @@ -194,32 +192,34 @@ impl MessageHandler for Protocol { // we can go request it from some of our peers Type::Header => { let header: core::BlockHeader = msg.body()?; - adapter.header_received(header, self.addr); - - // we do not return a hash here as we never request a single header - // a header will always arrive unsolicited Ok(None) } Type::Headers => { - let conn = &mut msg.get_conn(); + let mut total_bytes_read = 0; - let header_size: u64 = headers_header_size(conn, msg.header.msg_len)?; - let mut total_read: u64 = 2; - let mut reserved: Vec = vec![]; + // Read the count (u16) so we now how many headers to read. + let (count, bytes_read): (u16, _) = msg.streaming_read()?; + total_bytes_read += bytes_read; - while total_read < msg.header.msg_len || reserved.len() > 0 { - let headers: Headers = headers_streaming_body( - conn, - msg.header.msg_len, - 32, - &mut total_read, - &mut reserved, - header_size, - )?; - adapter.headers_received(headers.headers, self.addr); + // Read chunks of headers off the stream and pass them off to the adapter. + let chunk_size = 32; + for chunk in (0..count).collect::>().chunks(chunk_size) { + let mut headers = vec![]; + for _ in chunk { + let (header, bytes_read) = msg.streaming_read()?; + headers.push(header); + total_bytes_read += bytes_read; + } + adapter.headers_received(headers, self.addr); } + + // Now check we read the correct total number of bytes off the stream. + if total_bytes_read != msg.header.msg_len { + return Err(Error::MsgLen); + } + Ok(None) } @@ -344,97 +344,3 @@ impl MessageHandler for Protocol { } } } - -/// Read the Headers Vec size from the underlying connection, and calculate maximum header_size of one Header -fn headers_header_size(conn: &mut TcpStream, msg_len: u64) -> Result { - let mut size = vec![0u8; 2]; - // read size of Vec - read_exact(conn, &mut size, time::Duration::from_millis(10), true)?; - - let total_headers = size[0] as u64 * 256 + size[1] as u64; - if total_headers == 0 || total_headers > 10_000 { - return Err(Error::Connection(io::Error::new( - io::ErrorKind::InvalidData, - "headers_header_size", - ))); - } - let average_header_size = (msg_len - 2) / total_headers; - - // support size of Cuck(at)oo: from Cuck(at)oo 29 to Cuck(at)oo 35, with version 2 - // having slightly larger headers - let min_size = core::serialized_size_of_header(1, global::min_edge_bits()); - let max_size = min_size + 6; - if average_header_size < min_size as u64 || average_header_size > max_size as u64 { - debug!( - "headers_header_size - size of Vec: {}, average_header_size: {}, min: {}, max: {}", - total_headers, average_header_size, min_size, max_size, - ); - return Err(Error::Connection(io::Error::new( - io::ErrorKind::InvalidData, - "headers_header_size", - ))); - } - return Ok(max_size as u64); -} - -/// Read the Headers streaming body from the underlying connection -fn headers_streaming_body( - conn: &mut TcpStream, // (i) underlying connection - msg_len: u64, // (i) length of whole 'Headers' - headers_num: u64, // (i) how many BlockHeader(s) do you want to read - total_read: &mut u64, // (i/o) how many bytes already read on this 'Headers' message - reserved: &mut Vec, // (i/o) reserved part of previous read, which is not a whole header - max_header_size: u64, // (i) maximum possible size of single BlockHeader -) -> Result { - if headers_num == 0 || msg_len < *total_read || *total_read < 2 { - return Err(Error::Connection(io::Error::new( - io::ErrorKind::InvalidInput, - "headers_streaming_body", - ))); - } - - // Note: - // As we allow Cuckoo sizes greater than 30 now, the proof of work part of the header - // could be 30*42 bits, 31*42 bits, 32*42 bits, etc. - // So, for compatibility with variable size of block header, we read max possible size, for - // up to Cuckoo 36. - // - let mut read_size = headers_num * max_header_size - reserved.len() as u64; - if *total_read + read_size > msg_len { - read_size = msg_len - *total_read; - } - - // 1st part - let mut body = vec![0u8; 2]; // for Vec<> size - let mut final_headers_num = (read_size + reserved.len() as u64) / max_header_size; - let remaining = msg_len - *total_read - read_size; - if final_headers_num == 0 && remaining == 0 { - final_headers_num = 1; - } - body[0] = (final_headers_num >> 8) as u8; - body[1] = (final_headers_num & 0x00ff) as u8; - - // 2nd part - body.append(reserved); - - // 3rd part - let mut read_body = vec![0u8; read_size as usize]; - if read_size > 0 { - read_exact(conn, &mut read_body, time::Duration::from_secs(20), true)?; - *total_read += read_size; - } - body.append(&mut read_body); - - // deserialize these assembled 3 parts - let result: Result = ser::deserialize(&mut &body[..]).map_err(From::from); - let headers = result?; - - // remaining data - let mut deserialized_size = 2; // for Vec<> size - for header in &headers.headers { - deserialized_size += header.serialized_size(); - } - *reserved = body[deserialized_size..].to_vec(); - - Ok(headers) -} diff --git a/p2p/src/types.rs b/p2p/src/types.rs index dd87ac621..100cc4583 100644 --- a/p2p/src/types.rs +++ b/p2p/src/types.rs @@ -55,6 +55,7 @@ pub enum Error { Connection(io::Error), /// Header type does not match the expected message type BadMessage, + MsgLen, Banned, ConnectionClose, Timeout, diff --git a/util/src/lib.rs b/util/src/lib.rs index 91a7306b4..67ad79886 100644 --- a/util/src/lib.rs +++ b/util/src/lib.rs @@ -57,6 +57,9 @@ pub use types::{LogLevel, LoggingConfig}; pub mod macros; +// read_exact and write_all impls +pub mod read_write; + // other utils #[allow(unused_imports)] use std::ops::Deref; diff --git a/util/src/read_write.rs b/util/src/read_write.rs new file mode 100644 index 000000000..016e9fe10 --- /dev/null +++ b/util/src/read_write.rs @@ -0,0 +1,110 @@ +// Copyright 2018 The Grin Developers +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Custom impls of read_exact and write_all to work around async stream restrictions. + +use std::io; +use std::io::prelude::*; +use std::thread; +use std::time::Duration; + +/// The default implementation of read_exact is useless with an async stream (TcpStream) as +/// it will return as soon as something has been read, regardless of +/// whether the buffer has been filled (and then errors). This implementation +/// will block until it has read exactly `len` bytes and returns them as a +/// `vec`. Except for a timeout, this implementation will never return a +/// partially filled buffer. +/// +/// The timeout in milliseconds aborts the read when it's met. Note that the +/// time is not guaranteed to be exact. To support cases where we want to poll +/// instead of blocking, a `block_on_empty` boolean, when false, ensures +/// `read_exact` returns early with a `io::ErrorKind::WouldBlock` if nothing +/// has been read from the socket. +pub fn read_exact( + stream: &mut Read, + mut buf: &mut [u8], + timeout: Duration, + block_on_empty: bool, +) -> io::Result<()> { + let sleep_time = Duration::from_micros(10); + let mut count = Duration::new(0, 0); + + let mut read = 0; + loop { + match stream.read(buf) { + Ok(0) => { + return Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + "read_exact", + )); + } + Ok(n) => { + let tmp = buf; + buf = &mut tmp[n..]; + read += n; + } + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + if read == 0 && !block_on_empty { + return Err(io::Error::new(io::ErrorKind::WouldBlock, "read_exact")); + } + } + Err(e) => return Err(e), + } + if !buf.is_empty() { + thread::sleep(sleep_time); + count += sleep_time; + } else { + break; + } + if count > timeout { + return Err(io::Error::new( + io::ErrorKind::TimedOut, + "reading from stream", + )); + } + } + Ok(()) +} + +/// Same as `read_exact` but for writing. +pub fn write_all(stream: &mut Write, mut buf: &[u8], timeout: Duration) -> io::Result<()> { + let sleep_time = Duration::from_micros(10); + let mut count = Duration::new(0, 0); + + while !buf.is_empty() { + match stream.write(buf) { + Ok(0) => { + return Err(io::Error::new( + io::ErrorKind::WriteZero, + "failed to write whole buffer", + )) + } + Ok(n) => buf = &buf[n..], + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} + Err(e) => return Err(e), + } + if !buf.is_empty() { + thread::sleep(sleep_time); + count += sleep_time; + } else { + break; + } + if count > timeout { + return Err(io::Error::new(io::ErrorKind::TimedOut, "writing to stream")); + } + } + Ok(()) +}