cache key_id->derivation in the keychain (#253)

This commit is contained in:
AntiochP 2017-11-10 10:12:15 -05:00 committed by GitHub
parent b831192335
commit 8f33c7e0fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 24 deletions

View file

@ -14,6 +14,7 @@
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use util::secp; use util::secp;
use util::secp::{Message, Secp256k1, Signature}; use util::secp::{Message, Secp256k1, Signature};
@ -49,6 +50,7 @@ pub struct Keychain {
secp: Secp256k1, secp: Secp256k1,
extkey: extkey::ExtendedKey, extkey: extkey::ExtendedKey,
key_overrides: HashMap<Identifier, SecretKey>, key_overrides: HashMap<Identifier, SecretKey>,
key_derivation_cache: Arc<RwLock<HashMap<Identifier, u32>>>,
} }
impl Keychain { impl Keychain {
@ -57,7 +59,6 @@ impl Keychain {
} }
// For tests and burn only, associate a key identifier with a known secret key. // For tests and burn only, associate a key identifier with a known secret key.
//
pub fn burn_enabled(keychain: &Keychain, burn_key_id: &Identifier) -> Keychain { pub fn burn_enabled(keychain: &Keychain, burn_key_id: &Identifier) -> Keychain {
let mut key_overrides = HashMap::new(); let mut key_overrides = HashMap::new();
key_overrides.insert( key_overrides.insert(
@ -77,6 +78,7 @@ impl Keychain {
secp: secp, secp: secp,
extkey: extkey, extkey: extkey,
key_overrides: HashMap::new(), key_overrides: HashMap::new(),
key_derivation_cache: Arc::new(RwLock::new(HashMap::new())),
}; };
Ok(keychain) Ok(keychain)
} }
@ -94,35 +96,58 @@ impl Keychain {
Ok(key_id) Ok(key_id)
} }
fn derived_key_search(&self, key_id: &Identifier, n_child: Option<u32>) -> Result<SecretKey, Error> { fn derived_key(&self, key_id: &Identifier) -> Result<SecretKey, Error> {
trace!(LOGGER, "Derived Key by key_id: {}", key_id);
// first check our overrides and just return the key if we have one in there
if let Some(key) = self.key_overrides.get(key_id) { if let Some(key) = self.key_overrides.get(key_id) {
trace!(LOGGER, "... Derived Key (using override) key_id: {}", key_id);
return Ok(*key); return Ok(*key);
} }
trace!(LOGGER, "Derived Key key_id: {}", key_id); // then check the derivation cache to see if we have previously derived this key
// if so use the derivation from the cache to derive the key
if let Some(n) = n_child{ {
let extkey = self.extkey.derive(&self.secp, n)?; let cache = self.key_derivation_cache.read().unwrap();
return Ok(extkey.key); if let Some(derivation) = cache.get(key_id) {
}; trace!(LOGGER, "... Derived Key (cache hit) key_id: {}, derivation: {}", key_id, derivation);
return Ok(self.derived_key_from_index(*derivation)?)
for i in 1..10000 {
let extkey = self.extkey.derive(&self.secp, i)?;
if extkey.identifier(&self.secp)? == *key_id {
return Ok(extkey.key);
} }
} }
// otherwise iterate over a large number of derivations looking for our key
// cache the resulting derivations by key_id for faster lookup later
// TODO - remove the 10k hard limit and be smarter about batching somehow
{
let mut cache = self.key_derivation_cache.write().unwrap();
for i in 1..10_000 {
let extkey = self.extkey.derive(&self.secp, i)?;
let extkey_id = extkey.identifier(&self.secp)?;
if !cache.contains_key(&extkey_id) {
trace!(LOGGER, "... Derived Key (cache miss) key_id: {}, derivation: {}", extkey_id, extkey.n_child);
cache.insert(extkey_id.clone(), extkey.n_child);
}
if extkey_id == *key_id {
return Ok(extkey.key);
}
}
}
Err(Error::KeyDerivation( Err(Error::KeyDerivation(
format!("cannot find extkey for {:?}", key_id), format!("cannot find extkey for {:?}", key_id),
)) ))
} }
fn derived_key(&self, key_id: &Identifier) -> Result<SecretKey, Error> { // if we know the derivation index we can just straight to deriving the key
self.derived_key_search(key_id, None) fn derived_key_from_index(
} &self,
derivation: u32,
fn derived_key_from_index(&self, key_id: &Identifier, n_child:u32) -> Result<SecretKey, Error> { ) -> Result<SecretKey, Error> {
self.derived_key_search(key_id, Some(n_child)) trace!(LOGGER, "Derived Key (fast) by derivation: {}", derivation);
let extkey = self.extkey.derive(&self.secp, derivation)?;
return Ok(extkey.key)
} }
pub fn commit(&self, amount: u64, key_id: &Identifier) -> Result<Commitment, Error> { pub fn commit(&self, amount: u64, key_id: &Identifier) -> Result<Commitment, Error> {
@ -131,8 +156,12 @@ impl Keychain {
Ok(commit) Ok(commit)
} }
pub fn commit_with_key_index(&self, amount: u64, key_id: &Identifier, n_child: u32) -> Result<Commitment, Error> { pub fn commit_with_key_index(
let skey = self.derived_key_from_index(key_id, n_child)?; &self,
amount: u64,
derivation: u32,
) -> Result<Commitment, Error> {
let skey = self.derived_key_from_index(derivation)?;
let commit = self.secp.commit(amount, skey)?; let commit = self.secp.commit(amount, skey)?;
Ok(commit) Ok(commit)
} }

View file

@ -58,7 +58,7 @@ pub fn refresh_outputs(config: &WalletConfig, keychain: &Keychain) -> Result<(),
let mut commits: Vec<pedersen::Commitment> = vec![]; let mut commits: Vec<pedersen::Commitment> = vec![];
// build a local map of wallet outputs by commits // build a local map of wallet outputs by commits
// and a list of outputs we wantot query the node for // and a list of outputs we want to query the node for
let _ = WalletData::read_wallet(&config.data_file_dir, |wallet_data| { let _ = WalletData::read_wallet(&config.data_file_dir, |wallet_data| {
for out in wallet_data for out in wallet_data
.outputs .outputs
@ -66,8 +66,7 @@ pub fn refresh_outputs(config: &WalletConfig, keychain: &Keychain) -> Result<(),
.filter(|out| out.root_key_id == keychain.root_key_id()) .filter(|out| out.root_key_id == keychain.root_key_id())
.filter(|out| out.status != OutputStatus::Spent) .filter(|out| out.status != OutputStatus::Spent)
{ {
let key_id = keychain.derive_key_id(out.n_child).unwrap(); let commit = keychain.commit_with_key_index(out.value, out.n_child).unwrap();
let commit = keychain.commit_with_key_index(out.value, &key_id, out.n_child).unwrap();
commits.push(commit); commits.push(commit);
wallet_outputs.insert(commit, out.key_id.clone()); wallet_outputs.insert(commit, out.key_id.clone());
} }