use sender thread consistently to send msgs to a peer (#3067)

This commit is contained in:
Antioch Peverell 2019-10-07 16:22:05 +01:00 committed by GitHub
parent 95e74c7b4b
commit a3f3fc25dc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 117 additions and 159 deletions

View file

@ -23,12 +23,11 @@
use crate::core::ser; use crate::core::ser;
use crate::core::ser::{FixedLength, ProtocolVersion}; use crate::core::ser::{FixedLength, ProtocolVersion};
use crate::msg::{ use crate::msg::{
read_body, read_discard, read_header, read_item, write_to_buf, MsgHeader, MsgHeaderWrapper, read_body, read_discard, read_header, read_item, write_message, Msg, MsgHeader,
Type, MsgHeaderWrapper,
}; };
use crate::types::Error; use crate::types::Error;
use crate::util::{RateCounter, RwLock}; use crate::util::{RateCounter, RwLock};
use std::fs::File;
use std::io::{self, Read, Write}; use std::io::{self, Read, Write};
use std::net::{Shutdown, TcpStream}; use std::net::{Shutdown, TcpStream};
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
@ -44,12 +43,7 @@ const IO_TIMEOUT: Duration = Duration::from_millis(1000);
/// A trait to be implemented in order to receive messages from the /// A trait to be implemented in order to receive messages from the
/// connection. Allows providing an optional response. /// connection. Allows providing an optional response.
pub trait MessageHandler: Send + 'static { pub trait MessageHandler: Send + 'static {
fn consume<'a>( fn consume<'a>(&self, msg: Message<'a>, tracker: Arc<Tracker>) -> Result<Option<Msg>, Error>;
&self,
msg: Message<'a>,
writer: &'a mut dyn Write,
tracker: Arc<Tracker>,
) -> Result<Option<Response<'a>>, Error>;
} }
// Macro to simplify the boilerplate around async I/O error handling, // Macro to simplify the boilerplate around async I/O error handling,
@ -121,64 +115,6 @@ impl<'a> Message<'a> {
} }
} }
/// Response to a `Message`.
pub struct Response<'a> {
resp_type: Type,
body: Vec<u8>,
version: ProtocolVersion,
stream: &'a mut dyn Write,
attachment: Option<File>,
}
impl<'a> Response<'a> {
pub fn new<T: ser::Writeable>(
resp_type: Type,
version: ProtocolVersion,
body: T,
stream: &'a mut dyn Write,
) -> Result<Response<'a>, Error> {
let body = ser::ser_vec(&body, version)?;
Ok(Response {
resp_type,
body,
version,
stream,
attachment: None,
})
}
fn write(mut self, tracker: Arc<Tracker>) -> Result<(), Error> {
let mut msg = ser::ser_vec(
&MsgHeader::new(self.resp_type, self.body.len() as u64),
self.version,
)?;
msg.append(&mut self.body);
self.stream.write_all(&msg[..])?;
tracker.inc_sent(msg.len() as u64);
if let Some(mut file) = self.attachment {
let mut buf = [0u8; 8000];
loop {
match file.read(&mut buf[..]) {
Ok(0) => break,
Ok(n) => {
self.stream.write_all(&buf[..n])?;
// Increase sent bytes "quietly" without incrementing the counter.
// (In a loop here for the single attachment).
tracker.inc_quiet_sent(n as u64);
}
Err(e) => return Err(From::from(e)),
}
}
}
Ok(())
}
pub fn add_attachment(&mut self, file: File) {
self.attachment = Some(file);
}
}
pub const SEND_CHANNEL_CAP: usize = 100; pub const SEND_CHANNEL_CAP: usize = 100;
pub struct StopHandle { pub struct StopHandle {
@ -220,20 +156,16 @@ impl StopHandle {
} }
} }
#[derive(Clone)]
pub struct ConnHandle { pub struct ConnHandle {
/// Channel to allow sending data through the connection /// Channel to allow sending data through the connection
pub send_channel: mpsc::SyncSender<Vec<u8>>, pub send_channel: mpsc::SyncSender<Msg>,
} }
impl ConnHandle { impl ConnHandle {
pub fn send<T>(&self, body: T, msg_type: Type, version: ProtocolVersion) -> Result<u64, Error> pub fn send(&self, msg: Msg) -> Result<(), Error> {
where self.send_channel.try_send(msg)?;
T: ser::Writeable, Ok(())
{
let buf = write_to_buf(body, msg_type, version)?;
let buf_len = buf.len();
self.send_channel.try_send(buf)?;
Ok(buf_len as u64)
} }
} }
@ -294,13 +226,22 @@ where
let stopped = Arc::new(AtomicBool::new(false)); let stopped = Arc::new(AtomicBool::new(false));
let (reader_thread, writer_thread) = let conn_handle = ConnHandle {
poll(stream, version, handler, send_rx, stopped.clone(), tracker)?; send_channel: send_tx,
};
let (reader_thread, writer_thread) = poll(
stream,
conn_handle.clone(),
version,
handler,
send_rx,
stopped.clone(),
tracker,
)?;
Ok(( Ok((
ConnHandle { conn_handle,
send_channel: send_tx,
},
StopHandle { StopHandle {
stopped, stopped,
reader_thread: Some(reader_thread), reader_thread: Some(reader_thread),
@ -311,9 +252,10 @@ where
fn poll<H>( fn poll<H>(
conn: TcpStream, conn: TcpStream,
conn_handle: ConnHandle,
version: ProtocolVersion, version: ProtocolVersion,
handler: H, handler: H,
send_rx: mpsc::Receiver<Vec<u8>>, send_rx: mpsc::Receiver<Msg>,
stopped: Arc<AtomicBool>, stopped: Arc<AtomicBool>,
tracker: Arc<Tracker>, tracker: Arc<Tracker>,
) -> io::Result<(JoinHandle<()>, JoinHandle<()>)> ) -> io::Result<(JoinHandle<()>, JoinHandle<()>)>
@ -323,9 +265,11 @@ where
// Split out tcp stream out into separate reader/writer halves. // Split out tcp stream out into separate reader/writer halves.
let mut reader = conn.try_clone().expect("clone conn for reader failed"); let mut reader = conn.try_clone().expect("clone conn for reader failed");
let mut writer = conn.try_clone().expect("clone conn for writer failed"); let mut writer = conn.try_clone().expect("clone conn for writer failed");
let mut responder = conn.try_clone().expect("clone conn for writer failed");
let reader_stopped = stopped.clone(); let reader_stopped = stopped.clone();
let reader_tracker = tracker.clone();
let writer_tracker = tracker.clone();
let reader_thread = thread::Builder::new() let reader_thread = thread::Builder::new()
.name("peer_read".to_string()) .name("peer_read".to_string())
.spawn(move || { .spawn(move || {
@ -342,17 +286,16 @@ where
); );
// Increase received bytes counter // Increase received bytes counter
tracker.inc_received(MsgHeader::LEN as u64 + msg.header.msg_len); reader_tracker.inc_received(MsgHeader::LEN as u64 + msg.header.msg_len);
if let Some(Some(resp)) = let resp_msg = try_break!(handler.consume(msg, reader_tracker.clone()));
try_break!(handler.consume(msg, &mut responder, tracker.clone())) if let Some(Some(resp_msg)) = resp_msg {
{ try_break!(conn_handle.send(resp_msg));
try_break!(resp.write(tracker.clone()));
} }
} }
Some(MsgHeaderWrapper::Unknown(msg_len)) => { Some(MsgHeaderWrapper::Unknown(msg_len)) => {
// Increase received bytes counter // Increase received bytes counter
tracker.inc_received(MsgHeader::LEN as u64 + msg_len); reader_tracker.inc_received(MsgHeader::LEN as u64 + msg_len);
try_break!(read_discard(msg_len, &mut reader)); try_break!(read_discard(msg_len, &mut reader));
} }
@ -383,7 +326,8 @@ where
let maybe_data = retry_send.or_else(|_| send_rx.recv_timeout(IO_TIMEOUT)); let maybe_data = retry_send.or_else(|_| send_rx.recv_timeout(IO_TIMEOUT));
retry_send = Err(()); retry_send = Err(());
if let Ok(data) = maybe_data { if let Ok(data) = maybe_data {
let written = try_break!(writer.write_all(&data[..]).map_err(&From::from)); let written =
try_break!(write_message(&mut writer, &data, writer_tracker.clone()));
if written.is_none() { if written.is_none() {
retry_send = Ok(data); retry_send = Ok(data);
} }

View file

@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::conn::Tracker;
use crate::core::core::hash::Hash; use crate::core::core::hash::Hash;
use crate::core::pow::Difficulty; use crate::core::pow::Difficulty;
use crate::core::ser::ProtocolVersion; use crate::core::ser::ProtocolVersion;
use crate::msg::{read_message, write_message, Hand, Shake, Type, USER_AGENT}; use crate::msg::{read_message, write_message, Hand, Msg, Shake, Type, USER_AGENT};
use crate::peer::Peer; use crate::peer::Peer;
use crate::types::{Capabilities, Direction, Error, P2PConfig, PeerAddr, PeerInfo, PeerLiveInfo}; use crate::types::{Capabilities, Direction, Error, P2PConfig, PeerAddr, PeerInfo, PeerLiveInfo};
use crate::util::RwLock; use crate::util::RwLock;
@ -47,6 +48,7 @@ pub struct Handshake {
genesis: Hash, genesis: Hash,
config: P2PConfig, config: P2PConfig,
protocol_version: ProtocolVersion, protocol_version: ProtocolVersion,
tracker: Arc<Tracker>,
} }
impl Handshake { impl Handshake {
@ -58,6 +60,7 @@ impl Handshake {
genesis, genesis,
config, config,
protocol_version: ProtocolVersion::local(), protocol_version: ProtocolVersion::local(),
tracker: Arc::new(Tracker::new()),
} }
} }
@ -99,7 +102,8 @@ impl Handshake {
}; };
// write and read the handshake response // write and read the handshake response
write_message(conn, hand, Type::Hand, self.protocol_version)?; let msg = Msg::new(Type::Hand, hand, self.protocol_version)?;
write_message(conn, &msg, self.tracker.clone())?;
let shake: Shake = read_message(conn, self.protocol_version, Type::Shake)?; let shake: Shake = read_message(conn, self.protocol_version, Type::Shake)?;
if shake.genesis != self.genesis { if shake.genesis != self.genesis {
@ -196,7 +200,9 @@ impl Handshake {
user_agent: USER_AGENT.to_string(), user_agent: USER_AGENT.to_string(),
}; };
write_message(conn, shake, Type::Shake, negotiated_version)?; let msg = Msg::new(Type::Shake, shake, negotiated_version)?;
write_message(conn, &msg, self.tracker.clone())?;
trace!("Success handshake with {}.", peer_info.addr); trace!("Success handshake with {}.", peer_info.addr);
Ok(peer_info) Ok(peer_info)

View file

@ -14,6 +14,7 @@
//! Message types that transit over the network and related serialization code. //! Message types that transit over the network and related serialization code.
use crate::conn::Tracker;
use crate::core::core::hash::Hash; use crate::core::core::hash::Hash;
use crate::core::core::BlockHeader; use crate::core::core::BlockHeader;
use crate::core::pow::Difficulty; use crate::core::pow::Difficulty;
@ -25,7 +26,9 @@ use crate::types::{
Capabilities, Error, PeerAddr, ReasonForBan, MAX_BLOCK_HEADERS, MAX_LOCATORS, MAX_PEER_ADDRS, Capabilities, Error, PeerAddr, ReasonForBan, MAX_BLOCK_HEADERS, MAX_LOCATORS, MAX_PEER_ADDRS,
}; };
use num::FromPrimitive; use num::FromPrimitive;
use std::fs::File;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::sync::Arc;
/// Grin's user agent with current version /// Grin's user agent with current version
pub const USER_AGENT: &'static str = concat!("MW/Grin ", env!("CARGO_PKG_VERSION")); pub const USER_AGENT: &'static str = concat!("MW/Grin ", env!("CARGO_PKG_VERSION"));
@ -114,6 +117,33 @@ fn magic() -> [u8; 2] {
} }
} }
pub struct Msg {
header: MsgHeader,
body: Vec<u8>,
attachment: Option<File>,
version: ProtocolVersion,
}
impl Msg {
pub fn new<T: Writeable>(
msg_type: Type,
msg: T,
version: ProtocolVersion,
) -> Result<Msg, Error> {
let body = ser::ser_vec(&msg, version)?;
Ok(Msg {
header: MsgHeader::new(msg_type, body.len() as u64),
body,
attachment: None,
version,
})
}
pub fn add_attachment(&mut self, attachment: File) {
self.attachment = Some(attachment)
}
}
/// Read a header from the provided stream without blocking if the /// Read a header from the provided stream without blocking if the
/// underlying stream is async. Typically headers will be polled for, so /// underlying stream is async. Typically headers will be polled for, so
/// we do not want to block. /// we do not want to block.
@ -182,32 +212,31 @@ pub fn read_message<T: Readable>(
} }
} }
pub fn write_to_buf<T: Writeable>( pub fn write_message(
msg: T,
msg_type: Type,
version: ProtocolVersion,
) -> Result<Vec<u8>, Error> {
// prepare the body first so we know its serialized length
let mut body_buf = vec![];
ser::serialize(&mut body_buf, version, &msg)?;
// build and serialize the header using the body size
let mut msg_buf = vec![];
let blen = body_buf.len() as u64;
ser::serialize(&mut msg_buf, version, &MsgHeader::new(msg_type, blen))?;
msg_buf.append(&mut body_buf);
Ok(msg_buf)
}
pub fn write_message<T: Writeable>(
stream: &mut dyn Write, stream: &mut dyn Write,
msg: T, msg: &Msg,
msg_type: Type, tracker: Arc<Tracker>,
version: ProtocolVersion,
) -> Result<(), Error> { ) -> Result<(), Error> {
let buf = write_to_buf(msg, msg_type, version)?; let mut buf = ser::ser_vec(&msg.header, msg.version)?;
buf.extend(&msg.body[..]);
stream.write_all(&buf[..])?; stream.write_all(&buf[..])?;
tracker.inc_sent(buf.len() as u64);
if let Some(file) = &msg.attachment {
let mut file = file.try_clone()?;
let mut buf = [0u8; 8000];
loop {
match file.read(&mut buf[..]) {
Ok(0) => break,
Ok(n) => {
stream.write_all(&buf[..n])?;
// Increase sent bytes "quietly" without incrementing the counter.
// (In a loop here for the single attachment).
tracker.inc_quiet_sent(n as u64);
}
Err(e) => return Err(From::from(e)),
}
}
}
Ok(()) Ok(())
} }

View file

@ -29,7 +29,7 @@ use crate::core::ser::Writeable;
use crate::core::{core, global}; use crate::core::{core, global};
use crate::handshake::Handshake; use crate::handshake::Handshake;
use crate::msg::{ use crate::msg::{
self, BanReason, GetPeerAddrs, KernelDataRequest, Locator, Ping, TxHashSetRequest, Type, self, BanReason, GetPeerAddrs, KernelDataRequest, Locator, Msg, Ping, TxHashSetRequest, Type,
}; };
use crate::protocol::Protocol; use crate::protocol::Protocol;
use crate::types::{ use crate::types::{
@ -233,12 +233,8 @@ impl Peer {
/// Send a msg with given msg_type to our peer via the connection. /// Send a msg with given msg_type to our peer via the connection.
fn send<T: Writeable>(&self, msg: T, msg_type: Type) -> Result<(), Error> { fn send<T: Writeable>(&self, msg: T, msg_type: Type) -> Result<(), Error> {
let bytes = self let msg = Msg::new(msg_type, msg, self.info.version)?;
.send_handle self.send_handle.lock().send(msg)
.lock()
.send(msg, msg_type, self.info.version)?;
self.tracker.inc_sent(bytes);
Ok(())
} }
/// Send a ping to the remote peer, providing our local difficulty and /// Send a ping to the remote peer, providing our local difficulty and

View file

@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use crate::conn::{Message, MessageHandler, Response, Tracker}; use crate::conn::{Message, MessageHandler, Tracker};
use crate::core::core::{self, hash::Hash, hash::Hashed, CompactBlock}; use crate::core::core::{self, hash::Hash, hash::Hashed, CompactBlock};
use crate::msg::{ use crate::msg::{
BanReason, GetPeerAddrs, Headers, KernelDataResponse, Locator, PeerAddrs, Ping, Pong, BanReason, GetPeerAddrs, Headers, KernelDataResponse, Locator, Msg, PeerAddrs, Ping, Pong,
TxHashSetArchive, TxHashSetRequest, Type, TxHashSetArchive, TxHashSetRequest, Type,
}; };
use crate::types::{Error, NetAdapter, PeerInfo}; use crate::types::{Error, NetAdapter, PeerInfo};
@ -24,7 +24,7 @@ use chrono::prelude::Utc;
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use std::cmp; use std::cmp;
use std::fs::{self, File, OpenOptions}; use std::fs::{self, File, OpenOptions};
use std::io::{BufWriter, Seek, SeekFrom, Write}; use std::io::{BufWriter, Seek, SeekFrom};
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
@ -51,12 +51,7 @@ impl Protocol {
} }
impl MessageHandler for Protocol { impl MessageHandler for Protocol {
fn consume<'a>( fn consume(&self, mut msg: Message, tracker: Arc<Tracker>) -> Result<Option<Msg>, Error> {
&self,
mut msg: Message<'a>,
writer: &'a mut dyn Write,
tracker: Arc<Tracker>,
) -> Result<Option<Response<'a>>, Error> {
let adapter = &self.adapter; let adapter = &self.adapter;
// If we received a msg from a banned peer then log and drop it. // If we received a msg from a banned peer then log and drop it.
@ -75,14 +70,13 @@ impl MessageHandler for Protocol {
let ping: Ping = msg.body()?; let ping: Ping = msg.body()?;
adapter.peer_difficulty(self.peer_info.addr, ping.total_difficulty, ping.height); adapter.peer_difficulty(self.peer_info.addr, ping.total_difficulty, ping.height);
Ok(Some(Response::new( Ok(Some(Msg::new(
Type::Pong, Type::Pong,
self.peer_info.version,
Pong { Pong {
total_difficulty: adapter.total_difficulty()?, total_difficulty: adapter.total_difficulty()?,
height: adapter.total_height()?, height: adapter.total_height()?,
}, },
writer, self.peer_info.version,
)?)) )?))
} }
@ -116,11 +110,10 @@ impl MessageHandler for Protocol {
); );
let tx = adapter.get_transaction(h); let tx = adapter.get_transaction(h);
if let Some(tx) = tx { if let Some(tx) = tx {
Ok(Some(Response::new( Ok(Some(Msg::new(
Type::Transaction, Type::Transaction,
self.peer_info.version,
tx, tx,
writer, self.peer_info.version,
)?)) )?))
} else { } else {
Ok(None) Ok(None)
@ -157,12 +150,7 @@ impl MessageHandler for Protocol {
let bo = adapter.get_block(h); let bo = adapter.get_block(h);
if let Some(b) = bo { if let Some(b) = bo {
return Ok(Some(Response::new( return Ok(Some(Msg::new(Type::Block, b, self.peer_info.version)?));
Type::Block,
self.peer_info.version,
b,
writer,
)?));
} }
Ok(None) Ok(None)
} }
@ -184,11 +172,10 @@ impl MessageHandler for Protocol {
let h: Hash = msg.body()?; let h: Hash = msg.body()?;
if let Some(b) = adapter.get_block(h) { if let Some(b) = adapter.get_block(h) {
let cb: CompactBlock = b.into(); let cb: CompactBlock = b.into();
Ok(Some(Response::new( Ok(Some(Msg::new(
Type::CompactBlock, Type::CompactBlock,
self.peer_info.version,
cb, cb,
writer, self.peer_info.version,
)?)) )?))
} else { } else {
Ok(None) Ok(None)
@ -212,11 +199,10 @@ impl MessageHandler for Protocol {
let headers = adapter.locate_headers(&loc.hashes)?; let headers = adapter.locate_headers(&loc.hashes)?;
// serialize and send all the headers over // serialize and send all the headers over
Ok(Some(Response::new( Ok(Some(Msg::new(
Type::Headers, Type::Headers,
self.peer_info.version,
Headers { headers }, Headers { headers },
writer, self.peer_info.version,
)?)) )?))
} }
@ -258,11 +244,10 @@ impl MessageHandler for Protocol {
Type::GetPeerAddrs => { Type::GetPeerAddrs => {
let get_peers: GetPeerAddrs = msg.body()?; let get_peers: GetPeerAddrs = msg.body()?;
let peers = adapter.find_peer_addrs(get_peers.capabilities); let peers = adapter.find_peer_addrs(get_peers.capabilities);
Ok(Some(Response::new( Ok(Some(Msg::new(
Type::PeerAddrs, Type::PeerAddrs,
self.peer_info.version,
PeerAddrs { peers }, PeerAddrs { peers },
writer, self.peer_info.version,
)?)) )?))
} }
@ -277,11 +262,10 @@ impl MessageHandler for Protocol {
let kernel_data = self.adapter.kernel_data_read()?; let kernel_data = self.adapter.kernel_data_read()?;
let bytes = kernel_data.metadata()?.len(); let bytes = kernel_data.metadata()?.len();
let kernel_data_response = KernelDataResponse { bytes }; let kernel_data_response = KernelDataResponse { bytes };
let mut response = Response::new( let mut response = Msg::new(
Type::KernelDataResponse, Type::KernelDataResponse,
self.peer_info.version,
&kernel_data_response, &kernel_data_response,
writer, self.peer_info.version,
)?; )?;
response.add_attachment(kernel_data); response.add_attachment(kernel_data);
Ok(Some(response)) Ok(Some(response))
@ -337,15 +321,14 @@ impl MessageHandler for Protocol {
if let Some(txhashset) = txhashset { if let Some(txhashset) = txhashset {
let file_sz = txhashset.reader.metadata()?.len(); let file_sz = txhashset.reader.metadata()?.len();
let mut resp = Response::new( let mut resp = Msg::new(
Type::TxHashSetArchive, Type::TxHashSetArchive,
self.peer_info.version,
&TxHashSetArchive { &TxHashSetArchive {
height: txhashset_header.height as u64, height: txhashset_header.height as u64,
hash: txhashset_header_hash, hash: txhashset_header_hash,
bytes: file_sz, bytes: file_sz,
}, },
writer, self.peer_info.version,
)?; )?;
resp.add_attachment(txhashset.reader); resp.add_attachment(txhashset.reader);
Ok(Some(resp)) Ok(Some(resp))