Streaming headers (#1989)

* headers msg is now "streamed" off the tcp stream

* rustfmt

* cleanup

* move StreamingReader into ser.rs
extract read_exact out into util crate

* rustfmt
This commit is contained in:
Antioch Peverell 2018-11-16 11:00:39 +00:00 committed by GitHub
parent a011450825
commit f0fa410273
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 280 additions and 321 deletions

View file

@ -34,7 +34,7 @@ use core::{
use global;
use keychain::{self, BlindingFactor};
use pow::{Difficulty, Proof, ProofOfWork};
use ser::{self, FixedLength, HashOnlyPMMRable, Readable, Reader, Writeable, Writer};
use ser::{self, HashOnlyPMMRable, Readable, Reader, Writeable, Writer};
use util::{secp, static_secp_instance};
/// Errors thrown by Block validation
@ -140,34 +140,6 @@ pub struct BlockHeader {
pub pow: ProofOfWork,
}
const FIXED_HEADER_SIZE: usize = 2 // version
+ 8 // height
+ 8 // timestamp
+ 5 * Hash::LEN // prev_hash, prev_root, output_root, range_proof_root, kernel_root
+ BlindingFactor::LEN // total_kernel_offset
+ 2 * 8 // output_mmr_size, kernel_mmr_size
+ Difficulty::LEN // total_difficulty
+ 4 // secondary_scaling
+ 8; // nonce
/// Serialized size of fixed part of a BlockHeader, i.e. without pow
fn fixed_size_of_serialized_header(_version: u16) -> usize {
FIXED_HEADER_SIZE
}
/// Serialized size of a BlockHeader
pub fn serialized_size_of_header(version: u16, edge_bits: u8) -> usize {
let mut size = fixed_size_of_serialized_header(version);
size += 1; // pow.edge_bits
let bitvec_len = global::proofsize() * edge_bits as usize;
size += bitvec_len / 8; // pow.nonces
if bitvec_len % 8 != 0 {
size += 1;
}
size
}
impl Default for BlockHeader {
fn default() -> BlockHeader {
BlockHeader {
@ -292,20 +264,6 @@ impl BlockHeader {
pub fn total_kernel_offset(&self) -> BlindingFactor {
self.total_kernel_offset
}
/// Serialized size of this header
pub fn serialized_size(&self) -> usize {
let mut size = fixed_size_of_serialized_header(self.version);
size += 1; // pow.edge_bits
let nonce_bits = self.pow.edge_bits() as usize;
let bitvec_len = global::proofsize() * nonce_bits;
size += bitvec_len / 8; // pow.nonces
if bitvec_len % 8 != 0 {
size += 1;
}
size
}
}
/// A block as expressed in the MimbleWimble protocol. The reward is

View file

@ -19,6 +19,8 @@
//! To use it simply implement `Writeable` or `Readable` and then use the
//! `serialize` or `deserialize` functions on them as appropriate.
use std::time::Duration;
use byteorder::{BigEndian, ByteOrder, ReadBytesExt};
use consensus;
use core::hash::{Hash, Hashed};
@ -27,6 +29,7 @@ use std::fmt::Debug;
use std::io::{self, Read, Write};
use std::marker;
use std::{cmp, error, fmt};
use util::read_write::read_exact;
use util::secp::constants::{
AGG_SIGNATURE_SIZE, MAX_PROOF_SIZE, PEDERSEN_COMMITMENT_SIZE, SECRET_KEY_SIZE,
};
@ -206,7 +209,8 @@ pub trait Writeable {
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), Error>;
}
struct IteratingReader<'a, T> {
/// Reader that exposes an Iterator interface.
pub struct IteratingReader<'a, T> {
count: u64,
curr: u64,
reader: &'a mut Reader,
@ -214,7 +218,9 @@ struct IteratingReader<'a, T> {
}
impl<'a, T> IteratingReader<'a, T> {
fn new(reader: &'a mut Reader, count: u64) -> IteratingReader<'a, T> {
/// Constructor to create a new iterating reader for the provided underlying reader.
/// Takes a count so we know how many to iterate over.
pub fn new(reader: &'a mut Reader, count: u64) -> IteratingReader<'a, T> {
let curr = 0;
IteratingReader {
count,
@ -270,7 +276,7 @@ where
fn read(reader: &mut Reader) -> Result<Self, Error>;
}
/// Deserializes a Readeable from any std::io::Read implementation.
/// Deserializes a Readable from any std::io::Read implementation.
pub fn deserialize<T: Readable>(source: &mut Read) -> Result<T, Error> {
let mut reader = BinReader { source };
T::read(&mut reader)
@ -325,13 +331,14 @@ impl<'a> Reader for BinReader<'a> {
let len = self.read_u64()?;
self.read_fixed_bytes(len as usize)
}
/// Read a fixed number of bytes.
fn read_fixed_bytes(&mut self, length: usize) -> Result<Vec<u8>, Error> {
// not reading more than 100k in a single read
if length > 100000 {
fn read_fixed_bytes(&mut self, len: usize) -> Result<Vec<u8>, Error> {
// not reading more than 100k bytes in a single read
if len > 100_000 {
return Err(Error::TooLargeReadErr);
}
let mut buf = vec![0; length];
let mut buf = vec![0; len];
self.source
.read_exact(&mut buf)
.map(move |_| buf)
@ -351,6 +358,90 @@ impl<'a> Reader for BinReader<'a> {
}
}
/// A reader that reads straight off a stream.
/// Tracks total bytes read so we can verify we read the right number afterwards.
pub struct StreamingReader<'a> {
total_bytes_read: u64,
stream: &'a mut Read,
timeout: Duration,
}
impl<'a> StreamingReader<'a> {
/// Create a new streaming reader with the provided underlying stream.
/// Also takes a duration to be used for each individual read_exact call.
pub fn new(stream: &'a mut Read, timeout: Duration) -> StreamingReader<'a> {
StreamingReader {
total_bytes_read: 0,
stream,
timeout,
}
}
/// Returns the total bytes read via this streaming reader.
pub fn total_bytes_read(&self) -> u64 {
self.total_bytes_read
}
}
impl<'a> Reader for StreamingReader<'a> {
fn read_u8(&mut self) -> Result<u8, Error> {
let buf = self.read_fixed_bytes(1)?;
deserialize(&mut &buf[..])
}
fn read_u16(&mut self) -> Result<u16, Error> {
let buf = self.read_fixed_bytes(2)?;
deserialize(&mut &buf[..])
}
fn read_u32(&mut self) -> Result<u32, Error> {
let buf = self.read_fixed_bytes(4)?;
deserialize(&mut &buf[..])
}
fn read_i32(&mut self) -> Result<i32, Error> {
let buf = self.read_fixed_bytes(4)?;
deserialize(&mut &buf[..])
}
fn read_u64(&mut self) -> Result<u64, Error> {
let buf = self.read_fixed_bytes(8)?;
deserialize(&mut &buf[..])
}
fn read_i64(&mut self) -> Result<i64, Error> {
let buf = self.read_fixed_bytes(8)?;
deserialize(&mut &buf[..])
}
/// Read a variable size vector from the underlying stream. Expects a usize
fn read_bytes_len_prefix(&mut self) -> Result<Vec<u8>, Error> {
let len = self.read_u64()?;
self.total_bytes_read += 8;
self.read_fixed_bytes(len as usize)
}
/// Read a fixed number of bytes.
fn read_fixed_bytes(&mut self, len: usize) -> Result<Vec<u8>, Error> {
let mut buf = vec![0u8; len];
read_exact(&mut self.stream, &mut buf, self.timeout, true)?;
self.total_bytes_read += len as u64;
Ok(buf)
}
fn expect_u8(&mut self, val: u8) -> Result<u8, Error> {
let b = self.read_u8()?;
if b == val {
Ok(b)
} else {
Err(Error::UnexpectedData {
expected: vec![val],
received: vec![b],
})
}
}
}
impl Readable for Commitment {
fn read(reader: &mut Reader) -> Result<Commitment, Error> {
let a = reader.read_fixed_bytes(PEDERSEN_COMMITMENT_SIZE)?;

View file

@ -28,8 +28,9 @@ use std::{cmp, thread, time};
use core::ser;
use core::ser::FixedLength;
use msg::{read_body, read_exact, read_header, write_all, write_to_buf, MsgHeader, Type};
use msg::{read_body, read_header, read_item, write_to_buf, MsgHeader, Type};
use types::Error;
use util::read_write::{read_exact, write_all};
use util::{RateCounter, RwLock};
/// A trait to be implemented in order to receive messages from the
@ -69,17 +70,15 @@ impl<'a> Message<'a> {
Message { header, conn }
}
/// Get the TcpStream
pub fn get_conn(&mut self) -> TcpStream {
return self.conn.try_clone().unwrap();
/// Read the message body from the underlying connection
pub fn body<T: ser::Readable>(&mut self) -> Result<T, Error> {
read_body(&self.header, self.conn)
}
/// Read the message body from the underlying connection
pub fn body<T>(&mut self) -> Result<T, Error>
where
T: ser::Readable,
{
read_body(&self.header, self.conn)
/// Read a single "thing" from the underlying connection.
/// Return the thing and the total bytes read.
pub fn streaming_read<T: ser::Readable>(&mut self) -> Result<(T, u64), Error> {
read_item(self.conn)
}
pub fn copy_attachment(&mut self, len: usize, writer: &mut Write) -> Result<usize, Error> {

View file

@ -15,17 +15,17 @@
//! Message types that transit over the network and related serialization code.
use num::FromPrimitive;
use std::io::{self, Read, Write};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpStream};
use std::{thread, time};
use std::io::{Read, Write};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::time;
use core::consensus;
use core::core::hash::Hash;
use core::core::BlockHeader;
use core::pow::Difficulty;
use core::ser::{self, FixedLength, Readable, Reader, Writeable, Writer};
use core::ser::{self, FixedLength, Readable, Reader, StreamingReader, Writeable, Writer};
use types::{Capabilities, Error, ReasonForBan, MAX_BLOCK_HEADERS, MAX_LOCATORS, MAX_PEER_ADDRS};
use util::read_write::read_exact;
/// Current latest version of the protocol
pub const PROTOCOL_VERSION: u32 = 1;
@ -97,111 +97,19 @@ fn max_msg_size(msg_type: Type) -> u64 {
}
}
/// The default implementation of read_exact is useless with async TcpStream as
/// it will return as soon as something has been read, regardless of
/// whether the buffer has been filled (and then errors). This implementation
/// will block until it has read exactly `len` bytes and returns them as a
/// `vec<u8>`. Except for a timeout, this implementation will never return a
/// partially filled buffer.
///
/// The timeout in milliseconds aborts the read when it's met. Note that the
/// time is not guaranteed to be exact. To support cases where we want to poll
/// instead of blocking, a `block_on_empty` boolean, when false, ensures
/// `read_exact` returns early with a `io::ErrorKind::WouldBlock` if nothing
/// has been read from the socket.
pub fn read_exact(
conn: &mut TcpStream,
mut buf: &mut [u8],
timeout: time::Duration,
block_on_empty: bool,
) -> io::Result<()> {
let sleep_time = time::Duration::from_micros(10);
let mut count = time::Duration::new(0, 0);
let mut read = 0;
loop {
match conn.read(buf) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"read_exact",
));
}
Ok(n) => {
let tmp = buf;
buf = &mut tmp[n..];
read += n;
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if read == 0 && !block_on_empty {
return Err(io::Error::new(io::ErrorKind::WouldBlock, "read_exact"));
}
}
Err(e) => return Err(e),
}
if !buf.is_empty() {
thread::sleep(sleep_time);
count += sleep_time;
} else {
break;
}
if count > timeout {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"reading from tcp stream",
));
}
}
Ok(())
}
/// Same as `read_exact` but for writing.
pub fn write_all(conn: &mut Write, mut buf: &[u8], timeout: time::Duration) -> io::Result<()> {
let sleep_time = time::Duration::from_micros(10);
let mut count = time::Duration::new(0, 0);
while !buf.is_empty() {
match conn.write(buf) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write whole buffer",
))
}
Ok(n) => buf = &buf[n..],
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => return Err(e),
}
if !buf.is_empty() {
thread::sleep(sleep_time);
count += sleep_time;
} else {
break;
}
if count > timeout {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"reading from tcp stream",
));
}
}
Ok(())
}
/// Read a header from the provided connection without blocking if the
/// Read a header from the provided stream without blocking if the
/// underlying stream is async. Typically headers will be polled for, so
/// we do not want to block.
pub fn read_header(conn: &mut TcpStream, msg_type: Option<Type>) -> Result<MsgHeader, Error> {
pub fn read_header(stream: &mut Read, msg_type: Option<Type>) -> Result<MsgHeader, Error> {
let mut head = vec![0u8; MsgHeader::LEN];
if Some(Type::Hand) == msg_type {
read_exact(conn, &mut head, time::Duration::from_millis(10), true)?;
read_exact(stream, &mut head, time::Duration::from_millis(10), true)?;
} else {
read_exact(conn, &mut head, time::Duration::from_secs(10), false)?;
read_exact(stream, &mut head, time::Duration::from_secs(10), false)?;
}
let header = ser::deserialize::<MsgHeader>(&mut &head[..])?;
let max_len = max_msg_size(header.msg_type);
// TODO 4x the limits for now to leave ourselves space to change things
if header.msg_len > max_len * 4 {
error!(
@ -213,33 +121,34 @@ pub fn read_header(conn: &mut TcpStream, msg_type: Option<Type>) -> Result<MsgHe
Ok(header)
}
/// Read a message body from the provided connection, always blocking
/// Read a single item from the provided stream, always blocking until we
/// have a result (or timeout).
/// Returns the item and the total bytes read.
pub fn read_item<T: Readable>(stream: &mut Read) -> Result<(T, u64), Error> {
let timeout = time::Duration::from_secs(20);
let mut reader = StreamingReader::new(stream, timeout);
let res = T::read(&mut reader)?;
Ok((res, reader.total_bytes_read()))
}
/// Read a message body from the provided stream, always blocking
/// until we have a result (or timeout).
pub fn read_body<T>(h: &MsgHeader, conn: &mut TcpStream) -> Result<T, Error>
where
T: Readable,
{
pub fn read_body<T: Readable>(h: &MsgHeader, stream: &mut Read) -> Result<T, Error> {
let mut body = vec![0u8; h.msg_len as usize];
read_exact(conn, &mut body, time::Duration::from_secs(20), true)?;
read_exact(stream, &mut body, time::Duration::from_secs(20), true)?;
ser::deserialize(&mut &body[..]).map_err(From::from)
}
/// Reads a full message from the underlying connection.
pub fn read_message<T>(conn: &mut TcpStream, msg_type: Type) -> Result<T, Error>
where
T: Readable,
{
let header = read_header(conn, Some(msg_type))?;
/// Reads a full message from the underlying stream.
pub fn read_message<T: Readable>(stream: &mut Read, msg_type: Type) -> Result<T, Error> {
let header = read_header(stream, Some(msg_type))?;
if header.msg_type != msg_type {
return Err(Error::BadMessage);
}
read_body(&header, conn)
read_body(&header, stream)
}
pub fn write_to_buf<T>(msg: T, msg_type: Type) -> Vec<u8>
where
T: Writeable,
{
pub fn write_to_buf<T: Writeable>(msg: T, msg_type: Type) -> Vec<u8> {
// prepare the body first so we know its serialized length
let mut body_buf = vec![];
ser::serialize(&mut body_buf, &msg).unwrap();
@ -253,13 +162,13 @@ where
msg_buf
}
pub fn write_message<T>(conn: &mut TcpStream, msg: T, msg_type: Type) -> Result<(), Error>
where
T: Writeable + 'static,
{
pub fn write_message<T: Writeable>(
stream: &mut Write,
msg: T,
msg_type: Type,
) -> Result<(), Error> {
let buf = write_to_buf(msg, msg_type);
// send the whole thing
conn.write_all(&buf[..])?;
stream.write_all(&buf[..])?;
Ok(())
}
@ -605,24 +514,6 @@ impl Writeable for Headers {
}
}
impl Readable for Headers {
fn read(reader: &mut Reader) -> Result<Headers, ser::Error> {
let len = reader.read_u16()?;
if (len as u32) > MAX_BLOCK_HEADERS + 1 {
return Err(ser::Error::TooLargeReadErr);
}
let mut headers: Vec<BlockHeader> = Vec::with_capacity(len as usize);
for n in 0..len as usize {
let header = BlockHeader::read(reader)?;
if n > 0 && header.height != headers[n - 1].height + 1 {
return Err(ser::Error::CorruptedData);
}
headers.push(header);
}
Ok(Headers { headers: headers })
}
}
pub struct Ping {
/// total difficulty accumulated by the sender, used to check whether sync
/// may be needed

View file

@ -15,20 +15,18 @@
use std::cmp;
use std::env;
use std::fs::File;
use std::io::{self, BufWriter};
use std::net::{SocketAddr, TcpStream};
use std::io::BufWriter;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time;
use chrono::prelude::Utc;
use conn::{Message, MessageHandler, Response};
use core::core::{self, hash::Hash, CompactBlock};
use core::{global, ser};
use util::{RateCounter, RwLock};
use msg::{
read_exact, BanReason, GetPeerAddrs, Headers, Locator, PeerAddrs, Ping, Pong, SockAddr,
TxHashSetArchive, TxHashSetRequest, Type,
BanReason, GetPeerAddrs, Headers, Locator, PeerAddrs, Ping, Pong, SockAddr, TxHashSetArchive,
TxHashSetRequest, Type,
};
use types::{Error, NetAdapter};
@ -194,32 +192,34 @@ impl MessageHandler for Protocol {
// we can go request it from some of our peers
Type::Header => {
let header: core::BlockHeader = msg.body()?;
adapter.header_received(header, self.addr);
// we do not return a hash here as we never request a single header
// a header will always arrive unsolicited
Ok(None)
}
Type::Headers => {
let conn = &mut msg.get_conn();
let mut total_bytes_read = 0;
let header_size: u64 = headers_header_size(conn, msg.header.msg_len)?;
let mut total_read: u64 = 2;
let mut reserved: Vec<u8> = vec![];
// Read the count (u16) so we now how many headers to read.
let (count, bytes_read): (u16, _) = msg.streaming_read()?;
total_bytes_read += bytes_read;
while total_read < msg.header.msg_len || reserved.len() > 0 {
let headers: Headers = headers_streaming_body(
conn,
msg.header.msg_len,
32,
&mut total_read,
&mut reserved,
header_size,
)?;
adapter.headers_received(headers.headers, self.addr);
// Read chunks of headers off the stream and pass them off to the adapter.
let chunk_size = 32;
for chunk in (0..count).collect::<Vec<_>>().chunks(chunk_size) {
let mut headers = vec![];
for _ in chunk {
let (header, bytes_read) = msg.streaming_read()?;
headers.push(header);
total_bytes_read += bytes_read;
}
adapter.headers_received(headers, self.addr);
}
// Now check we read the correct total number of bytes off the stream.
if total_bytes_read != msg.header.msg_len {
return Err(Error::MsgLen);
}
Ok(None)
}
@ -344,97 +344,3 @@ impl MessageHandler for Protocol {
}
}
}
/// Read the Headers Vec size from the underlying connection, and calculate maximum header_size of one Header
fn headers_header_size(conn: &mut TcpStream, msg_len: u64) -> Result<u64, Error> {
let mut size = vec![0u8; 2];
// read size of Vec<BlockHeader>
read_exact(conn, &mut size, time::Duration::from_millis(10), true)?;
let total_headers = size[0] as u64 * 256 + size[1] as u64;
if total_headers == 0 || total_headers > 10_000 {
return Err(Error::Connection(io::Error::new(
io::ErrorKind::InvalidData,
"headers_header_size",
)));
}
let average_header_size = (msg_len - 2) / total_headers;
// support size of Cuck(at)oo: from Cuck(at)oo 29 to Cuck(at)oo 35, with version 2
// having slightly larger headers
let min_size = core::serialized_size_of_header(1, global::min_edge_bits());
let max_size = min_size + 6;
if average_header_size < min_size as u64 || average_header_size > max_size as u64 {
debug!(
"headers_header_size - size of Vec: {}, average_header_size: {}, min: {}, max: {}",
total_headers, average_header_size, min_size, max_size,
);
return Err(Error::Connection(io::Error::new(
io::ErrorKind::InvalidData,
"headers_header_size",
)));
}
return Ok(max_size as u64);
}
/// Read the Headers streaming body from the underlying connection
fn headers_streaming_body(
conn: &mut TcpStream, // (i) underlying connection
msg_len: u64, // (i) length of whole 'Headers'
headers_num: u64, // (i) how many BlockHeader(s) do you want to read
total_read: &mut u64, // (i/o) how many bytes already read on this 'Headers' message
reserved: &mut Vec<u8>, // (i/o) reserved part of previous read, which is not a whole header
max_header_size: u64, // (i) maximum possible size of single BlockHeader
) -> Result<Headers, Error> {
if headers_num == 0 || msg_len < *total_read || *total_read < 2 {
return Err(Error::Connection(io::Error::new(
io::ErrorKind::InvalidInput,
"headers_streaming_body",
)));
}
// Note:
// As we allow Cuckoo sizes greater than 30 now, the proof of work part of the header
// could be 30*42 bits, 31*42 bits, 32*42 bits, etc.
// So, for compatibility with variable size of block header, we read max possible size, for
// up to Cuckoo 36.
//
let mut read_size = headers_num * max_header_size - reserved.len() as u64;
if *total_read + read_size > msg_len {
read_size = msg_len - *total_read;
}
// 1st part
let mut body = vec![0u8; 2]; // for Vec<> size
let mut final_headers_num = (read_size + reserved.len() as u64) / max_header_size;
let remaining = msg_len - *total_read - read_size;
if final_headers_num == 0 && remaining == 0 {
final_headers_num = 1;
}
body[0] = (final_headers_num >> 8) as u8;
body[1] = (final_headers_num & 0x00ff) as u8;
// 2nd part
body.append(reserved);
// 3rd part
let mut read_body = vec![0u8; read_size as usize];
if read_size > 0 {
read_exact(conn, &mut read_body, time::Duration::from_secs(20), true)?;
*total_read += read_size;
}
body.append(&mut read_body);
// deserialize these assembled 3 parts
let result: Result<Headers, Error> = ser::deserialize(&mut &body[..]).map_err(From::from);
let headers = result?;
// remaining data
let mut deserialized_size = 2; // for Vec<> size
for header in &headers.headers {
deserialized_size += header.serialized_size();
}
*reserved = body[deserialized_size..].to_vec();
Ok(headers)
}

View file

@ -55,6 +55,7 @@ pub enum Error {
Connection(io::Error),
/// Header type does not match the expected message type
BadMessage,
MsgLen,
Banned,
ConnectionClose,
Timeout,

View file

@ -57,6 +57,9 @@ pub use types::{LogLevel, LoggingConfig};
pub mod macros;
// read_exact and write_all impls
pub mod read_write;
// other utils
#[allow(unused_imports)]
use std::ops::Deref;

110
util/src/read_write.rs Normal file
View file

@ -0,0 +1,110 @@
// Copyright 2018 The Grin Developers
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Custom impls of read_exact and write_all to work around async stream restrictions.
use std::io;
use std::io::prelude::*;
use std::thread;
use std::time::Duration;
/// The default implementation of read_exact is useless with an async stream (TcpStream) as
/// it will return as soon as something has been read, regardless of
/// whether the buffer has been filled (and then errors). This implementation
/// will block until it has read exactly `len` bytes and returns them as a
/// `vec<u8>`. Except for a timeout, this implementation will never return a
/// partially filled buffer.
///
/// The timeout in milliseconds aborts the read when it's met. Note that the
/// time is not guaranteed to be exact. To support cases where we want to poll
/// instead of blocking, a `block_on_empty` boolean, when false, ensures
/// `read_exact` returns early with a `io::ErrorKind::WouldBlock` if nothing
/// has been read from the socket.
pub fn read_exact(
stream: &mut Read,
mut buf: &mut [u8],
timeout: Duration,
block_on_empty: bool,
) -> io::Result<()> {
let sleep_time = Duration::from_micros(10);
let mut count = Duration::new(0, 0);
let mut read = 0;
loop {
match stream.read(buf) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"read_exact",
));
}
Ok(n) => {
let tmp = buf;
buf = &mut tmp[n..];
read += n;
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if read == 0 && !block_on_empty {
return Err(io::Error::new(io::ErrorKind::WouldBlock, "read_exact"));
}
}
Err(e) => return Err(e),
}
if !buf.is_empty() {
thread::sleep(sleep_time);
count += sleep_time;
} else {
break;
}
if count > timeout {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"reading from stream",
));
}
}
Ok(())
}
/// Same as `read_exact` but for writing.
pub fn write_all(stream: &mut Write, mut buf: &[u8], timeout: Duration) -> io::Result<()> {
let sleep_time = Duration::from_micros(10);
let mut count = Duration::new(0, 0);
while !buf.is_empty() {
match stream.write(buf) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write whole buffer",
))
}
Ok(n) => buf = &buf[n..],
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => return Err(e),
}
if !buf.is_empty() {
thread::sleep(sleep_time);
count += sleep_time;
} else {
break;
}
if count > timeout {
return Err(io::Error::new(io::ErrorKind::TimedOut, "writing to stream"));
}
}
Ok(())
}