Sum-tree only stores hashes

Couple improvements to the sum tree implementation. First change
is to not store the data but only its hashes, assuming a higher
level layer will take care of the data (KV store).

Second minor improvement is renaming Node into NodeData and vice
versa.
This commit is contained in:
Ignotus Peverell 2017-07-20 14:07:57 +00:00
parent 01b66de437
commit 929943d8b1
No known key found for this signature in database
GPG key ID: 99CD25F39F8F8211

View file

@ -71,46 +71,46 @@ impl<T> Summable for NoSum<T> {
} }
#[derive(Clone)] #[derive(Clone)]
enum NodeData<T: Summable> { enum Node<T: Summable> {
/// Node with 2^n children which are not stored with the tree /// Node with 2^n children which are not stored with the tree
Pruned(T::Sum), Pruned(T::Sum),
/// Actual data /// Actual data
Leaf(T), Leaf(T::Sum),
/// Node with 2^n children /// Node with 2^n children
Internal { Internal {
lchild: Box<Node<T>>, lchild: Box<NodeData<T>>,
rchild: Box<Node<T>>, rchild: Box<NodeData<T>>,
sum: T::Sum, sum: T::Sum,
}, },
} }
impl<T: Summable> Summable for NodeData<T> {
type Sum = T::Sum;
fn sum(&self) -> T::Sum {
match *self {
NodeData::Pruned(ref sum) => sum.clone(),
NodeData::Leaf(ref data) => data.sum(),
NodeData::Internal { ref sum, .. } => sum.clone(),
}
}
}
#[derive(Clone)]
struct Node<T: Summable> {
full: bool,
data: NodeData<T>,
hash: Hash,
depth: u8,
}
impl<T: Summable> Summable for Node<T> { impl<T: Summable> Summable for Node<T> {
type Sum = T::Sum; type Sum = T::Sum;
fn sum(&self) -> T::Sum { fn sum(&self) -> T::Sum {
self.data.sum() match *self {
Node::Pruned(ref sum) => sum.clone(),
Node::Leaf(ref sum) => sum.clone(),
Node::Internal { ref sum, .. } => sum.clone(),
}
} }
} }
impl<T: Summable> Node<T> { #[derive(Clone)]
struct NodeData<T: Summable> {
full: bool,
node: Node<T>,
hash: Hash,
depth: u8,
}
impl<T: Summable> Summable for NodeData<T> {
type Sum = T::Sum;
fn sum(&self) -> T::Sum {
self.node.sum()
}
}
impl<T: Summable> NodeData<T> {
/// Get the root hash and sum of the node /// Get the root hash and sum of the node
fn root_sum(&self) -> (Hash, T::Sum) { fn root_sum(&self) -> (Hash, T::Sum) {
(self.hash, self.sum()) (self.hash, self.sum())
@ -120,11 +120,11 @@ impl<T: Summable> Node<T> {
if self.full { if self.full {
1 << self.depth 1 << self.depth
} else { } else {
if let NodeData::Internal { if let Node::Internal {
ref lchild, ref lchild,
ref rchild, ref rchild,
.. ..
} = self.data } = self.node
{ {
lchild.n_children() + rchild.n_children() lchild.n_children() + rchild.n_children()
} else { } else {
@ -140,7 +140,7 @@ pub struct SumTree<T: Summable + Writeable> {
/// Index mapping data to its index in the tree /// Index mapping data to its index in the tree
index: HashMap<Hash, usize>, index: HashMap<Hash, usize>,
/// Tree contents /// Tree contents
root: Option<Node<T>>, root: Option<NodeData<T>>,
} }
impl<T> SumTree<T> impl<T> SumTree<T>
@ -160,7 +160,7 @@ where
self.root.as_ref().map(|node| node.root_sum()) self.root.as_ref().map(|node| node.root_sum())
} }
fn insert_right_of(mut old: Node<T>, new: Node<T>) -> Node<T> { fn insert_right_of(mut old: NodeData<T>, new: NodeData<T>) -> NodeData<T> {
assert!(old.depth >= new.depth); assert!(old.depth >= new.depth);
// If we are inserting next to a full node, make a parent. If we're // If we are inserting next to a full node, make a parent. If we're
@ -172,15 +172,15 @@ where
let parent_sum = old.sum() + new.sum(); let parent_sum = old.sum() + new.sum();
let parent_hash = (parent_depth, &parent_sum, old.hash, new.hash).hash(); let parent_hash = (parent_depth, &parent_sum, old.hash, new.hash).hash();
let parent_full = old.depth == new.depth; let parent_full = old.depth == new.depth;
let parent_data = NodeData::Internal { let parent_node = Node::Internal {
lchild: Box::new(old), lchild: Box::new(old),
rchild: Box::new(new), rchild: Box::new(new),
sum: parent_sum, sum: parent_sum,
}; };
Node { NodeData {
full: parent_full, full: parent_full,
data: parent_data, node: parent_node,
hash: parent_hash, hash: parent_hash,
depth: parent_depth, depth: parent_depth,
} }
@ -188,16 +188,16 @@ where
// inserting under the node, so we recurse. The right child of a partial // inserting under the node, so we recurse. The right child of a partial
// node is always another partial node or a leaf. // node is always another partial node or a leaf.
} else { } else {
if let NodeData::Internal { if let Node::Internal {
ref lchild, ref lchild,
ref mut rchild, ref mut rchild,
ref mut sum, ref mut sum,
} = old.data } = old.node
{ {
// Recurse // Recurse
let dummy_child = Node { let dummy_child = NodeData {
full: true, full: true,
data: NodeData::Pruned(sum.clone()), node: Node::Pruned(sum.clone()),
hash: old.hash, hash: old.hash,
depth: 0, depth: 0,
}; };
@ -243,9 +243,9 @@ where
// Special-case the first element // Special-case the first element
if self.root.is_none() { if self.root.is_none() {
self.root = Some(Node { self.root = Some(NodeData {
full: true, full: true,
data: NodeData::Leaf(elem), node: Node::Leaf(elem_sum),
hash: elem_hash, hash: elem_hash,
depth: 0, depth: 0,
}); });
@ -258,9 +258,9 @@ where
let old_root = mem::replace(&mut self.root, None).unwrap(); let old_root = mem::replace(&mut self.root, None).unwrap();
// Insert into tree, compute new root // Insert into tree, compute new root
let new_node = Node { let new_node = NodeData {
full: true, full: true,
data: NodeData::Leaf(elem), node: Node::Leaf(elem_sum),
hash: elem_hash, hash: elem_hash,
depth: 0, depth: 0,
}; };
@ -272,16 +272,16 @@ where
true true
} }
fn replace_recurse(node: &mut Node<T>, index: usize, new_elem: T) { fn replace_recurse(node: &mut NodeData<T>, index: usize, new_elem: T) {
assert!(index < (1 << node.depth)); assert!(index < (1 << node.depth));
if node.depth == 0 { if node.depth == 0 {
assert!(node.full); assert!(node.full);
node.hash = (0u8, new_elem.sum(), Hashed::hash(&new_elem)).hash(); node.hash = (0u8, new_elem.sum(), Hashed::hash(&new_elem)).hash();
node.data = NodeData::Leaf(new_elem); node.node = Node::Leaf(new_elem.sum());
} else { } else {
match node.data { match node.node {
NodeData::Internal { Node::Internal {
ref mut lchild, ref mut lchild,
ref mut rchild, ref mut rchild,
ref mut sum, ref mut sum,
@ -296,8 +296,8 @@ where
node.hash = (node.depth, &*sum, lchild.hash, rchild.hash).hash(); node.hash = (node.depth, &*sum, lchild.hash, rchild.hash).hash();
} }
// Pruned data would not have been in the index // Pruned data would not have been in the index
NodeData::Pruned(_) => unreachable!(), Node::Pruned(_) => unreachable!(),
NodeData::Leaf(_) => unreachable!(), Node::Leaf(_) => unreachable!(),
} }
} }
} }
@ -337,20 +337,20 @@ where
self.index.get(&index_hash).map(|x| *x) self.index.get(&index_hash).map(|x| *x)
} }
fn prune_recurse(node: &mut Node<T>, index: usize) { fn prune_recurse(node: &mut NodeData<T>, index: usize) {
assert!(index < (1 << node.depth)); assert!(index < (1 << node.depth));
if node.depth == 0 { if node.depth == 0 {
let sum = if let NodeData::Leaf(ref elem) = node.data { let sum = if let Node::Leaf(ref sum) = node.node {
elem.sum() sum.clone()
} else { } else {
unreachable!() unreachable!()
}; };
node.data = NodeData::Pruned(sum); node.node = Node::Pruned(sum);
} else { } else {
let mut prune_me = None; let mut prune_me = None;
match node.data { match node.node {
NodeData::Internal { Node::Internal {
ref mut lchild, ref mut lchild,
ref mut rchild, ref mut rchild,
.. ..
@ -361,21 +361,21 @@ where
} else { } else {
SumTree::prune_recurse(lchild, index); SumTree::prune_recurse(lchild, index);
} }
if let (&NodeData::Pruned(ref lsum), &NodeData::Pruned(ref rsum)) = if let (&Node::Pruned(ref lsum), &Node::Pruned(ref rsum)) =
(&lchild.data, &rchild.data) (&lchild.node, &rchild.node)
{ {
if node.full { if node.full {
prune_me = Some(lsum.clone() + rsum.clone()); prune_me = Some(lsum.clone() + rsum.clone());
} }
} }
} }
NodeData::Pruned(_) => { Node::Pruned(_) => {
// Already pruned. Ok. // Already pruned. Ok.
} }
NodeData::Leaf(_) => unreachable!(), Node::Leaf(_) => unreachable!(),
} }
if let Some(sum) = prune_me { if let Some(sum) = prune_me {
node.data = NodeData::Pruned(sum); node.node = Node::Pruned(sum);
} }
} }
} }
@ -429,7 +429,7 @@ where
} }
} }
impl<T> Writeable for Node<T> impl<T> Writeable for NodeData<T>
where where
T: Summable + Writeable, T: Summable + Writeable,
{ {
@ -441,7 +441,7 @@ where
if self.full { if self.full {
depth |= 0x80; depth |= 0x80;
} }
if let NodeData::Pruned(_) = self.data { if let Node::Pruned(_) = self.node {
} else { } else {
depth |= 0xc0; depth |= 0xc0;
} }
@ -449,10 +449,10 @@ where
// Encode node // Encode node
try!(writer.write_u8(depth)); try!(writer.write_u8(depth));
try!(self.hash.write(writer)); try!(self.hash.write(writer));
match self.data { match self.node {
NodeData::Pruned(ref sum) => sum.write(writer), Node::Pruned(ref sum) => sum.write(writer),
NodeData::Leaf(ref data) => data.write(writer), Node::Leaf(ref sum) => sum.write(writer),
NodeData::Internal { Node::Internal {
ref lchild, ref lchild,
ref rchild, ref rchild,
ref sum, ref sum,
@ -469,7 +469,7 @@ fn node_read_recurse<T>(
reader: &mut Reader, reader: &mut Reader,
index: &mut HashMap<Hash, usize>, index: &mut HashMap<Hash, usize>,
tree_index: &mut usize, tree_index: &mut usize,
) -> Result<Node<T>, ser::Error> ) -> Result<NodeData<T>, ser::Error>
where where
T: Summable + Readable + Hashed, T: Summable + Readable + Hashed,
{ {
@ -486,21 +486,19 @@ where
// Read remainder of node // Read remainder of node
let hash = try!(Readable::read(reader)); let hash = try!(Readable::read(reader));
let sum = try!(Readable::read(reader));
let data = match (depth, pruned) { let data = match (depth, pruned) {
(_, true) => { (_, true) => {
let sum = try!(Readable::read(reader));
*tree_index += 1 << depth as usize; *tree_index += 1 << depth as usize;
NodeData::Pruned(sum) Node::Pruned(sum)
} }
(0, _) => { (0, _) => {
let elem: T = try!(Readable::read(reader)); index.insert(hash, *tree_index);
index.insert(Hashed::hash(&elem), *tree_index);
*tree_index += 1; *tree_index += 1;
NodeData::Leaf(elem) Node::Leaf(sum)
} }
(_, _) => { (_, _) => {
let sum = try!(Readable::read(reader)); Node::Internal {
NodeData::Internal {
lchild: Box::new(try!(node_read_recurse(reader, index, tree_index))), lchild: Box::new(try!(node_read_recurse(reader, index, tree_index))),
rchild: Box::new(try!(node_read_recurse(reader, index, tree_index))), rchild: Box::new(try!(node_read_recurse(reader, index, tree_index))),
sum: sum, sum: sum,
@ -508,9 +506,9 @@ where
} }
}; };
Ok(Node { Ok(NodeData {
full: full, full: full,
data: data, node: data,
hash: hash, hash: hash,
depth: depth, depth: depth,
}) })
@ -539,16 +537,13 @@ where
let mut index = HashMap::new(); let mut index = HashMap::new();
let hash = try!(Readable::read(reader)); let hash = try!(Readable::read(reader));
let sum = try!(Readable::read(reader));
let data = match (depth, pruned) { let data = match (depth, pruned) {
(_, true) => { (_, true) => Node::Pruned(sum),
let sum = try!(Readable::read(reader)); (0, _) => Node::Leaf(sum),
NodeData::Pruned(sum)
}
(0, _) => NodeData::Leaf(try!(Readable::read(reader))),
(_, _) => { (_, _) => {
let sum = try!(Readable::read(reader));
let mut tree_index = 0; let mut tree_index = 0;
NodeData::Internal { Node::Internal {
lchild: Box::new(try!(node_read_recurse(reader, &mut index, &mut tree_index))), lchild: Box::new(try!(node_read_recurse(reader, &mut index, &mut tree_index))),
rchild: Box::new(try!(node_read_recurse(reader, &mut index, &mut tree_index))), rchild: Box::new(try!(node_read_recurse(reader, &mut index, &mut tree_index))),
sum: sum, sum: sum,
@ -558,9 +553,9 @@ where
Ok(SumTree { Ok(SumTree {
index: index, index: index,
root: Some(Node { root: Some(NodeData {
full: full, full: full,
data: data, node: data,
hash: hash, hash: hash,
depth: depth, depth: depth,
}), }),
@ -633,7 +628,7 @@ where
// a couple functions that help debugging // a couple functions that help debugging
#[allow(dead_code)] #[allow(dead_code)]
fn print_node<T>(node: &Node<T>, tab_level: usize) fn print_node<T>(node: &NodeData<T>, tab_level: usize)
where where
T: Summable + Writeable, T: Summable + Writeable,
T::Sum: std::fmt::Debug, T::Sum: std::fmt::Debug,
@ -642,10 +637,10 @@ where
print!(" "); print!(" ");
} }
print!("[{:03}] {} {:?}", node.depth, node.hash, node.sum()); print!("[{:03}] {} {:?}", node.depth, node.hash, node.sum());
match node.data { match node.node {
NodeData::Pruned(_) => println!(" X"), Node::Pruned(_) => println!(" X"),
NodeData::Leaf(_) => println!(" L"), Node::Leaf(_) => println!(" L"),
NodeData::Internal { Node::Internal {
ref lchild, ref lchild,
ref rchild, ref rchild,
.. ..