basic reorg protection

(cherry picked from commit e3696ed73c2012f20205b98099f397ad799fcff4)
This commit is contained in:
scilio 2024-04-02 17:27:08 -04:00
parent 389581d759
commit d1ae6863f8
4 changed files with 1422 additions and 1272 deletions

View file

@ -16,9 +16,10 @@ use mwixnet::node::GrinNode;
use mwixnet::store::StoreError;
use rpassword;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::thread::{sleep, spawn};
use std::time::Duration;
use grin_core::core::Transaction;
#[macro_use]
extern crate clap;
@ -166,7 +167,7 @@ fn real_main() -> Result<(), Box<dyn std::error::Error>> {
.threaded_scheduler()
.enable_all()
.build()?;
if let Err(e) = rt.block_on(node.async_get_chain_height()) {
if let Err(e) = rt.block_on(node.async_get_chain_tip()) {
eprintln!("Node communication failure. Is node listening?");
return Err(e.into());
};
@ -262,6 +263,9 @@ fn real_main() -> Result<(), Box<dyn std::error::Error>> {
let close_handle = http_server.close_handle();
let round_handle = spawn(move || {
let mut secs = 0;
let prev_tx = Arc::new(Mutex::new(None));
let server = swap_server.clone();
loop {
if stop_state.is_stopped() {
close_handle.close();
@ -272,9 +276,34 @@ fn real_main() -> Result<(), Box<dyn std::error::Error>> {
secs = (secs + 1) % server_config.interval_s;
if secs == 0 {
let server = swap_server.clone();
rt.spawn(async move { server.lock().await.execute_round().await });
//let _ = swap_server.lock().unwrap().execute_round();
let prev_tx_clone = prev_tx.clone();
let server_clone = server.clone();
rt.spawn(async move {
let result = server_clone.lock().await.execute_round().await;
let mut prev_tx_lock = prev_tx_clone.lock().unwrap();
*prev_tx_lock = match result {
Ok(Some(tx)) => Some(tx),
_ => None,
};
});
} else if secs % 30 == 0 {
let prev_tx_clone = prev_tx.clone();
let server_clone = server.clone();
rt.spawn(async move {
let tx_option = {
let prev_tx_lock = prev_tx_clone.lock().unwrap();
prev_tx_lock.clone()
}; // Lock is dropped here
if let Some(tx) = tx_option {
let result = server_clone.lock().await.check_reorg(tx).await;
let mut prev_tx_lock = prev_tx_clone.lock().unwrap();
*prev_tx_lock = match result {
Ok(Some(tx)) => Some(tx),
_ => None,
};
}
});
}
}
});

View file

@ -4,13 +4,14 @@ use grin_api::json_rpc::{build_request, Request, Response};
use grin_api::{client, LocatedTxKernel};
use grin_api::{OutputPrintable, OutputType, Tip};
use grin_core::consensus::COINBASE_MATURITY;
use grin_core::core::{Input, OutputFeatures, Transaction};
use grin_core::core::{Committed, Input, OutputFeatures, Transaction};
use grin_util::ToHex;
use async_trait::async_trait;
use serde_json::json;
use std::net::SocketAddr;
use std::sync::Arc;
use grin_core::core::hash::Hash;
use thiserror::Error;
#[async_trait]
@ -21,8 +22,8 @@ pub trait GrinNode: Send + Sync {
output_commit: &Commitment,
) -> Result<Option<OutputPrintable>, NodeError>;
/// Gets the height of the chain tip
async fn async_get_chain_height(&self) -> Result<u64, NodeError>;
/// Gets the height and hash of the chain tip
async fn async_get_chain_tip(&self) -> Result<(u64, Hash), NodeError>;
/// Posts a transaction to the grin node
async fn async_post_tx(&self, tx: &Transaction) -> Result<(), NodeError>;
@ -108,6 +109,23 @@ pub async fn async_build_input(
Ok(None)
}
pub async fn async_is_tx_valid(node: &Arc<dyn GrinNode>, tx: &Transaction) -> Result<bool, NodeError> {
let next_block_height = node.async_get_chain_tip().await?.0 + 1;
for input_commit in &tx.inputs_committed() {
if !async_is_spendable(&node, &input_commit, next_block_height).await? {
return Ok(false);
}
}
for output_commit in &tx.outputs_committed() {
if async_is_unspent(&node, &output_commit).await? {
return Ok(false);
}
}
Ok(true)
}
/// HTTP (JSON-RPC) implementation of the 'GrinNode' trait
#[derive(Clone)]
pub struct HttpGrinNode {
@ -176,7 +194,7 @@ impl GrinNode for HttpGrinNode {
Ok(Some(outputs[0].clone()))
}
async fn async_get_chain_height(&self) -> Result<u64, NodeError> {
async fn async_get_chain_tip(&self) -> Result<(u64, Hash), NodeError> {
let params = json!([]);
let tip_json = self
.async_send_request::<serde_json::Value>("get_tip", &params)
@ -184,7 +202,7 @@ impl GrinNode for HttpGrinNode {
let tip =
serde_json::from_value::<Tip>(tip_json).map_err(NodeError::DecodeResponseError)?;
Ok(tip.height)
Ok((tip.height, Hash::from_hex(tip.last_block_pushed.as_str()).unwrap()))
}
async fn async_post_tx(&self, tx: &Transaction) -> Result<(), NodeError> {
@ -226,6 +244,7 @@ pub mod mock {
use grin_onion::crypto::secp::Commitment;
use std::collections::HashMap;
use std::sync::RwLock;
use grin_core::core::hash::Hash;
/// Implementation of 'GrinNode' trait that mocks a grin node instance.
/// Use only for testing purposes.
@ -299,8 +318,8 @@ pub mod mock {
Ok(None)
}
async fn async_get_chain_height(&self) -> Result<u64, NodeError> {
Ok(100)
async fn async_get_chain_tip(&self) -> Result<(u64, Hash), NodeError> {
Ok((100, Hash::default()))
}
async fn async_post_tx(&self, tx: &Transaction) -> Result<(), NodeError> {

View file

@ -1,12 +1,12 @@
use crate::client::MixClient;
use crate::config::ServerConfig;
use crate::node::{self, GrinNode};
use crate::store::{StoreError, SwapData, SwapStatus, SwapStore};
use crate::store::{StoreError, SwapData, SwapStatus, SwapStore, SwapTx};
use crate::tx;
use crate::wallet::Wallet;
use async_trait::async_trait;
use grin_core::core::{Input, Output, OutputFeatures, Transaction, TransactionBody};
use grin_core::core::{Committed, Input, Output, OutputFeatures, Transaction, TransactionBody};
use grin_core::global::DEFAULT_ACCEPT_FEE_BASE;
use grin_onion::crypto::comsig::ComSignature;
use grin_onion::crypto::secp::{Commitment, Secp256k1, SecretKey};
@ -76,6 +76,11 @@ pub trait SwapServer: Send + Sync {
/// Iterate through all saved submissions, filter out any inputs that are no longer spendable,
/// and assemble the coinswap transaction, posting the transaction to the configured node.
async fn execute_round(&self) -> Result<Option<Arc<Transaction>>, SwapError>;
/// Verify the previous swap transaction is in the active chain or mempool.
/// If it's not, rebroacast the transaction if it's still valid.
/// If the transaction is no longer valid, perform the swap again.
async fn check_reorg(&self, tx: Arc<Transaction>) -> Result<Option<Arc<Transaction>>, SwapError>;
}
/// The standard MWixnet server implementation
@ -134,6 +139,97 @@ impl SwapServerImpl {
false
}
async fn async_execute_round(&self, store: &SwapStore, mut swaps: Vec<SwapData>) -> Result<Option<Arc<Transaction>>, SwapError> {
swaps.sort_by(|a, b| a.output_commit.partial_cmp(&b.output_commit).unwrap());
if swaps.len() == 0 {
return Ok(None);
}
let (filtered, failed, offset, outputs, kernels) = if let Some(client) = &self.next_server {
// Call next mix server
let onions = swaps.iter().map(|s| s.onion.clone()).collect();
let mixed = client
.mix_outputs(&onions)
.await
.map_err(|e| SwapError::ClientError(e.to_string()))?;
// Filter out failed entries
let kept_indices = HashSet::<_>::from_iter(mixed.indices.clone());
let filtered = swaps
.iter()
.enumerate()
.filter(|(i, _)| kept_indices.contains(i))
.map(|(_, j)| j.clone())
.collect();
let failed = swaps
.iter()
.enumerate()
.filter(|(i, _)| !kept_indices.contains(i))
.map(|(_, j)| j.clone())
.collect();
(
filtered,
failed,
mixed.components.offset,
mixed.components.outputs,
mixed.components.kernels,
)
} else {
// Build plain outputs for each swap entry
let outputs: Vec<Output> = swaps
.iter()
.map(|s| {
Output::new(
OutputFeatures::Plain,
s.output_commit,
s.rangeproof.unwrap(),
)
})
.collect();
(swaps, Vec::new(), ZERO_KEY, outputs, Vec::new())
};
let fees_paid: u64 = filtered.iter().map(|s| s.fee).sum();
let inputs: Vec<Input> = filtered.iter().map(|s| s.input).collect();
let output_excesses: Vec<SecretKey> = filtered.iter().map(|s| s.excess.clone()).collect();
let tx = tx::async_assemble_tx(
&self.wallet,
&inputs,
&outputs,
&kernels,
self.get_fee_base(),
fees_paid,
&offset,
&output_excesses,
)
.await?;
let chain_tip = self.node.async_get_chain_tip().await?;
self.node.async_post_tx(&tx).await?;
store.save_swap_tx(&SwapTx { tx: tx.clone(), chain_tip })?;
// Update status to in process
let kernel_commit = tx.kernels().first().unwrap().excess;
for mut swap in filtered {
swap.status = SwapStatus::InProcess { kernel_commit };
store.save_swap(&swap, true)?;
}
// Update status of failed swaps
for mut swap in failed {
swap.status = SwapStatus::Failed;
store.save_swap(&swap, true)?;
}
Ok(Some(Arc::new(tx)))
}
}
#[async_trait]
@ -212,7 +308,7 @@ impl SwapServer for SwapServerImpl {
}
async fn execute_round(&self) -> Result<Option<Arc<Transaction>>, SwapError> {
let next_block_height = self.node.async_get_chain_height().await? + 1;
let next_block_height = self.node.async_get_chain_tip().await?.0 + 1;
let locked_store = self.store.lock().await;
let swaps: Vec<SwapData> = locked_store
@ -226,91 +322,39 @@ impl SwapServer for SwapServerImpl {
}
}
spendable.sort_by(|a, b| a.output_commit.partial_cmp(&b.output_commit).unwrap());
if spendable.len() == 0 {
return Ok(None);
self.async_execute_round(&locked_store, swaps).await
}
let (filtered, failed, offset, outputs, kernels) = if let Some(client) = &self.next_server {
// Call next mix server
let onions = spendable.iter().map(|s| s.onion.clone()).collect();
let mixed = client
.mix_outputs(&onions)
.await
.map_err(|e| SwapError::ClientError(e.to_string()))?;
// Filter out failed entries
let kept_indices = HashSet::<_>::from_iter(mixed.indices.clone());
let filtered = spendable
.iter()
.enumerate()
.filter(|(i, _)| kept_indices.contains(i))
.map(|(_, j)| j.clone())
.collect();
let failed = spendable
.iter()
.enumerate()
.filter(|(i, _)| !kept_indices.contains(i))
.map(|(_, j)| j.clone())
.collect();
(
filtered,
failed,
mixed.components.offset,
mixed.components.outputs,
mixed.components.kernels,
)
} else {
// Build plain outputs for each swap entry
let outputs: Vec<Output> = spendable
.iter()
.map(|s| {
Output::new(
OutputFeatures::Plain,
s.output_commit,
s.rangeproof.unwrap(),
)
})
.collect();
(spendable, Vec::new(), ZERO_KEY, outputs, Vec::new())
};
let fees_paid: u64 = filtered.iter().map(|s| s.fee).sum();
let inputs: Vec<Input> = filtered.iter().map(|s| s.input).collect();
let output_excesses: Vec<SecretKey> = filtered.iter().map(|s| s.excess.clone()).collect();
let tx = tx::async_assemble_tx(
&self.wallet,
&inputs,
&outputs,
&kernels,
self.get_fee_base(),
fees_paid,
&offset,
&output_excesses,
)
.await?;
async fn check_reorg(&self, tx: Arc<Transaction>) -> Result<Option<Arc<Transaction>>, SwapError> {
let excess = tx.kernels().first().unwrap().excess;
if let Ok(swap_tx) = self.store.lock().await.get_swap_tx(&excess) {
// If kernel is in active chain, return tx
if self.node.async_get_kernel(&excess, Some(swap_tx.chain_tip.0), None).await?.is_some() {
return Ok(Some(tx));
}
// If transaction is still valid, rebroadcast and return tx
if node::async_is_tx_valid(&self.node, &tx).await? {
self.node.async_post_tx(&tx).await?;
// Update status to in process
let kernel_commit = tx.kernels().first().unwrap().excess;
for mut swap in filtered {
swap.status = SwapStatus::InProcess { kernel_commit };
locked_store.save_swap(&swap, true)?;
return Ok(Some(tx));
}
// Update status of failed swaps
for mut swap in failed {
swap.status = SwapStatus::Failed;
locked_store.save_swap(&swap, true)?;
// Collect all swaps based on tx's inputs, and execute_round with those swaps
let next_block_height = self.node.async_get_chain_tip().await?.0 + 1;
let locked_store = self.store.lock().await;
let mut swaps = Vec::new();
for input_commit in &tx.inputs_committed() {
if let Ok(swap) = locked_store.get_swap(&input_commit) {
if self.async_is_spendable(next_block_height, &swap).await {
swaps.push(swap);
}
}
}
Ok(Some(Arc::new(tx)))
self.async_execute_round(&locked_store, swaps).await
} else {
Err(SwapError::UnknownError("Swap transaction not found".to_string())) // TODO: Create SwapError enum value
}
}
}
@ -323,6 +367,7 @@ pub mod mock {
use grin_onion::crypto::comsig::ComSignature;
use grin_onion::onion::Onion;
use std::collections::HashMap;
use std::sync::Arc;
pub struct MockSwapServer {
errors: HashMap<Onion, SwapError>,
@ -353,6 +398,10 @@ pub mod mock {
async fn execute_round(&self) -> Result<Option<std::sync::Arc<Transaction>>, SwapError> {
Ok(None)
}
async fn check_reorg(&self, tx: Arc<Transaction>) -> Result<Option<Arc<Transaction>>, SwapError> {
Ok(Some(tx))
}
}
}
@ -842,7 +891,7 @@ mod tests {
assert_eq!(
Err(SwapError::FeeTooLow {
minimum_fee: 12_500_000,
actual_fee: fee as u64
actual_fee: fee as u64,
}),
result
);

View file

@ -3,7 +3,7 @@ use grin_onion::crypto::secp::{self, Commitment, RangeProof, SecretKey};
use grin_onion::onion::Onion;
use grin_onion::util::{read_optional, write_optional};
use grin_core::core::Input;
use grin_core::core::{Input, Transaction};
use grin_core::ser::{
self, DeserializationMode, ProtocolVersion, Readable, Reader, Writeable, Writer,
};
@ -14,9 +14,12 @@ use thiserror::Error;
const DB_NAME: &str = "swap";
const STORE_SUBPATH: &str = "swaps";
const CURRENT_VERSION: u8 = 0;
const CURRENT_SWAP_VERSION: u8 = 0;
const SWAP_PREFIX: u8 = b'S';
const CURRENT_TX_VERSION: u8 = 0;
const TX_PREFIX: u8 = b'T';
/// Swap statuses
#[derive(Clone, Debug, PartialEq)]
pub enum SwapStatus {
@ -104,7 +107,7 @@ pub struct SwapData {
impl Writeable for SwapData {
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ser::Error> {
writer.write_u8(CURRENT_VERSION)?;
writer.write_u8(CURRENT_SWAP_VERSION)?;
writer.write_fixed_bytes(&self.excess)?;
writer.write_fixed_bytes(&self.output_commit)?;
write_optional(writer, &self.rangeproof)?;
@ -120,7 +123,7 @@ impl Writeable for SwapData {
impl Readable for SwapData {
fn read<R: Reader>(reader: &mut R) -> Result<SwapData, ser::Error> {
let version = reader.read_u8()?;
if version != CURRENT_VERSION {
if version != CURRENT_SWAP_VERSION {
return Err(ser::Error::UnsupportedProtocolVersion);
}
@ -143,6 +146,41 @@ impl Readable for SwapData {
}
}
/// A transaction created as part of a swap round.
#[derive(Clone, Debug, PartialEq)]
pub struct SwapTx {
pub tx: Transaction,
pub chain_tip: (u64, Hash),
// TODO: Include status
}
impl Writeable for SwapTx {
fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ser::Error> {
writer.write_u8(CURRENT_TX_VERSION)?;
self.tx.write(writer)?;
writer.write_u64(self.chain_tip.0)?;
self.chain_tip.1.write(writer)?;
Ok(())
}
}
impl Readable for SwapTx {
fn read<R: Reader>(reader: &mut R) -> Result<SwapTx, ser::Error> {
let version = reader.read_u8()?;
if version != CURRENT_TX_VERSION {
return Err(ser::Error::UnsupportedProtocolVersion);
}
let tx = Transaction::read(reader)?;
let height = reader.read_u64()?;
let block_hash = Hash::read(reader)?;
Ok(SwapTx {
tx,
chain_tip: (height, block_hash),
})
}
}
/// Storage facility for swap data.
pub struct SwapStore {
db: Store,
@ -245,6 +283,21 @@ impl SwapStore {
pub fn get_swap(&self, input_commit: &Commitment) -> Result<SwapData, StoreError> {
self.read(SWAP_PREFIX, input_commit)
}
/// Saves a swap transaction to the database
pub fn save_swap_tx(&self, s: &SwapTx) -> Result<(), StoreError> {
let data = ser::ser_vec(&s, ProtocolVersion::local())?;
self
.write(TX_PREFIX, &s.tx.kernels().first().unwrap().excess, &data, true)
.map_err(StoreError::WriteError)?;
Ok(())
}
/// Reads a swap tx from the database
pub fn get_swap_tx(&self, kernel_excess: &Commitment) -> Result<SwapTx, StoreError> {
self.read(TX_PREFIX, kernel_excess)
}
}
#[cfg(test)]