diff --git a/core/src/ser.rs b/core/src/ser.rs index 4a4648f09..34e07d271 100644 --- a/core/src/ser.rs +++ b/core/src/ser.rs @@ -25,6 +25,7 @@ use core::hash::{Hash, Hashed}; use keychain::{BlindingFactor, Identifier, IDENTIFIER_SIZE}; use std::fmt::Debug; use std::io::{self, Read, Write}; +use std::marker; use std::{cmp, error, fmt}; use util::secp::constants::{ AGG_SIGNATURE_SIZE, MAX_PROOF_SIZE, PEDERSEN_COMMITMENT_SIZE, SECRET_KEY_SIZE, @@ -46,6 +47,8 @@ pub enum Error { }, /// Data wasn't in a consumable format CorruptedData, + /// Incorrect number of elements (when deserializing a vec via read_multi say). + CountError, /// When asked to read too much data TooLargeReadErr, /// Consensus rule failure (currently sort order) @@ -75,6 +78,7 @@ impl fmt::Display for Error { received: ref r, } => write!(f, "expected {:?}, got {:?}", e, r), Error::CorruptedData => f.write_str("corrupted data"), + Error::CountError => f.write_str("count error"), Error::TooLargeReadErr => f.write_str("too large read"), Error::ConsensusError(ref e) => write!(f, "consensus error {:?}", e), Error::HexError(ref e) => write!(f, "hex error {:?}", e), @@ -95,6 +99,7 @@ impl error::Error for Error { Error::IOErr(ref e, _) => e, Error::UnexpectedData { .. } => "unexpected data", Error::CorruptedData => "corrupted data", + Error::CountError => "count error", Error::TooLargeReadErr => "too large read", Error::ConsensusError(_) => "consensus error (sort order)", Error::HexError(_) => "hex error", @@ -201,10 +206,44 @@ pub trait Writeable { fn write(&self, writer: &mut W) -> Result<(), Error>; } +struct IteratingReader<'a, T> { + count: u64, + curr: u64, + reader: &'a mut Reader, + _marker: marker::PhantomData, +} + +impl<'a, T> IteratingReader<'a, T> { + fn new(reader: &'a mut Reader, count: u64) -> IteratingReader<'a, T> { + let curr = 0; + IteratingReader { + count, + curr, + reader, + _marker: marker::PhantomData, + } + } +} + +impl<'a, T> Iterator for IteratingReader<'a, T> +where + T: Readable, +{ + type Item = T; + + fn next(&mut self) -> Option { + if self.curr >= self.count { + return None; + } + self.curr += 1; + T::read(self.reader).ok() + } +} + /// Reads multiple serialized items into a Vec. pub fn read_multi(reader: &mut Reader, count: u64) -> Result, Error> where - T: Readable + Hashed + Writeable, + T: Readable, { // Very rudimentary check to ensure we do not overflow anything // attempting to read huge amounts of data. @@ -213,8 +252,11 @@ where return Err(Error::TooLargeReadErr); } - let result: Vec = try!((0..count).map(|_| T::read(reader)).collect()); - Ok(result) + let res: Vec = IteratingReader::new(reader, count).collect(); + if res.len() as u64 != count { + return Err(Error::CountError); + } + Ok(res) } /// Trait that every type that can be deserialized from binary must implement.