diff --git a/p2p/src/conn.rs b/p2p/src/conn.rs index 63c041eb7..3157669c9 100644 --- a/p2p/src/conn.rs +++ b/p2p/src/conn.rs @@ -39,6 +39,7 @@ pub trait MessageHandler: Send + 'static { fn consume<'a>( &self, msg: Message<'a>, + writer: &'a mut Write, received_bytes: Arc>, ) -> Result>, Error>; } @@ -62,23 +63,23 @@ macro_rules! try_break { /// header lazily consumes the message body, handling its deserialization. pub struct Message<'a> { pub header: MsgHeader, - conn: &'a mut TcpStream, + stream: &'a mut Read, } impl<'a> Message<'a> { - fn from_header(header: MsgHeader, conn: &'a mut TcpStream) -> Message<'a> { - Message { header, conn } + fn from_header(header: MsgHeader, stream: &'a mut Read) -> Message<'a> { + Message { header, stream } } /// Read the message body from the underlying connection pub fn body(&mut self) -> Result { - read_body(&self.header, self.conn) + read_body(&self.header, self.stream) } /// Read a single "thing" from the underlying connection. /// Return the thing and the total bytes read. pub fn streaming_read(&mut self) -> Result<(T, u64), Error> { - read_item(self.conn) + read_item(self.stream) } pub fn copy_attachment(&mut self, len: usize, writer: &mut Write) -> Result { @@ -87,7 +88,7 @@ impl<'a> Message<'a> { let read_len = cmp::min(8000, len - written); let mut buf = vec![0u8; read_len]; read_exact( - &mut self.conn, + &mut self.stream, &mut buf[..], time::Duration::from_secs(10), true, @@ -97,36 +98,32 @@ impl<'a> Message<'a> { } Ok(written) } - - /// Respond to the message with the provided message type and body - pub fn respond(self, resp_type: Type, body: T) -> Response<'a> - where - T: ser::Writeable, - { - let body = ser::ser_vec(&body).unwrap(); - Response { - resp_type: resp_type, - body: body, - conn: self.conn, - attachment: None, - } - } } -/// Response to a `Message` +/// Response to a `Message`. pub struct Response<'a> { resp_type: Type, body: Vec, - conn: &'a mut TcpStream, + stream: &'a mut Write, attachment: Option, } impl<'a> Response<'a> { + pub fn new(resp_type: Type, body: T, stream: &'a mut Write) -> Response<'a> { + let body = ser::ser_vec(&body).unwrap(); + Response { + resp_type, + body, + stream, + attachment: None, + } + } + fn write(mut self, sent_bytes: Arc>) -> Result<(), Error> { let mut msg = ser::ser_vec(&MsgHeader::new(self.resp_type, self.body.len() as u64)).unwrap(); msg.append(&mut self.body); - write_all(&mut self.conn, &msg[..], time::Duration::from_secs(10))?; + write_all(&mut self.stream, &msg[..], time::Duration::from_secs(10))?; // Increase sent bytes counter { let mut sent_bytes = sent_bytes.write(); @@ -138,7 +135,7 @@ impl<'a> Response<'a> { match file.read(&mut buf[..]) { Ok(0) => break, Ok(n) => { - write_all(&mut self.conn, &buf[..n], time::Duration::from_secs(10))?; + write_all(&mut self.stream, &buf[..n], time::Duration::from_secs(10))?; // Increase sent bytes "quietly" without incrementing the counter. // (In a loop here for the single attachment). let mut sent_bytes = sent_bytes.write(); @@ -237,18 +234,20 @@ fn poll( ) where H: MessageHandler, { - let mut conn = conn; + // Split out tcp stream out into separate reader/writer halves. + let mut reader = conn.try_clone().expect("clone conn for reader failed"); + let mut writer = conn.try_clone().expect("clone conn for writer failed"); + let _ = thread::Builder::new() .name("peer".to_string()) .spawn(move || { let sleep_time = time::Duration::from_millis(1); - - let conn = &mut conn; let mut retry_send = Err(()); loop { // check the read end - if let Some(h) = try_break!(error_tx, read_header(conn, None)) { - let msg = Message::from_header(h, conn); + if let Some(h) = try_break!(error_tx, read_header(&mut reader, None)) { + let msg = Message::from_header(h, &mut reader); + trace!( "Received message header, type {:?}, len {}.", msg.header.msg_type, @@ -262,7 +261,9 @@ fn poll( received_bytes.inc(MsgHeader::LEN as u64 + msg.header.msg_len); } - if let Some(Some(resp)) = try_break!(error_tx, handler.consume(msg, received)) { + if let Some(Some(resp)) = + try_break!(error_tx, handler.consume(msg, &mut writer, received)) + { try_break!(error_tx, resp.write(sent_bytes.clone())); } } @@ -272,7 +273,7 @@ fn poll( retry_send = Err(()); if let Ok(data) = maybe_data { let written = - try_break!(error_tx, conn.write_all(&data[..]).map_err(&From::from)); + try_break!(error_tx, writer.write_all(&data[..]).map_err(&From::from)); if written.is_none() { retry_send = Ok(data); } diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs index 950bbc5c5..de9b7174a 100644 --- a/p2p/src/protocol.rs +++ b/p2p/src/protocol.rs @@ -15,7 +15,7 @@ use std::cmp; use std::env; use std::fs::File; -use std::io::BufWriter; +use std::io::{BufWriter, Write}; use std::net::SocketAddr; use std::sync::Arc; @@ -45,6 +45,7 @@ impl MessageHandler for Protocol { fn consume<'a>( &self, mut msg: Message<'a>, + writer: &'a mut Write, received_bytes: Arc>, ) -> Result>, Error> { let adapter = &self.adapter; @@ -65,12 +66,13 @@ impl MessageHandler for Protocol { let ping: Ping = msg.body()?; adapter.peer_difficulty(self.addr, ping.total_difficulty, ping.height); - Ok(Some(msg.respond( + Ok(Some(Response::new( Type::Pong, Pong { total_difficulty: adapter.total_difficulty(), height: adapter.total_height(), }, + writer, ))) } @@ -104,7 +106,7 @@ impl MessageHandler for Protocol { ); let tx = adapter.get_transaction(h); if let Some(tx) = tx { - Ok(Some(msg.respond(Type::Transaction, tx))) + Ok(Some(Response::new(Type::Transaction, tx, writer))) } else { Ok(None) } @@ -140,7 +142,7 @@ impl MessageHandler for Protocol { let bo = adapter.get_block(h); if let Some(b) = bo { - return Ok(Some(msg.respond(Type::Block, b))); + return Ok(Some(Response::new(Type::Block, b, writer))); } Ok(None) } @@ -160,7 +162,7 @@ impl MessageHandler for Protocol { let h: Hash = msg.body()?; if let Some(b) = adapter.get_block(h) { let cb: CompactBlock = b.into(); - Ok(Some(msg.respond(Type::CompactBlock, cb))) + Ok(Some(Response::new(Type::CompactBlock, cb, writer))) } else { Ok(None) } @@ -183,9 +185,11 @@ impl MessageHandler for Protocol { let headers = adapter.locate_headers(loc.hashes); // serialize and send all the headers over - Ok(Some( - msg.respond(Type::Headers, Headers { headers: headers }), - )) + Ok(Some(Response::new( + Type::Headers, + Headers { headers }, + writer, + ))) } // "header first" block propagation - if we have not yet seen this block @@ -226,11 +230,12 @@ impl MessageHandler for Protocol { Type::GetPeerAddrs => { let get_peers: GetPeerAddrs = msg.body()?; let peer_addrs = adapter.find_peer_addrs(get_peers.capabilities); - Ok(Some(msg.respond( + Ok(Some(Response::new( Type::PeerAddrs, PeerAddrs { peers: peer_addrs.iter().map(|sa| SockAddr(*sa)).collect(), }, + writer, ))) } @@ -251,13 +256,14 @@ impl MessageHandler for Protocol { if let Some(txhashset) = txhashset { let file_sz = txhashset.reader.metadata()?.len(); - let mut resp = msg.respond( + let mut resp = Response::new( Type::TxHashSetArchive, &TxHashSetArchive { height: sm_req.height as u64, hash: sm_req.hash, bytes: file_sz, }, + writer, ); resp.add_attachment(txhashset.reader); Ok(Some(resp))