mirror of
https://github.com/mimblewimble/grin.git
synced 2025-01-21 11:31:08 +03:00
Wrap MsgHeader in MsgHeaderWrapper for Known/Unknown msg type support (#2791)
* wrap MsgHeader in MsgHeaderWrapper for Known/Unknown msg type support. * cleanup based on feedback
This commit is contained in:
parent
4ba7b0a46a
commit
ff1c55193f
3 changed files with 114 additions and 46 deletions
|
@ -28,7 +28,10 @@ use std::{cmp, thread, time};
|
||||||
|
|
||||||
use crate::core::ser;
|
use crate::core::ser;
|
||||||
use crate::core::ser::FixedLength;
|
use crate::core::ser::FixedLength;
|
||||||
use crate::msg::{read_body, read_header, read_item, write_to_buf, MsgHeader, Type};
|
use crate::msg::{
|
||||||
|
read_body, read_discard, read_header, read_item, write_to_buf, MsgHeader, MsgHeaderWrapper,
|
||||||
|
Type,
|
||||||
|
};
|
||||||
use crate::types::Error;
|
use crate::types::Error;
|
||||||
use crate::util::read_write::{read_exact, write_all};
|
use crate::util::read_write::{read_exact, write_all};
|
||||||
use crate::util::{RateCounter, RwLock};
|
use crate::util::{RateCounter, RwLock};
|
||||||
|
@ -251,8 +254,9 @@ fn poll<H>(
|
||||||
let mut retry_send = Err(());
|
let mut retry_send = Err(());
|
||||||
loop {
|
loop {
|
||||||
// check the read end
|
// check the read end
|
||||||
if let Some(h) = try_break!(read_header(&mut reader, None)) {
|
match try_break!(read_header(&mut reader, None)) {
|
||||||
let msg = Message::from_header(h, &mut reader);
|
Some(MsgHeaderWrapper::Known(header)) => {
|
||||||
|
let msg = Message::from_header(header, &mut reader);
|
||||||
|
|
||||||
trace!(
|
trace!(
|
||||||
"Received message header, type {:?}, len {}.",
|
"Received message header, type {:?}, len {}.",
|
||||||
|
@ -261,18 +265,24 @@ fn poll<H>(
|
||||||
);
|
);
|
||||||
|
|
||||||
// Increase received bytes counter
|
// Increase received bytes counter
|
||||||
let received = received_bytes.clone();
|
received_bytes
|
||||||
{
|
.write()
|
||||||
let mut received_bytes = received_bytes.write();
|
.inc(MsgHeader::LEN as u64 + msg.header.msg_len);
|
||||||
received_bytes.inc(MsgHeader::LEN as u64 + msg.header.msg_len);
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(Some(resp)) =
|
if let Some(Some(resp)) =
|
||||||
try_break!(handler.consume(msg, &mut writer, received))
|
try_break!(handler.consume(msg, &mut writer, received_bytes.clone()))
|
||||||
{
|
{
|
||||||
try_break!(resp.write(sent_bytes.clone()));
|
try_break!(resp.write(sent_bytes.clone()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Some(MsgHeaderWrapper::Unknown(msg_len)) => {
|
||||||
|
// Increase received bytes counter
|
||||||
|
received_bytes.write().inc(MsgHeader::LEN as u64 + msg_len);
|
||||||
|
|
||||||
|
try_break!(read_discard(msg_len, &mut reader));
|
||||||
|
}
|
||||||
|
None => {}
|
||||||
|
}
|
||||||
|
|
||||||
// check the write end, use or_else so try_recv is lazily eval'd
|
// check the write end, use or_else so try_recv is lazily eval'd
|
||||||
let maybe_data = retry_send.or_else(|_| send_rx.try_recv());
|
let maybe_data = retry_send.or_else(|_| send_rx.try_recv());
|
||||||
|
|
109
p2p/src/msg.rs
109
p2p/src/msg.rs
|
@ -82,6 +82,11 @@ fn max_block_size() -> u64 {
|
||||||
(global::max_block_weight() / consensus::BLOCK_OUTPUT_WEIGHT * 708) as u64
|
(global::max_block_weight() / consensus::BLOCK_OUTPUT_WEIGHT * 708) as u64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Max msg size when msg type is unknown.
|
||||||
|
fn default_max_msg_size() -> u64 {
|
||||||
|
max_block_size()
|
||||||
|
}
|
||||||
|
|
||||||
// Max msg size for each msg type.
|
// Max msg size for each msg type.
|
||||||
fn max_msg_size(msg_type: Type) -> u64 {
|
fn max_msg_size(msg_type: Type) -> u64 {
|
||||||
match msg_type {
|
match msg_type {
|
||||||
|
@ -120,24 +125,20 @@ fn magic() -> [u8; 2] {
|
||||||
/// Read a header from the provided stream without blocking if the
|
/// Read a header from the provided stream without blocking if the
|
||||||
/// underlying stream is async. Typically headers will be polled for, so
|
/// underlying stream is async. Typically headers will be polled for, so
|
||||||
/// we do not want to block.
|
/// we do not want to block.
|
||||||
pub fn read_header(stream: &mut dyn Read, msg_type: Option<Type>) -> Result<MsgHeader, Error> {
|
///
|
||||||
|
/// Note: We return a MsgHeaderWrapper here as we may encounter an unknown msg type.
|
||||||
|
///
|
||||||
|
pub fn read_header(
|
||||||
|
stream: &mut dyn Read,
|
||||||
|
msg_type: Option<Type>,
|
||||||
|
) -> Result<MsgHeaderWrapper, Error> {
|
||||||
let mut head = vec![0u8; MsgHeader::LEN];
|
let mut head = vec![0u8; MsgHeader::LEN];
|
||||||
if Some(Type::Hand) == msg_type {
|
if Some(Type::Hand) == msg_type {
|
||||||
read_exact(stream, &mut head, time::Duration::from_millis(10), true)?;
|
read_exact(stream, &mut head, time::Duration::from_millis(10), true)?;
|
||||||
} else {
|
} else {
|
||||||
read_exact(stream, &mut head, time::Duration::from_secs(10), false)?;
|
read_exact(stream, &mut head, time::Duration::from_secs(10), false)?;
|
||||||
}
|
}
|
||||||
let header = ser::deserialize::<MsgHeader>(&mut &head[..])?;
|
let header = ser::deserialize::<MsgHeaderWrapper>(&mut &head[..])?;
|
||||||
let max_len = max_msg_size(header.msg_type);
|
|
||||||
|
|
||||||
// TODO 4x the limits for now to leave ourselves space to change things
|
|
||||||
if header.msg_len > max_len * 4 {
|
|
||||||
error!(
|
|
||||||
"Too large read {}, had {}, wanted {}.",
|
|
||||||
header.msg_type as u8, max_len, header.msg_len
|
|
||||||
);
|
|
||||||
return Err(Error::Serialization(ser::Error::TooLargeReadErr));
|
|
||||||
}
|
|
||||||
Ok(header)
|
Ok(header)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,13 +160,28 @@ pub fn read_body<T: Readable>(h: &MsgHeader, stream: &mut dyn Read) -> Result<T,
|
||||||
ser::deserialize(&mut &body[..]).map_err(From::from)
|
ser::deserialize(&mut &body[..]).map_err(From::from)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Read (an unknown) message from the provided stream and discard it.
|
||||||
|
pub fn read_discard(msg_len: u64, stream: &mut dyn Read) -> Result<(), Error> {
|
||||||
|
let mut buffer = vec![0u8; msg_len as usize];
|
||||||
|
read_exact(stream, &mut buffer, time::Duration::from_secs(20), true)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Reads a full message from the underlying stream.
|
/// Reads a full message from the underlying stream.
|
||||||
pub fn read_message<T: Readable>(stream: &mut dyn Read, msg_type: Type) -> Result<T, Error> {
|
pub fn read_message<T: Readable>(stream: &mut dyn Read, msg_type: Type) -> Result<T, Error> {
|
||||||
let header = read_header(stream, Some(msg_type))?;
|
match read_header(stream, Some(msg_type))? {
|
||||||
if header.msg_type != msg_type {
|
MsgHeaderWrapper::Known(header) => {
|
||||||
return Err(Error::BadMessage);
|
if header.msg_type == msg_type {
|
||||||
}
|
|
||||||
read_body(&header, stream)
|
read_body(&header, stream)
|
||||||
|
} else {
|
||||||
|
Err(Error::BadMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MsgHeaderWrapper::Unknown(msg_len) => {
|
||||||
|
read_discard(msg_len, stream)?;
|
||||||
|
Err(Error::BadMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn write_to_buf<T: Writeable>(msg: T, msg_type: Type) -> Result<Vec<u8>, Error> {
|
pub fn write_to_buf<T: Writeable>(msg: T, msg_type: Type) -> Result<Vec<u8>, Error> {
|
||||||
|
@ -192,7 +208,19 @@ pub fn write_message<T: Writeable>(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A wrapper around a message header. If the header is for an unknown msg type
|
||||||
|
/// then we will be unable to parse the msg itself (just a bunch of random bytes).
|
||||||
|
/// But we need to know how many bytes to discard to discard the full message.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum MsgHeaderWrapper {
|
||||||
|
/// A "known" msg type with deserialized msg header.
|
||||||
|
Known(MsgHeader),
|
||||||
|
/// An unknown msg type with corresponding msg size in bytes.
|
||||||
|
Unknown(u64),
|
||||||
|
}
|
||||||
|
|
||||||
/// Header of any protocol message, used to identify incoming messages.
|
/// Header of any protocol message, used to identify incoming messages.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct MsgHeader {
|
pub struct MsgHeader {
|
||||||
magic: [u8; 2],
|
magic: [u8; 2],
|
||||||
/// Type of the message.
|
/// Type of the message.
|
||||||
|
@ -213,7 +241,8 @@ impl MsgHeader {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FixedLength for MsgHeader {
|
impl FixedLength for MsgHeader {
|
||||||
const LEN: usize = 1 + 1 + 1 + 8;
|
// 2 magic bytes + 1 type byte + 8 bytes (msg_len)
|
||||||
|
const LEN: usize = 2 + 1 + 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Writeable for MsgHeader {
|
impl Writeable for MsgHeader {
|
||||||
|
@ -229,19 +258,49 @@ impl Writeable for MsgHeader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Readable for MsgHeader {
|
impl Readable for MsgHeaderWrapper {
|
||||||
fn read(reader: &mut dyn Reader) -> Result<MsgHeader, ser::Error> {
|
fn read(reader: &mut dyn Reader) -> Result<MsgHeaderWrapper, ser::Error> {
|
||||||
let m = magic();
|
let m = magic();
|
||||||
reader.expect_u8(m[0])?;
|
reader.expect_u8(m[0])?;
|
||||||
reader.expect_u8(m[1])?;
|
reader.expect_u8(m[1])?;
|
||||||
let (t, len) = ser_multiread!(reader, read_u8, read_u64);
|
|
||||||
|
// Read the msg header.
|
||||||
|
// We do not yet know if the msg type is one we support locally.
|
||||||
|
let (t, msg_len) = ser_multiread!(reader, read_u8, read_u64);
|
||||||
|
|
||||||
|
// Attempt to convert the msg type byte into one of our known msg type enum variants.
|
||||||
|
// Check the msg_len while we are at it.
|
||||||
match Type::from_u8(t) {
|
match Type::from_u8(t) {
|
||||||
Some(ty) => Ok(MsgHeader {
|
Some(msg_type) => {
|
||||||
|
// TODO 4x the limits for now to leave ourselves space to change things.
|
||||||
|
let max_len = max_msg_size(msg_type) * 4;
|
||||||
|
if msg_len > max_len {
|
||||||
|
error!(
|
||||||
|
"Too large read {:?}, max_len: {}, msg_len: {}.",
|
||||||
|
msg_type, max_len, msg_len
|
||||||
|
);
|
||||||
|
return Err(ser::Error::TooLargeReadErr);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(MsgHeaderWrapper::Known(MsgHeader {
|
||||||
magic: m,
|
magic: m,
|
||||||
msg_type: ty,
|
msg_type,
|
||||||
msg_len: len,
|
msg_len,
|
||||||
}),
|
}))
|
||||||
None => Err(ser::Error::CorruptedData),
|
}
|
||||||
|
None => {
|
||||||
|
// Unknown msg type, but we still want to limit how big the msg is.
|
||||||
|
let max_len = default_max_msg_size() * 4;
|
||||||
|
if msg_len > max_len {
|
||||||
|
error!(
|
||||||
|
"Too large read (unknown msg type) {:?}, max_len: {}, msg_len: {}.",
|
||||||
|
t, max_len, msg_len
|
||||||
|
);
|
||||||
|
return Err(ser::Error::TooLargeReadErr);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(MsgHeaderWrapper::Unknown(msg_len))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -353,9 +353,8 @@ impl MessageHandler for Protocol {
|
||||||
|
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
Type::Error | Type::Hand | Type::Shake => {
|
||||||
_ => {
|
debug!("Received an unexpected msg: {:?}", msg.header.msg_type);
|
||||||
debug!("unknown message type {:?}", msg.header.msg_type);
|
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue