cache key_id->derivation in the keychain (#253)

This commit is contained in:
AntiochP 2017-11-10 10:12:15 -05:00 committed by GitHub
parent b831192335
commit 8f33c7e0fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 24 deletions

View file

@ -14,6 +14,7 @@
use rand::{thread_rng, Rng};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use util::secp;
use util::secp::{Message, Secp256k1, Signature};
@ -49,6 +50,7 @@ pub struct Keychain {
secp: Secp256k1,
extkey: extkey::ExtendedKey,
key_overrides: HashMap<Identifier, SecretKey>,
key_derivation_cache: Arc<RwLock<HashMap<Identifier, u32>>>,
}
impl Keychain {
@ -57,7 +59,6 @@ impl Keychain {
}
// For tests and burn only, associate a key identifier with a known secret key.
//
pub fn burn_enabled(keychain: &Keychain, burn_key_id: &Identifier) -> Keychain {
let mut key_overrides = HashMap::new();
key_overrides.insert(
@ -77,6 +78,7 @@ impl Keychain {
secp: secp,
extkey: extkey,
key_overrides: HashMap::new(),
key_derivation_cache: Arc::new(RwLock::new(HashMap::new())),
};
Ok(keychain)
}
@ -94,35 +96,58 @@ impl Keychain {
Ok(key_id)
}
fn derived_key_search(&self, key_id: &Identifier, n_child: Option<u32>) -> Result<SecretKey, Error> {
fn derived_key(&self, key_id: &Identifier) -> Result<SecretKey, Error> {
trace!(LOGGER, "Derived Key by key_id: {}", key_id);
// first check our overrides and just return the key if we have one in there
if let Some(key) = self.key_overrides.get(key_id) {
trace!(LOGGER, "... Derived Key (using override) key_id: {}", key_id);
return Ok(*key);
}
trace!(LOGGER, "Derived Key key_id: {}", key_id);
// then check the derivation cache to see if we have previously derived this key
// if so use the derivation from the cache to derive the key
{
let cache = self.key_derivation_cache.read().unwrap();
if let Some(derivation) = cache.get(key_id) {
trace!(LOGGER, "... Derived Key (cache hit) key_id: {}, derivation: {}", key_id, derivation);
return Ok(self.derived_key_from_index(*derivation)?)
}
}
if let Some(n) = n_child{
let extkey = self.extkey.derive(&self.secp, n)?;
return Ok(extkey.key);
};
for i in 1..10000 {
// otherwise iterate over a large number of derivations looking for our key
// cache the resulting derivations by key_id for faster lookup later
// TODO - remove the 10k hard limit and be smarter about batching somehow
{
let mut cache = self.key_derivation_cache.write().unwrap();
for i in 1..10_000 {
let extkey = self.extkey.derive(&self.secp, i)?;
if extkey.identifier(&self.secp)? == *key_id {
let extkey_id = extkey.identifier(&self.secp)?;
if !cache.contains_key(&extkey_id) {
trace!(LOGGER, "... Derived Key (cache miss) key_id: {}, derivation: {}", extkey_id, extkey.n_child);
cache.insert(extkey_id.clone(), extkey.n_child);
}
if extkey_id == *key_id {
return Ok(extkey.key);
}
}
}
Err(Error::KeyDerivation(
format!("cannot find extkey for {:?}", key_id),
))
}
fn derived_key(&self, key_id: &Identifier) -> Result<SecretKey, Error> {
self.derived_key_search(key_id, None)
}
fn derived_key_from_index(&self, key_id: &Identifier, n_child:u32) -> Result<SecretKey, Error> {
self.derived_key_search(key_id, Some(n_child))
// if we know the derivation index we can just straight to deriving the key
fn derived_key_from_index(
&self,
derivation: u32,
) -> Result<SecretKey, Error> {
trace!(LOGGER, "Derived Key (fast) by derivation: {}", derivation);
let extkey = self.extkey.derive(&self.secp, derivation)?;
return Ok(extkey.key)
}
pub fn commit(&self, amount: u64, key_id: &Identifier) -> Result<Commitment, Error> {
@ -131,8 +156,12 @@ impl Keychain {
Ok(commit)
}
pub fn commit_with_key_index(&self, amount: u64, key_id: &Identifier, n_child: u32) -> Result<Commitment, Error> {
let skey = self.derived_key_from_index(key_id, n_child)?;
pub fn commit_with_key_index(
&self,
amount: u64,
derivation: u32,
) -> Result<Commitment, Error> {
let skey = self.derived_key_from_index(derivation)?;
let commit = self.secp.commit(amount, skey)?;
Ok(commit)
}

View file

@ -58,7 +58,7 @@ pub fn refresh_outputs(config: &WalletConfig, keychain: &Keychain) -> Result<(),
let mut commits: Vec<pedersen::Commitment> = vec![];
// build a local map of wallet outputs by commits
// and a list of outputs we wantot query the node for
// and a list of outputs we want to query the node for
let _ = WalletData::read_wallet(&config.data_file_dir, |wallet_data| {
for out in wallet_data
.outputs
@ -66,8 +66,7 @@ pub fn refresh_outputs(config: &WalletConfig, keychain: &Keychain) -> Result<(),
.filter(|out| out.root_key_id == keychain.root_key_id())
.filter(|out| out.status != OutputStatus::Spent)
{
let key_id = keychain.derive_key_id(out.n_child).unwrap();
let commit = keychain.commit_with_key_index(out.value, &key_id, out.n_child).unwrap();
let commit = keychain.commit_with_key_index(out.value, out.n_child).unwrap();
commits.push(commit);
wallet_outputs.insert(commit, out.key_id.clone());
}