diff --git a/keychain/src/keychain.rs b/keychain/src/keychain.rs index aa50e8b09..aed65f23f 100644 --- a/keychain/src/keychain.rs +++ b/keychain/src/keychain.rs @@ -14,6 +14,7 @@ use rand::{thread_rng, Rng}; use std::collections::HashMap; +use std::sync::{Arc, RwLock}; use util::secp; use util::secp::{Message, Secp256k1, Signature}; @@ -49,6 +50,7 @@ pub struct Keychain { secp: Secp256k1, extkey: extkey::ExtendedKey, key_overrides: HashMap, + key_derivation_cache: Arc>>, } impl Keychain { @@ -57,7 +59,6 @@ impl Keychain { } // For tests and burn only, associate a key identifier with a known secret key. - // pub fn burn_enabled(keychain: &Keychain, burn_key_id: &Identifier) -> Keychain { let mut key_overrides = HashMap::new(); key_overrides.insert( @@ -77,6 +78,7 @@ impl Keychain { secp: secp, extkey: extkey, key_overrides: HashMap::new(), + key_derivation_cache: Arc::new(RwLock::new(HashMap::new())), }; Ok(keychain) } @@ -94,35 +96,58 @@ impl Keychain { Ok(key_id) } - fn derived_key_search(&self, key_id: &Identifier, n_child: Option) -> Result { + fn derived_key(&self, key_id: &Identifier) -> Result { + trace!(LOGGER, "Derived Key by key_id: {}", key_id); + + // first check our overrides and just return the key if we have one in there if let Some(key) = self.key_overrides.get(key_id) { + trace!(LOGGER, "... Derived Key (using override) key_id: {}", key_id); return Ok(*key); } - trace!(LOGGER, "Derived Key key_id: {}", key_id); - - if let Some(n) = n_child{ - let extkey = self.extkey.derive(&self.secp, n)?; - return Ok(extkey.key); - }; - - for i in 1..10000 { - let extkey = self.extkey.derive(&self.secp, i)?; - if extkey.identifier(&self.secp)? == *key_id { - return Ok(extkey.key); + // then check the derivation cache to see if we have previously derived this key + // if so use the derivation from the cache to derive the key + { + let cache = self.key_derivation_cache.read().unwrap(); + if let Some(derivation) = cache.get(key_id) { + trace!(LOGGER, "... Derived Key (cache hit) key_id: {}, derivation: {}", key_id, derivation); + return Ok(self.derived_key_from_index(*derivation)?) } } + + // otherwise iterate over a large number of derivations looking for our key + // cache the resulting derivations by key_id for faster lookup later + // TODO - remove the 10k hard limit and be smarter about batching somehow + { + let mut cache = self.key_derivation_cache.write().unwrap(); + for i in 1..10_000 { + let extkey = self.extkey.derive(&self.secp, i)?; + let extkey_id = extkey.identifier(&self.secp)?; + + if !cache.contains_key(&extkey_id) { + trace!(LOGGER, "... Derived Key (cache miss) key_id: {}, derivation: {}", extkey_id, extkey.n_child); + cache.insert(extkey_id.clone(), extkey.n_child); + } + + if extkey_id == *key_id { + return Ok(extkey.key); + } + } + } + Err(Error::KeyDerivation( format!("cannot find extkey for {:?}", key_id), )) } - fn derived_key(&self, key_id: &Identifier) -> Result { - self.derived_key_search(key_id, None) - } - - fn derived_key_from_index(&self, key_id: &Identifier, n_child:u32) -> Result { - self.derived_key_search(key_id, Some(n_child)) + // if we know the derivation index we can just straight to deriving the key + fn derived_key_from_index( + &self, + derivation: u32, + ) -> Result { + trace!(LOGGER, "Derived Key (fast) by derivation: {}", derivation); + let extkey = self.extkey.derive(&self.secp, derivation)?; + return Ok(extkey.key) } pub fn commit(&self, amount: u64, key_id: &Identifier) -> Result { @@ -131,8 +156,12 @@ impl Keychain { Ok(commit) } - pub fn commit_with_key_index(&self, amount: u64, key_id: &Identifier, n_child: u32) -> Result { - let skey = self.derived_key_from_index(key_id, n_child)?; + pub fn commit_with_key_index( + &self, + amount: u64, + derivation: u32, + ) -> Result { + let skey = self.derived_key_from_index(derivation)?; let commit = self.secp.commit(amount, skey)?; Ok(commit) } diff --git a/wallet/src/checker.rs b/wallet/src/checker.rs index bbbfecab1..9aeac1002 100644 --- a/wallet/src/checker.rs +++ b/wallet/src/checker.rs @@ -58,7 +58,7 @@ pub fn refresh_outputs(config: &WalletConfig, keychain: &Keychain) -> Result<(), let mut commits: Vec = vec![]; // build a local map of wallet outputs by commits - // and a list of outputs we wantot query the node for + // and a list of outputs we want to query the node for let _ = WalletData::read_wallet(&config.data_file_dir, |wallet_data| { for out in wallet_data .outputs @@ -66,8 +66,7 @@ pub fn refresh_outputs(config: &WalletConfig, keychain: &Keychain) -> Result<(), .filter(|out| out.root_key_id == keychain.root_key_id()) .filter(|out| out.status != OutputStatus::Spent) { - let key_id = keychain.derive_key_id(out.n_child).unwrap(); - let commit = keychain.commit_with_key_index(out.value, &key_id, out.n_child).unwrap(); + let commit = keychain.commit_with_key_index(out.value, out.n_child).unwrap(); commits.push(commit); wallet_outputs.insert(commit, out.key_id.clone()); }