grin/p2p/src/codec.rs

282 lines
9.1 KiB
Rust
Raw Normal View History

// Copyright 2021 The Grin Developers
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Provides a connection wrapper that handles the lower level tasks in sending
//! or receiving data from the TCP socket, as well as dealing with timeouts.
//!
//! Because of a few idiosyncracies in the Rust `TcpStream`, this has to use
//! async I/O to be able to both read *and* write on the connection. Which
//! forces us to go through some additional gymnastic to loop over the async
//! stream and make sure we get the right number of bytes out.
use crate::core::global::header_size_bytes;
use crate::core::ser::{BufReader, ProtocolVersion, Readable};
use crate::msg::{Message, MsgHeader, MsgHeaderWrapper, Type};
use crate::types::{AttachmentMeta, AttachmentUpdate, Error};
use crate::{
core::core::block::{BlockHeader, UntrustedBlockHeader},
msg::HeadersData,
};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use core::ser::Reader;
use std::cmp::min;
use std::io::Read;
use std::mem;
use std::net::TcpStream;
use std::sync::Arc;
use std::time::{Duration, Instant};
use MsgHeaderWrapper::*;
use State::*;
const HEADER_IO_TIMEOUT: Duration = Duration::from_millis(2000);
pub const BODY_IO_TIMEOUT: Duration = Duration::from_millis(60000);
const HEADER_BATCH_SIZE: usize = 32;
enum State {
None,
Header(MsgHeaderWrapper),
BlockHeaders {
bytes_left: usize,
items_left: usize,
headers: Vec<BlockHeader>,
},
Attachment(usize, Arc<AttachmentMeta>, Instant),
}
impl State {
fn is_none(&self) -> bool {
match self {
State::None => true,
_ => false,
}
}
}
pub struct Codec {
pub version: ProtocolVersion,
stream: TcpStream,
buffer: BytesMut,
state: State,
bytes_read: usize,
}
impl Codec {
pub fn new(version: ProtocolVersion, stream: TcpStream) -> Self {
Self {
version,
stream,
buffer: BytesMut::with_capacity(8 * 1024),
state: None,
bytes_read: 0,
}
}
/// Destroy the codec and return the reader
pub fn stream(self) -> TcpStream {
self.stream
}
/// Inform codec next `len` bytes are an attachment
/// Panics if already reading a body
pub fn expect_attachment(&mut self, meta: Arc<AttachmentMeta>) {
assert!(self.state.is_none());
self.state = Attachment(meta.size, meta, Instant::now());
}
/// Length of the next item we are expecting, could be msg header, body, block header or attachment chunk
fn next_len(&self) -> usize {
match &self.state {
None => MsgHeader::LEN,
Header(Known(h)) if h.msg_type == Type::Headers => {
// If we are receiving a list of headers, read off the item count first
min(h.msg_len as usize, 2)
}
Header(Known(header)) => header.msg_len as usize,
Header(Unknown(len, _)) => *len as usize,
BlockHeaders { bytes_left, .. } => {
// The header length varies with the number of edge bits. Therefore we overestimate
// its size and only actually read the bytes we need
min(*bytes_left, header_size_bytes(63))
}
Attachment(left, _, _) => min(*left, 48_000),
}
}
/// Set stream timeout depending on the next expected item
fn set_stream_timeout(&self) -> Result<(), Error> {
let timeout = match &self.state {
None => HEADER_IO_TIMEOUT,
_ => BODY_IO_TIMEOUT,
};
self.stream.set_read_timeout(Some(timeout))?;
Ok(())
}
fn read_inner(&mut self) -> Result<Message, Error> {
self.bytes_read = 0;
loop {
let next_len = self.next_len();
let pre_len = self.buffer.len();
// Buffer could already be partially filled, calculate additional bytes we need
let to_read = next_len.saturating_sub(pre_len);
if to_read > 0 {
self.buffer.reserve(to_read);
for _ in 0..to_read {
self.buffer.put_u8(0);
}
self.set_stream_timeout()?;
if let Err(e) = self.stream.read_exact(&mut self.buffer[pre_len..]) {
// Undo reserved bytes on a failed read
self.buffer.truncate(pre_len);
return Err(e.into());
}
self.bytes_read += to_read;
}
match &mut self.state {
None => {
// Parse header and keep reading
let mut raw = self.buffer.split_to(next_len).freeze();
let mut reader = BufReader::new(&mut raw, self.version);
let header = MsgHeaderWrapper::read(&mut reader)?;
self.state = Header(header);
}
Header(Known(header)) => {
let mut raw = self.buffer.split_to(next_len).freeze();
if header.msg_type == Type::Headers {
// Special consideration for a list of headers, as we want to verify and process
// them as they come in instead of only after the full list has been received
let mut reader = BufReader::new(&mut raw, self.version);
let items_left = reader.read_u16()? as usize;
self.state = BlockHeaders {
bytes_left: header.msg_len as usize - 2,
items_left,
headers: Vec::with_capacity(min(HEADER_BATCH_SIZE, items_left)),
};
} else {
// Return full message
let msg = decode_message(header, &mut raw, self.version);
self.state = None;
return msg;
}
}
Header(Unknown(_, msg_type)) => {
// Discard body and return
let msg_type = *msg_type;
self.buffer.advance(next_len);
self.state = None;
return Ok(Message::Unknown(msg_type));
}
BlockHeaders {
bytes_left,
items_left,
headers,
} => {
if *bytes_left == 0 {
// Incorrect item count
self.state = None;
return Err(Error::BadMessage);
}
let mut reader = BufReader::new(&mut self.buffer, self.version);
let header: UntrustedBlockHeader = reader.body()?;
let bytes_read = reader.bytes_read() as usize;
headers.push(header.into());
*bytes_left = bytes_left.saturating_sub(bytes_read);
*items_left -= 1;
let remaining = *items_left as u64;
if headers.len() == HEADER_BATCH_SIZE || remaining == 0 {
let mut h = Vec::with_capacity(min(HEADER_BATCH_SIZE, *items_left));
mem::swap(headers, &mut h);
if remaining == 0 {
let bytes_left = *bytes_left;
self.state = None;
if bytes_left > 0 {
return Err(Error::BadMessage);
}
}
return Ok(Message::Headers(HeadersData {
headers: h,
remaining,
}));
}
}
Attachment(left, meta, now) => {
let raw = self.buffer.split_to(next_len).freeze();
*left -= next_len;
if now.elapsed().as_secs() > 10 {
*now = Instant::now();
debug!("attachment: {}/{}", meta.size - *left, meta.size);
}
let update = AttachmentUpdate {
read: next_len,
left: *left,
meta: Arc::clone(meta),
};
if *left == 0 {
self.state = None;
debug!("attachment: DONE");
}
return Ok(Message::Attachment(update, Some(raw)));
}
}
}
}
/// Blocking read of the next message
pub fn read(&mut self) -> (Result<Message, Error>, u64) {
let msg = self.read_inner();
(msg, self.bytes_read as u64)
}
}
// TODO: replace with a macro?
fn decode_message(
header: &MsgHeader,
body: &mut Bytes,
version: ProtocolVersion,
) -> Result<Message, Error> {
let mut msg = BufReader::new(body, version);
let c = match header.msg_type {
Type::Ping => Message::Ping(msg.body()?),
Type::Pong => Message::Pong(msg.body()?),
Type::BanReason => Message::BanReason(msg.body()?),
Type::TransactionKernel => Message::TransactionKernel(msg.body()?),
Type::GetTransaction => Message::GetTransaction(msg.body()?),
Type::Transaction => Message::Transaction(msg.body()?),
Type::StemTransaction => Message::StemTransaction(msg.body()?),
Type::GetBlock => Message::GetBlock(msg.body()?),
Type::Block => Message::Block(msg.body()?),
Type::GetCompactBlock => Message::GetCompactBlock(msg.body()?),
Type::CompactBlock => Message::CompactBlock(msg.body()?),
Type::GetHeaders => Message::GetHeaders(msg.body()?),
Type::Header => Message::Header(msg.body()?),
Type::GetPeerAddrs => Message::GetPeerAddrs(msg.body()?),
Type::PeerAddrs => Message::PeerAddrs(msg.body()?),
Type::TxHashSetRequest => Message::TxHashSetRequest(msg.body()?),
Type::TxHashSetArchive => Message::TxHashSetArchive(msg.body()?),
Type::GetOutputBitmapSegment => Message::GetOutputBitmapSegment(msg.body()?),
Type::OutputBitmapSegment => Message::OutputBitmapSegment(msg.body()?),
Type::GetOutputSegment => Message::GetOutputSegment(msg.body()?),
Type::OutputSegment => Message::OutputSegment(msg.body()?),
Type::GetRangeProofSegment => Message::GetRangeProofSegment(msg.body()?),
Type::RangeProofSegment => Message::RangeProofSegment(msg.body()?),
Type::GetKernelSegment => Message::GetKernelSegment(msg.body()?),
Type::KernelSegment => Message::KernelSegment(msg.body()?),
Type::Error | Type::Hand | Type::Shake | Type::Headers => {
return Err(Error::UnexpectedMessage)
}
};
Ok(c)
}