mirror of
https://github.com/mimblewimble/grin.git
synced 2025-01-20 19:11:08 +03:00
Streaming headers (#1989)
* headers msg is now "streamed" off the tcp stream * rustfmt * cleanup * move StreamingReader into ser.rs extract read_exact out into util crate * rustfmt
This commit is contained in:
parent
a011450825
commit
f0fa410273
8 changed files with 280 additions and 321 deletions
|
@ -34,7 +34,7 @@ use core::{
|
|||
use global;
|
||||
use keychain::{self, BlindingFactor};
|
||||
use pow::{Difficulty, Proof, ProofOfWork};
|
||||
use ser::{self, FixedLength, HashOnlyPMMRable, Readable, Reader, Writeable, Writer};
|
||||
use ser::{self, HashOnlyPMMRable, Readable, Reader, Writeable, Writer};
|
||||
use util::{secp, static_secp_instance};
|
||||
|
||||
/// Errors thrown by Block validation
|
||||
|
@ -140,34 +140,6 @@ pub struct BlockHeader {
|
|||
pub pow: ProofOfWork,
|
||||
}
|
||||
|
||||
const FIXED_HEADER_SIZE: usize = 2 // version
|
||||
+ 8 // height
|
||||
+ 8 // timestamp
|
||||
+ 5 * Hash::LEN // prev_hash, prev_root, output_root, range_proof_root, kernel_root
|
||||
+ BlindingFactor::LEN // total_kernel_offset
|
||||
+ 2 * 8 // output_mmr_size, kernel_mmr_size
|
||||
+ Difficulty::LEN // total_difficulty
|
||||
+ 4 // secondary_scaling
|
||||
+ 8; // nonce
|
||||
|
||||
/// Serialized size of fixed part of a BlockHeader, i.e. without pow
|
||||
fn fixed_size_of_serialized_header(_version: u16) -> usize {
|
||||
FIXED_HEADER_SIZE
|
||||
}
|
||||
|
||||
/// Serialized size of a BlockHeader
|
||||
pub fn serialized_size_of_header(version: u16, edge_bits: u8) -> usize {
|
||||
let mut size = fixed_size_of_serialized_header(version);
|
||||
|
||||
size += 1; // pow.edge_bits
|
||||
let bitvec_len = global::proofsize() * edge_bits as usize;
|
||||
size += bitvec_len / 8; // pow.nonces
|
||||
if bitvec_len % 8 != 0 {
|
||||
size += 1;
|
||||
}
|
||||
size
|
||||
}
|
||||
|
||||
impl Default for BlockHeader {
|
||||
fn default() -> BlockHeader {
|
||||
BlockHeader {
|
||||
|
@ -292,20 +264,6 @@ impl BlockHeader {
|
|||
pub fn total_kernel_offset(&self) -> BlindingFactor {
|
||||
self.total_kernel_offset
|
||||
}
|
||||
|
||||
/// Serialized size of this header
|
||||
pub fn serialized_size(&self) -> usize {
|
||||
let mut size = fixed_size_of_serialized_header(self.version);
|
||||
|
||||
size += 1; // pow.edge_bits
|
||||
let nonce_bits = self.pow.edge_bits() as usize;
|
||||
let bitvec_len = global::proofsize() * nonce_bits;
|
||||
size += bitvec_len / 8; // pow.nonces
|
||||
if bitvec_len % 8 != 0 {
|
||||
size += 1;
|
||||
}
|
||||
size
|
||||
}
|
||||
}
|
||||
|
||||
/// A block as expressed in the MimbleWimble protocol. The reward is
|
||||
|
|
105
core/src/ser.rs
105
core/src/ser.rs
|
@ -19,6 +19,8 @@
|
|||
//! To use it simply implement `Writeable` or `Readable` and then use the
|
||||
//! `serialize` or `deserialize` functions on them as appropriate.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use byteorder::{BigEndian, ByteOrder, ReadBytesExt};
|
||||
use consensus;
|
||||
use core::hash::{Hash, Hashed};
|
||||
|
@ -27,6 +29,7 @@ use std::fmt::Debug;
|
|||
use std::io::{self, Read, Write};
|
||||
use std::marker;
|
||||
use std::{cmp, error, fmt};
|
||||
use util::read_write::read_exact;
|
||||
use util::secp::constants::{
|
||||
AGG_SIGNATURE_SIZE, MAX_PROOF_SIZE, PEDERSEN_COMMITMENT_SIZE, SECRET_KEY_SIZE,
|
||||
};
|
||||
|
@ -206,7 +209,8 @@ pub trait Writeable {
|
|||
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), Error>;
|
||||
}
|
||||
|
||||
struct IteratingReader<'a, T> {
|
||||
/// Reader that exposes an Iterator interface.
|
||||
pub struct IteratingReader<'a, T> {
|
||||
count: u64,
|
||||
curr: u64,
|
||||
reader: &'a mut Reader,
|
||||
|
@ -214,7 +218,9 @@ struct IteratingReader<'a, T> {
|
|||
}
|
||||
|
||||
impl<'a, T> IteratingReader<'a, T> {
|
||||
fn new(reader: &'a mut Reader, count: u64) -> IteratingReader<'a, T> {
|
||||
/// Constructor to create a new iterating reader for the provided underlying reader.
|
||||
/// Takes a count so we know how many to iterate over.
|
||||
pub fn new(reader: &'a mut Reader, count: u64) -> IteratingReader<'a, T> {
|
||||
let curr = 0;
|
||||
IteratingReader {
|
||||
count,
|
||||
|
@ -270,7 +276,7 @@ where
|
|||
fn read(reader: &mut Reader) -> Result<Self, Error>;
|
||||
}
|
||||
|
||||
/// Deserializes a Readeable from any std::io::Read implementation.
|
||||
/// Deserializes a Readable from any std::io::Read implementation.
|
||||
pub fn deserialize<T: Readable>(source: &mut Read) -> Result<T, Error> {
|
||||
let mut reader = BinReader { source };
|
||||
T::read(&mut reader)
|
||||
|
@ -325,13 +331,14 @@ impl<'a> Reader for BinReader<'a> {
|
|||
let len = self.read_u64()?;
|
||||
self.read_fixed_bytes(len as usize)
|
||||
}
|
||||
|
||||
/// Read a fixed number of bytes.
|
||||
fn read_fixed_bytes(&mut self, length: usize) -> Result<Vec<u8>, Error> {
|
||||
// not reading more than 100k in a single read
|
||||
if length > 100000 {
|
||||
fn read_fixed_bytes(&mut self, len: usize) -> Result<Vec<u8>, Error> {
|
||||
// not reading more than 100k bytes in a single read
|
||||
if len > 100_000 {
|
||||
return Err(Error::TooLargeReadErr);
|
||||
}
|
||||
let mut buf = vec![0; length];
|
||||
let mut buf = vec![0; len];
|
||||
self.source
|
||||
.read_exact(&mut buf)
|
||||
.map(move |_| buf)
|
||||
|
@ -351,6 +358,90 @@ impl<'a> Reader for BinReader<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
/// A reader that reads straight off a stream.
|
||||
/// Tracks total bytes read so we can verify we read the right number afterwards.
|
||||
pub struct StreamingReader<'a> {
|
||||
total_bytes_read: u64,
|
||||
stream: &'a mut Read,
|
||||
timeout: Duration,
|
||||
}
|
||||
|
||||
impl<'a> StreamingReader<'a> {
|
||||
/// Create a new streaming reader with the provided underlying stream.
|
||||
/// Also takes a duration to be used for each individual read_exact call.
|
||||
pub fn new(stream: &'a mut Read, timeout: Duration) -> StreamingReader<'a> {
|
||||
StreamingReader {
|
||||
total_bytes_read: 0,
|
||||
stream,
|
||||
timeout,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the total bytes read via this streaming reader.
|
||||
pub fn total_bytes_read(&self) -> u64 {
|
||||
self.total_bytes_read
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Reader for StreamingReader<'a> {
|
||||
fn read_u8(&mut self) -> Result<u8, Error> {
|
||||
let buf = self.read_fixed_bytes(1)?;
|
||||
deserialize(&mut &buf[..])
|
||||
}
|
||||
|
||||
fn read_u16(&mut self) -> Result<u16, Error> {
|
||||
let buf = self.read_fixed_bytes(2)?;
|
||||
deserialize(&mut &buf[..])
|
||||
}
|
||||
|
||||
fn read_u32(&mut self) -> Result<u32, Error> {
|
||||
let buf = self.read_fixed_bytes(4)?;
|
||||
deserialize(&mut &buf[..])
|
||||
}
|
||||
|
||||
fn read_i32(&mut self) -> Result<i32, Error> {
|
||||
let buf = self.read_fixed_bytes(4)?;
|
||||
deserialize(&mut &buf[..])
|
||||
}
|
||||
|
||||
fn read_u64(&mut self) -> Result<u64, Error> {
|
||||
let buf = self.read_fixed_bytes(8)?;
|
||||
deserialize(&mut &buf[..])
|
||||
}
|
||||
|
||||
fn read_i64(&mut self) -> Result<i64, Error> {
|
||||
let buf = self.read_fixed_bytes(8)?;
|
||||
deserialize(&mut &buf[..])
|
||||
}
|
||||
|
||||
/// Read a variable size vector from the underlying stream. Expects a usize
|
||||
fn read_bytes_len_prefix(&mut self) -> Result<Vec<u8>, Error> {
|
||||
let len = self.read_u64()?;
|
||||
self.total_bytes_read += 8;
|
||||
self.read_fixed_bytes(len as usize)
|
||||
}
|
||||
|
||||
/// Read a fixed number of bytes.
|
||||
fn read_fixed_bytes(&mut self, len: usize) -> Result<Vec<u8>, Error> {
|
||||
let mut buf = vec![0u8; len];
|
||||
read_exact(&mut self.stream, &mut buf, self.timeout, true)?;
|
||||
self.total_bytes_read += len as u64;
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
fn expect_u8(&mut self, val: u8) -> Result<u8, Error> {
|
||||
let b = self.read_u8()?;
|
||||
if b == val {
|
||||
Ok(b)
|
||||
} else {
|
||||
Err(Error::UnexpectedData {
|
||||
expected: vec![val],
|
||||
received: vec![b],
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Readable for Commitment {
|
||||
fn read(reader: &mut Reader) -> Result<Commitment, Error> {
|
||||
let a = reader.read_fixed_bytes(PEDERSEN_COMMITMENT_SIZE)?;
|
||||
|
|
|
@ -28,8 +28,9 @@ use std::{cmp, thread, time};
|
|||
|
||||
use core::ser;
|
||||
use core::ser::FixedLength;
|
||||
use msg::{read_body, read_exact, read_header, write_all, write_to_buf, MsgHeader, Type};
|
||||
use msg::{read_body, read_header, read_item, write_to_buf, MsgHeader, Type};
|
||||
use types::Error;
|
||||
use util::read_write::{read_exact, write_all};
|
||||
use util::{RateCounter, RwLock};
|
||||
|
||||
/// A trait to be implemented in order to receive messages from the
|
||||
|
@ -69,17 +70,15 @@ impl<'a> Message<'a> {
|
|||
Message { header, conn }
|
||||
}
|
||||
|
||||
/// Get the TcpStream
|
||||
pub fn get_conn(&mut self) -> TcpStream {
|
||||
return self.conn.try_clone().unwrap();
|
||||
/// Read the message body from the underlying connection
|
||||
pub fn body<T: ser::Readable>(&mut self) -> Result<T, Error> {
|
||||
read_body(&self.header, self.conn)
|
||||
}
|
||||
|
||||
/// Read the message body from the underlying connection
|
||||
pub fn body<T>(&mut self) -> Result<T, Error>
|
||||
where
|
||||
T: ser::Readable,
|
||||
{
|
||||
read_body(&self.header, self.conn)
|
||||
/// Read a single "thing" from the underlying connection.
|
||||
/// Return the thing and the total bytes read.
|
||||
pub fn streaming_read<T: ser::Readable>(&mut self) -> Result<(T, u64), Error> {
|
||||
read_item(self.conn)
|
||||
}
|
||||
|
||||
pub fn copy_attachment(&mut self, len: usize, writer: &mut Write) -> Result<usize, Error> {
|
||||
|
|
177
p2p/src/msg.rs
177
p2p/src/msg.rs
|
@ -15,17 +15,17 @@
|
|||
//! Message types that transit over the network and related serialization code.
|
||||
|
||||
use num::FromPrimitive;
|
||||
use std::io::{self, Read, Write};
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpStream};
|
||||
use std::{thread, time};
|
||||
use std::io::{Read, Write};
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
|
||||
use std::time;
|
||||
|
||||
use core::consensus;
|
||||
use core::core::hash::Hash;
|
||||
use core::core::BlockHeader;
|
||||
use core::pow::Difficulty;
|
||||
use core::ser::{self, FixedLength, Readable, Reader, Writeable, Writer};
|
||||
|
||||
use core::ser::{self, FixedLength, Readable, Reader, StreamingReader, Writeable, Writer};
|
||||
use types::{Capabilities, Error, ReasonForBan, MAX_BLOCK_HEADERS, MAX_LOCATORS, MAX_PEER_ADDRS};
|
||||
use util::read_write::read_exact;
|
||||
|
||||
/// Current latest version of the protocol
|
||||
pub const PROTOCOL_VERSION: u32 = 1;
|
||||
|
@ -97,111 +97,19 @@ fn max_msg_size(msg_type: Type) -> u64 {
|
|||
}
|
||||
}
|
||||
|
||||
/// The default implementation of read_exact is useless with async TcpStream as
|
||||
/// it will return as soon as something has been read, regardless of
|
||||
/// whether the buffer has been filled (and then errors). This implementation
|
||||
/// will block until it has read exactly `len` bytes and returns them as a
|
||||
/// `vec<u8>`. Except for a timeout, this implementation will never return a
|
||||
/// partially filled buffer.
|
||||
///
|
||||
/// The timeout in milliseconds aborts the read when it's met. Note that the
|
||||
/// time is not guaranteed to be exact. To support cases where we want to poll
|
||||
/// instead of blocking, a `block_on_empty` boolean, when false, ensures
|
||||
/// `read_exact` returns early with a `io::ErrorKind::WouldBlock` if nothing
|
||||
/// has been read from the socket.
|
||||
pub fn read_exact(
|
||||
conn: &mut TcpStream,
|
||||
mut buf: &mut [u8],
|
||||
timeout: time::Duration,
|
||||
block_on_empty: bool,
|
||||
) -> io::Result<()> {
|
||||
let sleep_time = time::Duration::from_micros(10);
|
||||
let mut count = time::Duration::new(0, 0);
|
||||
|
||||
let mut read = 0;
|
||||
loop {
|
||||
match conn.read(buf) {
|
||||
Ok(0) => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"read_exact",
|
||||
));
|
||||
}
|
||||
Ok(n) => {
|
||||
let tmp = buf;
|
||||
buf = &mut tmp[n..];
|
||||
read += n;
|
||||
}
|
||||
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||
if read == 0 && !block_on_empty {
|
||||
return Err(io::Error::new(io::ErrorKind::WouldBlock, "read_exact"));
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
if !buf.is_empty() {
|
||||
thread::sleep(sleep_time);
|
||||
count += sleep_time;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
if count > timeout {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"reading from tcp stream",
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Same as `read_exact` but for writing.
|
||||
pub fn write_all(conn: &mut Write, mut buf: &[u8], timeout: time::Duration) -> io::Result<()> {
|
||||
let sleep_time = time::Duration::from_micros(10);
|
||||
let mut count = time::Duration::new(0, 0);
|
||||
|
||||
while !buf.is_empty() {
|
||||
match conn.write(buf) {
|
||||
Ok(0) => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
"failed to write whole buffer",
|
||||
))
|
||||
}
|
||||
Ok(n) => buf = &buf[n..],
|
||||
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
if !buf.is_empty() {
|
||||
thread::sleep(sleep_time);
|
||||
count += sleep_time;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
if count > timeout {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"reading from tcp stream",
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read a header from the provided connection without blocking if the
|
||||
/// Read a header from the provided stream without blocking if the
|
||||
/// underlying stream is async. Typically headers will be polled for, so
|
||||
/// we do not want to block.
|
||||
pub fn read_header(conn: &mut TcpStream, msg_type: Option<Type>) -> Result<MsgHeader, Error> {
|
||||
pub fn read_header(stream: &mut Read, msg_type: Option<Type>) -> Result<MsgHeader, Error> {
|
||||
let mut head = vec![0u8; MsgHeader::LEN];
|
||||
if Some(Type::Hand) == msg_type {
|
||||
read_exact(conn, &mut head, time::Duration::from_millis(10), true)?;
|
||||
read_exact(stream, &mut head, time::Duration::from_millis(10), true)?;
|
||||
} else {
|
||||
read_exact(conn, &mut head, time::Duration::from_secs(10), false)?;
|
||||
read_exact(stream, &mut head, time::Duration::from_secs(10), false)?;
|
||||
}
|
||||
let header = ser::deserialize::<MsgHeader>(&mut &head[..])?;
|
||||
let max_len = max_msg_size(header.msg_type);
|
||||
|
||||
// TODO 4x the limits for now to leave ourselves space to change things
|
||||
if header.msg_len > max_len * 4 {
|
||||
error!(
|
||||
|
@ -213,33 +121,34 @@ pub fn read_header(conn: &mut TcpStream, msg_type: Option<Type>) -> Result<MsgHe
|
|||
Ok(header)
|
||||
}
|
||||
|
||||
/// Read a message body from the provided connection, always blocking
|
||||
/// Read a single item from the provided stream, always blocking until we
|
||||
/// have a result (or timeout).
|
||||
/// Returns the item and the total bytes read.
|
||||
pub fn read_item<T: Readable>(stream: &mut Read) -> Result<(T, u64), Error> {
|
||||
let timeout = time::Duration::from_secs(20);
|
||||
let mut reader = StreamingReader::new(stream, timeout);
|
||||
let res = T::read(&mut reader)?;
|
||||
Ok((res, reader.total_bytes_read()))
|
||||
}
|
||||
|
||||
/// Read a message body from the provided stream, always blocking
|
||||
/// until we have a result (or timeout).
|
||||
pub fn read_body<T>(h: &MsgHeader, conn: &mut TcpStream) -> Result<T, Error>
|
||||
where
|
||||
T: Readable,
|
||||
{
|
||||
pub fn read_body<T: Readable>(h: &MsgHeader, stream: &mut Read) -> Result<T, Error> {
|
||||
let mut body = vec![0u8; h.msg_len as usize];
|
||||
read_exact(conn, &mut body, time::Duration::from_secs(20), true)?;
|
||||
read_exact(stream, &mut body, time::Duration::from_secs(20), true)?;
|
||||
ser::deserialize(&mut &body[..]).map_err(From::from)
|
||||
}
|
||||
|
||||
/// Reads a full message from the underlying connection.
|
||||
pub fn read_message<T>(conn: &mut TcpStream, msg_type: Type) -> Result<T, Error>
|
||||
where
|
||||
T: Readable,
|
||||
{
|
||||
let header = read_header(conn, Some(msg_type))?;
|
||||
/// Reads a full message from the underlying stream.
|
||||
pub fn read_message<T: Readable>(stream: &mut Read, msg_type: Type) -> Result<T, Error> {
|
||||
let header = read_header(stream, Some(msg_type))?;
|
||||
if header.msg_type != msg_type {
|
||||
return Err(Error::BadMessage);
|
||||
}
|
||||
read_body(&header, conn)
|
||||
read_body(&header, stream)
|
||||
}
|
||||
|
||||
pub fn write_to_buf<T>(msg: T, msg_type: Type) -> Vec<u8>
|
||||
where
|
||||
T: Writeable,
|
||||
{
|
||||
pub fn write_to_buf<T: Writeable>(msg: T, msg_type: Type) -> Vec<u8> {
|
||||
// prepare the body first so we know its serialized length
|
||||
let mut body_buf = vec![];
|
||||
ser::serialize(&mut body_buf, &msg).unwrap();
|
||||
|
@ -253,13 +162,13 @@ where
|
|||
msg_buf
|
||||
}
|
||||
|
||||
pub fn write_message<T>(conn: &mut TcpStream, msg: T, msg_type: Type) -> Result<(), Error>
|
||||
where
|
||||
T: Writeable + 'static,
|
||||
{
|
||||
pub fn write_message<T: Writeable>(
|
||||
stream: &mut Write,
|
||||
msg: T,
|
||||
msg_type: Type,
|
||||
) -> Result<(), Error> {
|
||||
let buf = write_to_buf(msg, msg_type);
|
||||
// send the whole thing
|
||||
conn.write_all(&buf[..])?;
|
||||
stream.write_all(&buf[..])?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -605,24 +514,6 @@ impl Writeable for Headers {
|
|||
}
|
||||
}
|
||||
|
||||
impl Readable for Headers {
|
||||
fn read(reader: &mut Reader) -> Result<Headers, ser::Error> {
|
||||
let len = reader.read_u16()?;
|
||||
if (len as u32) > MAX_BLOCK_HEADERS + 1 {
|
||||
return Err(ser::Error::TooLargeReadErr);
|
||||
}
|
||||
let mut headers: Vec<BlockHeader> = Vec::with_capacity(len as usize);
|
||||
for n in 0..len as usize {
|
||||
let header = BlockHeader::read(reader)?;
|
||||
if n > 0 && header.height != headers[n - 1].height + 1 {
|
||||
return Err(ser::Error::CorruptedData);
|
||||
}
|
||||
headers.push(header);
|
||||
}
|
||||
Ok(Headers { headers: headers })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Ping {
|
||||
/// total difficulty accumulated by the sender, used to check whether sync
|
||||
/// may be needed
|
||||
|
|
|
@ -15,20 +15,18 @@
|
|||
use std::cmp;
|
||||
use std::env;
|
||||
use std::fs::File;
|
||||
use std::io::{self, BufWriter};
|
||||
use std::net::{SocketAddr, TcpStream};
|
||||
use std::io::BufWriter;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time;
|
||||
|
||||
use chrono::prelude::Utc;
|
||||
use conn::{Message, MessageHandler, Response};
|
||||
use core::core::{self, hash::Hash, CompactBlock};
|
||||
use core::{global, ser};
|
||||
use util::{RateCounter, RwLock};
|
||||
|
||||
use msg::{
|
||||
read_exact, BanReason, GetPeerAddrs, Headers, Locator, PeerAddrs, Ping, Pong, SockAddr,
|
||||
TxHashSetArchive, TxHashSetRequest, Type,
|
||||
BanReason, GetPeerAddrs, Headers, Locator, PeerAddrs, Ping, Pong, SockAddr, TxHashSetArchive,
|
||||
TxHashSetRequest, Type,
|
||||
};
|
||||
use types::{Error, NetAdapter};
|
||||
|
||||
|
@ -194,32 +192,34 @@ impl MessageHandler for Protocol {
|
|||
// we can go request it from some of our peers
|
||||
Type::Header => {
|
||||
let header: core::BlockHeader = msg.body()?;
|
||||
|
||||
adapter.header_received(header, self.addr);
|
||||
|
||||
// we do not return a hash here as we never request a single header
|
||||
// a header will always arrive unsolicited
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
Type::Headers => {
|
||||
let conn = &mut msg.get_conn();
|
||||
let mut total_bytes_read = 0;
|
||||
|
||||
let header_size: u64 = headers_header_size(conn, msg.header.msg_len)?;
|
||||
let mut total_read: u64 = 2;
|
||||
let mut reserved: Vec<u8> = vec![];
|
||||
// Read the count (u16) so we now how many headers to read.
|
||||
let (count, bytes_read): (u16, _) = msg.streaming_read()?;
|
||||
total_bytes_read += bytes_read;
|
||||
|
||||
while total_read < msg.header.msg_len || reserved.len() > 0 {
|
||||
let headers: Headers = headers_streaming_body(
|
||||
conn,
|
||||
msg.header.msg_len,
|
||||
32,
|
||||
&mut total_read,
|
||||
&mut reserved,
|
||||
header_size,
|
||||
)?;
|
||||
adapter.headers_received(headers.headers, self.addr);
|
||||
// Read chunks of headers off the stream and pass them off to the adapter.
|
||||
let chunk_size = 32;
|
||||
for chunk in (0..count).collect::<Vec<_>>().chunks(chunk_size) {
|
||||
let mut headers = vec![];
|
||||
for _ in chunk {
|
||||
let (header, bytes_read) = msg.streaming_read()?;
|
||||
headers.push(header);
|
||||
total_bytes_read += bytes_read;
|
||||
}
|
||||
adapter.headers_received(headers, self.addr);
|
||||
}
|
||||
|
||||
// Now check we read the correct total number of bytes off the stream.
|
||||
if total_bytes_read != msg.header.msg_len {
|
||||
return Err(Error::MsgLen);
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
|
@ -344,97 +344,3 @@ impl MessageHandler for Protocol {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read the Headers Vec size from the underlying connection, and calculate maximum header_size of one Header
|
||||
fn headers_header_size(conn: &mut TcpStream, msg_len: u64) -> Result<u64, Error> {
|
||||
let mut size = vec![0u8; 2];
|
||||
// read size of Vec<BlockHeader>
|
||||
read_exact(conn, &mut size, time::Duration::from_millis(10), true)?;
|
||||
|
||||
let total_headers = size[0] as u64 * 256 + size[1] as u64;
|
||||
if total_headers == 0 || total_headers > 10_000 {
|
||||
return Err(Error::Connection(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"headers_header_size",
|
||||
)));
|
||||
}
|
||||
let average_header_size = (msg_len - 2) / total_headers;
|
||||
|
||||
// support size of Cuck(at)oo: from Cuck(at)oo 29 to Cuck(at)oo 35, with version 2
|
||||
// having slightly larger headers
|
||||
let min_size = core::serialized_size_of_header(1, global::min_edge_bits());
|
||||
let max_size = min_size + 6;
|
||||
if average_header_size < min_size as u64 || average_header_size > max_size as u64 {
|
||||
debug!(
|
||||
"headers_header_size - size of Vec: {}, average_header_size: {}, min: {}, max: {}",
|
||||
total_headers, average_header_size, min_size, max_size,
|
||||
);
|
||||
return Err(Error::Connection(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"headers_header_size",
|
||||
)));
|
||||
}
|
||||
return Ok(max_size as u64);
|
||||
}
|
||||
|
||||
/// Read the Headers streaming body from the underlying connection
|
||||
fn headers_streaming_body(
|
||||
conn: &mut TcpStream, // (i) underlying connection
|
||||
msg_len: u64, // (i) length of whole 'Headers'
|
||||
headers_num: u64, // (i) how many BlockHeader(s) do you want to read
|
||||
total_read: &mut u64, // (i/o) how many bytes already read on this 'Headers' message
|
||||
reserved: &mut Vec<u8>, // (i/o) reserved part of previous read, which is not a whole header
|
||||
max_header_size: u64, // (i) maximum possible size of single BlockHeader
|
||||
) -> Result<Headers, Error> {
|
||||
if headers_num == 0 || msg_len < *total_read || *total_read < 2 {
|
||||
return Err(Error::Connection(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"headers_streaming_body",
|
||||
)));
|
||||
}
|
||||
|
||||
// Note:
|
||||
// As we allow Cuckoo sizes greater than 30 now, the proof of work part of the header
|
||||
// could be 30*42 bits, 31*42 bits, 32*42 bits, etc.
|
||||
// So, for compatibility with variable size of block header, we read max possible size, for
|
||||
// up to Cuckoo 36.
|
||||
//
|
||||
let mut read_size = headers_num * max_header_size - reserved.len() as u64;
|
||||
if *total_read + read_size > msg_len {
|
||||
read_size = msg_len - *total_read;
|
||||
}
|
||||
|
||||
// 1st part
|
||||
let mut body = vec![0u8; 2]; // for Vec<> size
|
||||
let mut final_headers_num = (read_size + reserved.len() as u64) / max_header_size;
|
||||
let remaining = msg_len - *total_read - read_size;
|
||||
if final_headers_num == 0 && remaining == 0 {
|
||||
final_headers_num = 1;
|
||||
}
|
||||
body[0] = (final_headers_num >> 8) as u8;
|
||||
body[1] = (final_headers_num & 0x00ff) as u8;
|
||||
|
||||
// 2nd part
|
||||
body.append(reserved);
|
||||
|
||||
// 3rd part
|
||||
let mut read_body = vec![0u8; read_size as usize];
|
||||
if read_size > 0 {
|
||||
read_exact(conn, &mut read_body, time::Duration::from_secs(20), true)?;
|
||||
*total_read += read_size;
|
||||
}
|
||||
body.append(&mut read_body);
|
||||
|
||||
// deserialize these assembled 3 parts
|
||||
let result: Result<Headers, Error> = ser::deserialize(&mut &body[..]).map_err(From::from);
|
||||
let headers = result?;
|
||||
|
||||
// remaining data
|
||||
let mut deserialized_size = 2; // for Vec<> size
|
||||
for header in &headers.headers {
|
||||
deserialized_size += header.serialized_size();
|
||||
}
|
||||
*reserved = body[deserialized_size..].to_vec();
|
||||
|
||||
Ok(headers)
|
||||
}
|
||||
|
|
|
@ -55,6 +55,7 @@ pub enum Error {
|
|||
Connection(io::Error),
|
||||
/// Header type does not match the expected message type
|
||||
BadMessage,
|
||||
MsgLen,
|
||||
Banned,
|
||||
ConnectionClose,
|
||||
Timeout,
|
||||
|
|
|
@ -57,6 +57,9 @@ pub use types::{LogLevel, LoggingConfig};
|
|||
|
||||
pub mod macros;
|
||||
|
||||
// read_exact and write_all impls
|
||||
pub mod read_write;
|
||||
|
||||
// other utils
|
||||
#[allow(unused_imports)]
|
||||
use std::ops::Deref;
|
||||
|
|
110
util/src/read_write.rs
Normal file
110
util/src/read_write.rs
Normal file
|
@ -0,0 +1,110 @@
|
|||
// Copyright 2018 The Grin Developers
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Custom impls of read_exact and write_all to work around async stream restrictions.
|
||||
|
||||
use std::io;
|
||||
use std::io::prelude::*;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
/// The default implementation of read_exact is useless with an async stream (TcpStream) as
|
||||
/// it will return as soon as something has been read, regardless of
|
||||
/// whether the buffer has been filled (and then errors). This implementation
|
||||
/// will block until it has read exactly `len` bytes and returns them as a
|
||||
/// `vec<u8>`. Except for a timeout, this implementation will never return a
|
||||
/// partially filled buffer.
|
||||
///
|
||||
/// The timeout in milliseconds aborts the read when it's met. Note that the
|
||||
/// time is not guaranteed to be exact. To support cases where we want to poll
|
||||
/// instead of blocking, a `block_on_empty` boolean, when false, ensures
|
||||
/// `read_exact` returns early with a `io::ErrorKind::WouldBlock` if nothing
|
||||
/// has been read from the socket.
|
||||
pub fn read_exact(
|
||||
stream: &mut Read,
|
||||
mut buf: &mut [u8],
|
||||
timeout: Duration,
|
||||
block_on_empty: bool,
|
||||
) -> io::Result<()> {
|
||||
let sleep_time = Duration::from_micros(10);
|
||||
let mut count = Duration::new(0, 0);
|
||||
|
||||
let mut read = 0;
|
||||
loop {
|
||||
match stream.read(buf) {
|
||||
Ok(0) => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
"read_exact",
|
||||
));
|
||||
}
|
||||
Ok(n) => {
|
||||
let tmp = buf;
|
||||
buf = &mut tmp[n..];
|
||||
read += n;
|
||||
}
|
||||
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
|
||||
if read == 0 && !block_on_empty {
|
||||
return Err(io::Error::new(io::ErrorKind::WouldBlock, "read_exact"));
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
if !buf.is_empty() {
|
||||
thread::sleep(sleep_time);
|
||||
count += sleep_time;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
if count > timeout {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"reading from stream",
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Same as `read_exact` but for writing.
|
||||
pub fn write_all(stream: &mut Write, mut buf: &[u8], timeout: Duration) -> io::Result<()> {
|
||||
let sleep_time = Duration::from_micros(10);
|
||||
let mut count = Duration::new(0, 0);
|
||||
|
||||
while !buf.is_empty() {
|
||||
match stream.write(buf) {
|
||||
Ok(0) => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
"failed to write whole buffer",
|
||||
))
|
||||
}
|
||||
Ok(n) => buf = &buf[n..],
|
||||
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
|
||||
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
if !buf.is_empty() {
|
||||
thread::sleep(sleep_time);
|
||||
count += sleep_time;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
if count > timeout {
|
||||
return Err(io::Error::new(io::ErrorKind::TimedOut, "writing to stream"));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in a new issue