diff --git a/core/src/core/transaction.rs b/core/src/core/transaction.rs index 0e51da452..38f198f54 100644 --- a/core/src/core/transaction.rs +++ b/core/src/core/transaction.rs @@ -1290,26 +1290,33 @@ impl Transaction { /// eliminating any input/output pairs with input spending output. /// Returns new slices with cut-through elements removed. /// Also returns slices of the cut-through elements themselves. +/// Note: Takes slices of _anything_ that is AsRef for greater flexibility. +/// So we can cut_through inputs and outputs but we can also cut_through inputs and output identifiers. +/// Or we can get crazy and cut_through inputs with other inputs to identify intersection and difference etc. /// /// Example: /// Inputs: [A, B, C] /// Outputs: [C, D, E] /// Returns: ([A, B], [D, E], [C], [C]) # element C is cut-through -pub fn cut_through<'a>( - inputs: &'a mut [Input], - outputs: &'a mut [Output], -) -> Result<(&'a [Input], &'a [Output], &'a [Input], &'a [Output]), Error> { - // Make sure inputs and outputs are sorted consistently by commitment. - inputs.sort_unstable_by_key(|x| x.commitment()); - outputs.sort_unstable_by_key(|x| x.commitment()); +pub fn cut_through<'a, 'b, T, U>( + inputs: &'a mut [T], + outputs: &'b mut [U], +) -> Result<(&'a [T], &'b [U], &'a [T], &'b [U]), Error> +where + T: AsRef + Ord, + U: AsRef + Ord, +{ + // Make sure inputs and outputs are sorted consistently as we will iterate over both. + inputs.sort_unstable_by_key(|x| *x.as_ref()); + outputs.sort_unstable_by_key(|x| *x.as_ref()); let mut inputs_idx = 0; let mut outputs_idx = 0; let mut ncut = 0; while inputs_idx < inputs.len() && outputs_idx < outputs.len() { match inputs[inputs_idx] - .partial_cmp(&outputs[outputs_idx]) - .expect("compare input to output") + .as_ref() + .cmp(&outputs[outputs_idx].as_ref()) { Ordering::Less => { inputs.swap(inputs_idx - ncut, inputs_idx); @@ -1370,15 +1377,16 @@ pub fn aggregate(txs: &[Transaction]) -> Result { } else if txs.len() == 1 { return Ok(txs[0].clone()); } - let mut n_inputs = 0; - let mut n_outputs = 0; - let mut n_kernels = 0; - for tx in txs.iter() { - n_inputs += tx.inputs().len(); - n_outputs += tx.outputs().len(); - n_kernels += tx.kernels().len(); - } + let (n_inputs, n_outputs, n_kernels) = + txs.iter() + .fold((0, 0, 0), |(inputs, outputs, kernels), tx| { + ( + inputs + tx.inputs().len(), + outputs + tx.outputs().len(), + kernels + tx.kernels().len(), + ) + }); let mut inputs: Vec = Vec::with_capacity(n_inputs); let mut outputs: Vec = Vec::with_capacity(n_outputs); let mut kernels: Vec = Vec::with_capacity(n_kernels); @@ -1498,17 +1506,9 @@ pub struct Input { impl DefaultHashable for Input {} hashable_ord!(Input); -// Inputs can be compared to outputs. -impl PartialEq for Input { - fn eq(&self, other: &Output) -> bool { - self.commitment() == other.commitment() - } -} - -// Inputs can be compared to outputs. -impl PartialOrd for Input { - fn partial_cmp(&self, other: &Output) -> Option { - Some(self.commitment().cmp(&other.commitment())) +impl AsRef for Input { + fn as_ref(&self) -> &Commitment { + &self.commit } } @@ -1714,6 +1714,12 @@ pub struct Output { impl DefaultHashable for Output {} hashable_ord!(Output); +impl AsRef for Output { + fn as_ref(&self) -> &Commitment { + &self.commit + } +} + impl ::std::hash::Hash for Output { fn hash(&self, state: &mut H) { let mut vec = Vec::new();