diff --git a/core/src/core/sumtree.rs b/core/src/core/sumtree.rs index 1f87215b6..2f13a1bd0 100644 --- a/core/src/core/sumtree.rs +++ b/core/src/core/sumtree.rs @@ -71,46 +71,46 @@ impl Summable for NoSum { } #[derive(Clone)] -enum NodeData { +enum Node { /// Node with 2^n children which are not stored with the tree Pruned(T::Sum), /// Actual data - Leaf(T), + Leaf(T::Sum), /// Node with 2^n children Internal { - lchild: Box>, - rchild: Box>, + lchild: Box>, + rchild: Box>, sum: T::Sum, }, } -impl Summable for NodeData { - type Sum = T::Sum; - fn sum(&self) -> T::Sum { - match *self { - NodeData::Pruned(ref sum) => sum.clone(), - NodeData::Leaf(ref data) => data.sum(), - NodeData::Internal { ref sum, .. } => sum.clone(), - } - } -} - -#[derive(Clone)] -struct Node { - full: bool, - data: NodeData, - hash: Hash, - depth: u8, -} - impl Summable for Node { type Sum = T::Sum; fn sum(&self) -> T::Sum { - self.data.sum() + match *self { + Node::Pruned(ref sum) => sum.clone(), + Node::Leaf(ref sum) => sum.clone(), + Node::Internal { ref sum, .. } => sum.clone(), + } } } -impl Node { +#[derive(Clone)] +struct NodeData { + full: bool, + node: Node, + hash: Hash, + depth: u8, +} + +impl Summable for NodeData { + type Sum = T::Sum; + fn sum(&self) -> T::Sum { + self.node.sum() + } +} + +impl NodeData { /// Get the root hash and sum of the node fn root_sum(&self) -> (Hash, T::Sum) { (self.hash, self.sum()) @@ -120,11 +120,11 @@ impl Node { if self.full { 1 << self.depth } else { - if let NodeData::Internal { + if let Node::Internal { ref lchild, ref rchild, .. - } = self.data + } = self.node { lchild.n_children() + rchild.n_children() } else { @@ -140,7 +140,7 @@ pub struct SumTree { /// Index mapping data to its index in the tree index: HashMap, /// Tree contents - root: Option>, + root: Option>, } impl SumTree @@ -160,7 +160,7 @@ where self.root.as_ref().map(|node| node.root_sum()) } - fn insert_right_of(mut old: Node, new: Node) -> Node { + fn insert_right_of(mut old: NodeData, new: NodeData) -> NodeData { assert!(old.depth >= new.depth); // If we are inserting next to a full node, make a parent. If we're @@ -172,15 +172,15 @@ where let parent_sum = old.sum() + new.sum(); let parent_hash = (parent_depth, &parent_sum, old.hash, new.hash).hash(); let parent_full = old.depth == new.depth; - let parent_data = NodeData::Internal { + let parent_node = Node::Internal { lchild: Box::new(old), rchild: Box::new(new), sum: parent_sum, }; - Node { + NodeData { full: parent_full, - data: parent_data, + node: parent_node, hash: parent_hash, depth: parent_depth, } @@ -188,16 +188,16 @@ where // inserting under the node, so we recurse. The right child of a partial // node is always another partial node or a leaf. } else { - if let NodeData::Internal { + if let Node::Internal { ref lchild, ref mut rchild, ref mut sum, - } = old.data + } = old.node { // Recurse - let dummy_child = Node { + let dummy_child = NodeData { full: true, - data: NodeData::Pruned(sum.clone()), + node: Node::Pruned(sum.clone()), hash: old.hash, depth: 0, }; @@ -243,9 +243,9 @@ where // Special-case the first element if self.root.is_none() { - self.root = Some(Node { + self.root = Some(NodeData { full: true, - data: NodeData::Leaf(elem), + node: Node::Leaf(elem_sum), hash: elem_hash, depth: 0, }); @@ -258,9 +258,9 @@ where let old_root = mem::replace(&mut self.root, None).unwrap(); // Insert into tree, compute new root - let new_node = Node { + let new_node = NodeData { full: true, - data: NodeData::Leaf(elem), + node: Node::Leaf(elem_sum), hash: elem_hash, depth: 0, }; @@ -272,16 +272,16 @@ where true } - fn replace_recurse(node: &mut Node, index: usize, new_elem: T) { + fn replace_recurse(node: &mut NodeData, index: usize, new_elem: T) { assert!(index < (1 << node.depth)); if node.depth == 0 { assert!(node.full); node.hash = (0u8, new_elem.sum(), Hashed::hash(&new_elem)).hash(); - node.data = NodeData::Leaf(new_elem); + node.node = Node::Leaf(new_elem.sum()); } else { - match node.data { - NodeData::Internal { + match node.node { + Node::Internal { ref mut lchild, ref mut rchild, ref mut sum, @@ -296,8 +296,8 @@ where node.hash = (node.depth, &*sum, lchild.hash, rchild.hash).hash(); } // Pruned data would not have been in the index - NodeData::Pruned(_) => unreachable!(), - NodeData::Leaf(_) => unreachable!(), + Node::Pruned(_) => unreachable!(), + Node::Leaf(_) => unreachable!(), } } } @@ -337,20 +337,20 @@ where self.index.get(&index_hash).map(|x| *x) } - fn prune_recurse(node: &mut Node, index: usize) { + fn prune_recurse(node: &mut NodeData, index: usize) { assert!(index < (1 << node.depth)); if node.depth == 0 { - let sum = if let NodeData::Leaf(ref elem) = node.data { - elem.sum() + let sum = if let Node::Leaf(ref sum) = node.node { + sum.clone() } else { unreachable!() }; - node.data = NodeData::Pruned(sum); + node.node = Node::Pruned(sum); } else { let mut prune_me = None; - match node.data { - NodeData::Internal { + match node.node { + Node::Internal { ref mut lchild, ref mut rchild, .. @@ -361,21 +361,21 @@ where } else { SumTree::prune_recurse(lchild, index); } - if let (&NodeData::Pruned(ref lsum), &NodeData::Pruned(ref rsum)) = - (&lchild.data, &rchild.data) + if let (&Node::Pruned(ref lsum), &Node::Pruned(ref rsum)) = + (&lchild.node, &rchild.node) { if node.full { prune_me = Some(lsum.clone() + rsum.clone()); } } } - NodeData::Pruned(_) => { + Node::Pruned(_) => { // Already pruned. Ok. } - NodeData::Leaf(_) => unreachable!(), + Node::Leaf(_) => unreachable!(), } if let Some(sum) = prune_me { - node.data = NodeData::Pruned(sum); + node.node = Node::Pruned(sum); } } } @@ -429,7 +429,7 @@ where } } -impl Writeable for Node +impl Writeable for NodeData where T: Summable + Writeable, { @@ -441,7 +441,7 @@ where if self.full { depth |= 0x80; } - if let NodeData::Pruned(_) = self.data { + if let Node::Pruned(_) = self.node { } else { depth |= 0xc0; } @@ -449,10 +449,10 @@ where // Encode node try!(writer.write_u8(depth)); try!(self.hash.write(writer)); - match self.data { - NodeData::Pruned(ref sum) => sum.write(writer), - NodeData::Leaf(ref data) => data.write(writer), - NodeData::Internal { + match self.node { + Node::Pruned(ref sum) => sum.write(writer), + Node::Leaf(ref sum) => sum.write(writer), + Node::Internal { ref lchild, ref rchild, ref sum, @@ -469,7 +469,7 @@ fn node_read_recurse( reader: &mut Reader, index: &mut HashMap, tree_index: &mut usize, -) -> Result, ser::Error> +) -> Result, ser::Error> where T: Summable + Readable + Hashed, { @@ -486,21 +486,19 @@ where // Read remainder of node let hash = try!(Readable::read(reader)); + let sum = try!(Readable::read(reader)); let data = match (depth, pruned) { (_, true) => { - let sum = try!(Readable::read(reader)); *tree_index += 1 << depth as usize; - NodeData::Pruned(sum) + Node::Pruned(sum) } (0, _) => { - let elem: T = try!(Readable::read(reader)); - index.insert(Hashed::hash(&elem), *tree_index); + index.insert(hash, *tree_index); *tree_index += 1; - NodeData::Leaf(elem) + Node::Leaf(sum) } (_, _) => { - let sum = try!(Readable::read(reader)); - NodeData::Internal { + Node::Internal { lchild: Box::new(try!(node_read_recurse(reader, index, tree_index))), rchild: Box::new(try!(node_read_recurse(reader, index, tree_index))), sum: sum, @@ -508,9 +506,9 @@ where } }; - Ok(Node { + Ok(NodeData { full: full, - data: data, + node: data, hash: hash, depth: depth, }) @@ -539,16 +537,13 @@ where let mut index = HashMap::new(); let hash = try!(Readable::read(reader)); + let sum = try!(Readable::read(reader)); let data = match (depth, pruned) { - (_, true) => { - let sum = try!(Readable::read(reader)); - NodeData::Pruned(sum) - } - (0, _) => NodeData::Leaf(try!(Readable::read(reader))), + (_, true) => Node::Pruned(sum), + (0, _) => Node::Leaf(sum), (_, _) => { - let sum = try!(Readable::read(reader)); let mut tree_index = 0; - NodeData::Internal { + Node::Internal { lchild: Box::new(try!(node_read_recurse(reader, &mut index, &mut tree_index))), rchild: Box::new(try!(node_read_recurse(reader, &mut index, &mut tree_index))), sum: sum, @@ -558,9 +553,9 @@ where Ok(SumTree { index: index, - root: Some(Node { + root: Some(NodeData { full: full, - data: data, + node: data, hash: hash, depth: depth, }), @@ -633,7 +628,7 @@ where // a couple functions that help debugging #[allow(dead_code)] -fn print_node(node: &Node, tab_level: usize) +fn print_node(node: &NodeData, tab_level: usize) where T: Summable + Writeable, T::Sum: std::fmt::Debug, @@ -642,10 +637,10 @@ where print!(" "); } print!("[{:03}] {} {:?}", node.depth, node.hash, node.sum()); - match node.data { - NodeData::Pruned(_) => println!(" X"), - NodeData::Leaf(_) => println!(" L"), - NodeData::Internal { + match node.node { + Node::Pruned(_) => println!(" X"), + Node::Leaf(_) => println!(" L"), + Node::Internal { ref lchild, ref rchild, ..