// Copyright 2016 The Grin Developers // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //! Provides a connection wrapper that handles the lower level tasks in sending or //! receiving data from the TCP socket, as well as dealing with timeouts. use std::iter; use std::sync::{Mutex, Arc}; use futures; use futures::{Stream, Future}; use futures::stream; use futures::sync::mpsc::{Sender, UnboundedSender, UnboundedReceiver}; use tokio_core::io::{Io, WriteHalf, ReadHalf, write_all, read_exact}; use tokio_core::net::TcpStream; use core::ser; use msg::*; /// Handler to provide to the connection, will be called back anytime a message is /// received. The provided sender can be use to immediately send back another /// message. pub trait Handler: Sync + Send { /// Handle function to implement to process incoming messages. A sender to reply /// immediately as well as the message header and its unparsed body are provided. fn handle(&self, sender: UnboundedSender>, header: MsgHeader, body: Vec) -> Result<(), ser::Error>; } impl Handler for F where F: Fn(UnboundedSender>, MsgHeader, Vec) -> Result<(), ser::Error>, F: Sync + Send { fn handle(&self, sender: UnboundedSender>, header: MsgHeader, body: Vec) -> Result<(), ser::Error> { self(sender, header, body) } } /// A higher level connection wrapping the TcpStream. Maintains the amount of data /// transmitted and deals with the low-level task of sending and receiving /// data, parsing message headers and timeouts. pub struct Connection { // Channel to push bytes to the remote peer outbound_chan: UnboundedSender>, // Close the connection with the remote peer close_chan: Sender<()>, // Bytes we've sent. sent_bytes: Arc>, // Bytes we've received. received_bytes: Arc>, // Counter for read errors. error_count: Mutex, } impl Connection { /// Start listening on the provided connection and wraps it. Does not hang the /// current thread, instead just returns a future and the Connection itself. pub fn listen(conn: TcpStream, handler: F) -> (Connection, Box>) where F: Handler + 'static { let (reader, writer) = conn.split(); // prepare the channel that will transmit data to the connection writer let (tx, rx) = futures::sync::mpsc::unbounded(); // same for closing the connection let (close_tx, close_rx) = futures::sync::mpsc::channel(1); let close_conn = close_rx.for_each(|_| Ok(())).map_err(|_| ser::Error::CorruptedData); let me = Connection { outbound_chan: tx.clone(), close_chan: close_tx, sent_bytes: Arc::new(Mutex::new(0)), received_bytes: Arc::new(Mutex::new(0)), error_count: Mutex::new(0), }; // setup the reading future, getting messages from the peer and processing them let read_msg = me.read_msg(tx, reader, handler).map(|_| ()); // setting the writing future, getting messages from our system and sending // them out let write_msg = me.write_msg(rx, writer).map(|_| ()); // select between our different futures and return them let fut = Box::new(close_conn.select(read_msg.select(write_msg).map(|_| ()).map_err(|(e, _)| e)) .map(|_| ()) .map_err(|(e, _)| e)); (me, fut) } /// Prepares the future that gets message data produced by our system and /// sends it to the peer connection fn write_msg(&self, rx: UnboundedReceiver>, writer: WriteHalf) -> Box, Error = ser::Error>> { let sent_bytes = self.sent_bytes.clone(); let send_data = rx.map(move |data| { // add the count of bytes sent let mut sent_bytes = sent_bytes.lock().unwrap(); *sent_bytes += data.len() as u64; data }) // write the data and make sure the future returns the right types .fold(writer, |writer, data| write_all(writer, data).map_err(|_| ()).map(|(writer, buf)| writer)) .map_err(|_| ser::Error::CorruptedData); Box::new(send_data) } /// Prepares the future reading from the peer connection, parsing each /// message and forwarding them appropriately based on their type fn read_msg(&self, sender: UnboundedSender>, reader: ReadHalf, handler: F) -> Box, Error = ser::Error>> where F: Handler + 'static { // infinite iterator stream so we repeat the message reading logic until the // peer is stopped let iter = stream::iter(iter::repeat(()).map(Ok::<(), ser::Error>)); // setup the reading future, getting messages from the peer and processing them let recv_bytes = self.received_bytes.clone(); let handler = Arc::new(handler); let read_msg = iter.fold(reader, move |reader, _| { let recv_bytes = recv_bytes.clone(); let handler = handler.clone(); let sender_inner = sender.clone(); // first read the message header read_exact(reader, vec![0u8; HEADER_LEN as usize]) .map_err(|e| ser::Error::IOErr(e)) .and_then(move |(reader, buf)| { let header = try!(ser::deserialize::(&mut &buf[..])); Ok((reader, header)) }) .and_then(move |(reader, header)| { // now that we have a size, proceed with the body read_exact(reader, vec![0u8; header.msg_len as usize]) .map(|(reader, buf)| (reader, header, buf)) .map_err(|e| ser::Error::IOErr(e)) }) .map(move |(reader, header, buf)| { // add the count of bytes received let mut recv_bytes = recv_bytes.lock().unwrap(); *recv_bytes += header.serialized_len() + header.msg_len; // and handle the different message types let msg_type = header.msg_type; if let Err(e) = handler.handle(sender_inner.clone(), header, buf) { debug!("Invalid {:?} message: {}", msg_type, e); } reader }) }); Box::new(read_msg) } /// Utility function to send any Writeable. Handles adding the header and /// serialization. pub fn send_msg(&self, t: Type, body: &ser::Writeable) -> Result<(), ser::Error> { let mut body_data = vec![]; try!(ser::serialize(&mut body_data, body)); let mut data = vec![]; try!(ser::serialize(&mut data, &MsgHeader::new(t, body_data.len() as u64))); data.append(&mut body_data); self.outbound_chan.send(data).map_err(|_| ser::Error::CorruptedData) } /// Bytes sent and received by this peer to the remote peer. pub fn transmitted_bytes(&self) -> (u64, u64) { let sent = *self.sent_bytes.lock().unwrap(); let recv = *self.received_bytes.lock().unwrap(); (sent, recv) } }