diff --git a/api/src/handlers.rs b/api/src/handlers.rs index 8e3bf499d..37c4b6487 100644 --- a/api/src/handlers.rs +++ b/api/src/handlers.rs @@ -879,22 +879,19 @@ pub fn build_router( let mut router = Router::new(); // example how we can use midlleware - router.add_route( - "/v1/", - Box::new(LoggingMiddleware::new(Box::new(index_handler))), - )?; - router.add_route("/v1/blocks/*", Box::new(block_handler))?; - router.add_route("/v1/headers/*", Box::new(header_handler))?; - router.add_route("/v1/chain", Box::new(chain_tip_handler))?; - router.add_route("/v1/chain/outputs/*", Box::new(output_handler))?; - router.add_route("/v1/chain/compact", Box::new(chain_compact_handler))?; - router.add_route("/v1/chain/validate", Box::new(chain_validation_handler))?; - router.add_route("/v1/txhashset/*", Box::new(txhashset_handler))?; - router.add_route("/v1/status", Box::new(status_handler))?; - router.add_route("/v1/pool", Box::new(pool_info_handler))?; - router.add_route("/v1/pool/push", Box::new(pool_push_handler))?; - router.add_route("/v1/peers/all", Box::new(peers_all_handler))?; - router.add_route("/v1/peers/connected", Box::new(peers_connected_handler))?; - router.add_route("/v1/peers/**", Box::new(peer_handler))?; + router.add_route("/v1/", Arc::new(index_handler))?; + router.add_route("/v1/blocks/*", Arc::new(block_handler))?; + router.add_route("/v1/headers/*", Arc::new(header_handler))?; + router.add_route("/v1/chain", Arc::new(chain_tip_handler))?; + router.add_route("/v1/chain/outputs/*", Arc::new(output_handler))?; + router.add_route("/v1/chain/compact", Arc::new(chain_compact_handler))?; + router.add_route("/v1/chain/validate", Arc::new(chain_validation_handler))?; + router.add_route("/v1/txhashset/*", Arc::new(txhashset_handler))?; + router.add_route("/v1/status", Arc::new(status_handler))?; + router.add_route("/v1/pool", Arc::new(pool_info_handler))?; + router.add_route("/v1/pool/push", Arc::new(pool_push_handler))?; + router.add_route("/v1/peers/all", Arc::new(peers_all_handler))?; + router.add_route("/v1/peers/connected", Arc::new(peers_connected_handler))?; + router.add_route("/v1/peers/**", Arc::new(peer_handler))?; Ok(router) } diff --git a/api/src/rest.rs b/api/src/rest.rs index 7633bf37f..75d4d6aa6 100644 --- a/api/src/rest.rs +++ b/api/src/rest.rs @@ -199,20 +199,15 @@ impl ApiServer { } } -// Simple example of middleware -pub struct LoggingMiddleware { - next: HandlerObj, -} - -impl LoggingMiddleware { - pub fn new(next: HandlerObj) -> LoggingMiddleware { - LoggingMiddleware { next } - } -} +pub struct LoggingMiddleware {} impl Handler for LoggingMiddleware { - fn call(&self, req: Request) -> ResponseFuture { + fn call( + &self, + req: Request, + mut handlers: Box>, + ) -> ResponseFuture { debug!(LOGGER, "REST call: {} {}", req.method(), req.uri().path()); - self.next.call(req) + handlers.next().unwrap().call(req, handlers) } } diff --git a/api/src/router.rs b/api/src/router.rs index 47555f0cf..edf1604ca 100644 --- a/api/src/router.rs +++ b/api/src/router.rs @@ -51,7 +51,11 @@ pub trait Handler { not_found() } - fn call(&self, req: Request) -> ResponseFuture { + fn call( + &self, + req: Request, + mut _handlers: Box>, + ) -> ResponseFuture { match req.method() { &Method::GET => self.get(req), &Method::POST => self.post(req), @@ -66,6 +70,7 @@ pub trait Handler { } } } + #[derive(Fail, Debug)] pub enum RouterError { #[fail(display = "Route already exists")] @@ -86,14 +91,15 @@ struct NodeId(usize); const MAX_CHILDREN: usize = 16; -pub type HandlerObj = Box; +pub type HandlerObj = Arc; #[derive(Clone)] -struct Node { +pub struct Node { key: u64, - value: Option>, + value: Option, children: [NodeId; MAX_CHILDREN], children_count: usize, + mws: Option>, } impl Router { @@ -104,6 +110,10 @@ impl Router { Router { nodes } } + pub fn add_middleware(&mut self, mw: HandlerObj) { + self.node_mut(NodeId(0)).add_middleware(mw); + } + fn root(&self) -> NodeId { NodeId(0) } @@ -134,7 +144,11 @@ impl Router { id } - pub fn add_route(&mut self, route: &'static str, value: HandlerObj) -> Result<(), RouterError> { + pub fn add_route( + &mut self, + route: &'static str, + value: HandlerObj, + ) -> Result<&mut Node, RouterError> { let keys = generate_path(route); let mut node_id = self.root(); for key in keys { @@ -144,23 +158,34 @@ impl Router { } match self.node(node_id).value() { None => { - self.node_mut(node_id).set_value(value); - Ok(()) + let node = self.node_mut(node_id); + node.set_value(value); + Ok(node) } Some(_) => Err(RouterError::RouteAlreadyExists), } } - pub fn get(&self, path: &str) -> Result, RouterError> { + pub fn get(&self, path: &str) -> Result, RouterError> { let keys = generate_path(path); + let mut handlers = vec![]; let mut node_id = self.root(); + collect_node_middleware(&mut handlers, self.node(node_id)); for key in keys { node_id = self.find(node_id, key).ok_or(RouterError::RouteNotFound)?; - if self.node(node_id).key == *WILDCARD_STOP_HASH { + let node = self.node(node_id); + collect_node_middleware(&mut handlers, self.node(node_id)); + if node.key == *WILDCARD_STOP_HASH { break; } } - self.node(node_id).value().ok_or(RouterError::NoValue) + + if let Some(h) = self.node(node_id).value() { + handlers.push(h); + Ok(handlers.into_iter()) + } else { + Err(RouterError::NoValue) + } } } @@ -173,7 +198,10 @@ impl Service for Router { fn call(&mut self, req: Request) -> Self::Future { match self.get(req.uri().path()) { Err(_) => not_found(), - Ok(h) => h.call(req), + Ok(mut handlers) => match handlers.next() { + None => not_found(), + Some(h) => h.call(req, Box::new(handlers)), + }, } } } @@ -191,16 +219,27 @@ impl NewService for Router { } impl Node { - fn new(key: u64, value: Option>) -> Node { + fn new(key: u64, value: Option) -> Node { Node { key, value, children: [NodeId(0); MAX_CHILDREN], children_count: 0, + mws: None, } } - fn value(&self) -> Option> { + pub fn add_middleware(&mut self, mw: HandlerObj) -> &mut Node { + if self.mws.is_none() { + self.mws = Some(vec![]); + } + if let Some(ref mut mws) = self.mws { + mws.push(mw.clone()); + } + self + } + + fn value(&self) -> Option { match &self.value { None => None, Some(v) => Some(v.clone()), @@ -208,7 +247,7 @@ impl Node { } fn set_value(&mut self, value: HandlerObj) { - self.value = Some(Arc::new(value)); + self.value = Some(value); } fn add_child(&mut self, child_id: NodeId) { @@ -240,6 +279,14 @@ fn generate_path(route: &str) -> Vec { .collect() } +fn collect_node_middleware(handlers: &mut Vec, node: &Node) { + if let Some(ref mws) = node.mws { + for mw in mws { + handlers.push(mw.clone()); + } + } +} + #[cfg(test)] mod tests { @@ -263,30 +310,17 @@ mod tests { #[test] fn test_add_route() { let mut routes = Router::new(); + let h1 = Arc::new(HandlerImpl(1)); + let h2 = Arc::new(HandlerImpl(2)); + let h3 = Arc::new(HandlerImpl(3)); + routes.add_route("/v1/users", h1.clone()).unwrap(); + assert!(routes.add_route("/v1/users", h2.clone()).is_err()); + routes.add_route("/v1/users/xxx", h3.clone()).unwrap(); + routes.add_route("/v1/users/xxx/yyy", h3.clone()).unwrap(); + routes.add_route("/v1/zzz/*", h3.clone()).unwrap(); + assert!(routes.add_route("/v1/zzz/ccc", h2.clone()).is_err()); routes - .add_route("/v1/users", Box::new(HandlerImpl(1))) - .unwrap(); - assert!( - routes - .add_route("/v1/users", Box::new(HandlerImpl(2))) - .is_err() - ); - routes - .add_route("/v1/users/xxx", Box::new(HandlerImpl(3))) - .unwrap(); - routes - .add_route("/v1/users/xxx/yyy", Box::new(HandlerImpl(3))) - .unwrap(); - routes - .add_route("/v1/zzz/*", Box::new(HandlerImpl(3))) - .unwrap(); - assert!( - routes - .add_route("/v1/zzz/ccc", Box::new(HandlerImpl(2))) - .is_err() - ); - routes - .add_route("/v1/zzz/*/zzz", Box::new(HandlerImpl(6))) + .add_route("/v1/zzz/*/zzz", Arc::new(HandlerImpl(6))) .unwrap(); } @@ -294,19 +328,19 @@ mod tests { fn test_get() { let mut routes = Router::new(); routes - .add_route("/v1/users", Box::new(HandlerImpl(101))) + .add_route("/v1/users", Arc::new(HandlerImpl(101))) .unwrap(); routes - .add_route("/v1/users/xxx", Box::new(HandlerImpl(103))) + .add_route("/v1/users/xxx", Arc::new(HandlerImpl(103))) .unwrap(); routes - .add_route("/v1/users/xxx/yyy", Box::new(HandlerImpl(103))) + .add_route("/v1/users/xxx/yyy", Arc::new(HandlerImpl(103))) .unwrap(); routes - .add_route("/v1/zzz/*", Box::new(HandlerImpl(103))) + .add_route("/v1/zzz/*", Arc::new(HandlerImpl(103))) .unwrap(); routes - .add_route("/v1/zzz/*/zzz", Box::new(HandlerImpl(106))) + .add_route("/v1/zzz/*/zzz", Arc::new(HandlerImpl(106))) .unwrap(); let call_handler = |url| { @@ -314,6 +348,8 @@ mod tests { let task = routes .get(url) .unwrap() + .next() + .unwrap() .get(Request::new(Body::default())) .and_then(|resp| ok(resp.status().as_u16())); event_loop.run(task).unwrap() diff --git a/api/tests/rest.rs b/api/tests/rest.rs index b6d9ca618..b3447d426 100644 --- a/api/tests/rest.rs +++ b/api/tests/rest.rs @@ -5,6 +5,8 @@ extern crate hyper; use api::*; use hyper::{Body, Request}; use std::net::SocketAddr; +use std::sync::atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT}; +use std::sync::Arc; use std::{thread, time}; struct IndexHandler { @@ -19,13 +21,41 @@ impl Handler for IndexHandler { } } +pub struct CounterMiddleware { + counter: AtomicUsize, +} + +impl CounterMiddleware { + fn new() -> CounterMiddleware { + CounterMiddleware { + counter: ATOMIC_USIZE_INIT, + } + } + + fn value(&self) -> usize { + self.counter.load(Ordering::SeqCst) + } +} + +impl Handler for CounterMiddleware { + fn call( + &self, + req: Request, + mut handlers: Box>, + ) -> ResponseFuture { + self.counter.fetch_add(1, Ordering::SeqCst); + handlers.next().unwrap().call(req, handlers) + } +} + fn build_router() -> Router { let route_list = vec!["get blocks".to_string(), "get chain".to_string()]; let index_handler = IndexHandler { list: route_list }; let mut router = Router::new(); router - .add_route("/v1/*", Box::new(index_handler)) - .expect("add_route failed"); + .add_route("/v1/*", Arc::new(index_handler)) + .expect("add_route failed") + .add_middleware(Arc::new(LoggingMiddleware {})); router } @@ -33,13 +63,17 @@ fn build_router() -> Router { fn test_start_api() { util::init_test_logger(); let mut server = ApiServer::new(); - let router = build_router(); + let mut router = build_router(); + let counter = Arc::new(CounterMiddleware::new()); + // add middleware to the root + router.add_middleware(counter.clone()); let server_addr = "127.0.0.1:14434"; let addr: SocketAddr = server_addr.parse().expect("unable to parse server address"); assert!(server.start(addr, router)); let url = format!("http://{}/v1/", server_addr); let index = api::client::get::>(url.as_str()).unwrap(); assert_eq!(index.len(), 2); + assert_eq!(counter.value(), 1); assert!(server.stop()); thread::sleep(time::Duration::from_millis(1_000)); } diff --git a/wallet/src/libwallet/controller.rs b/wallet/src/libwallet/controller.rs index b5fa4c5e9..2ae2413f1 100644 --- a/wallet/src/libwallet/controller.rs +++ b/wallet/src/libwallet/controller.rs @@ -80,7 +80,7 @@ where let mut router = Router::new(); router - .add_route("/v1/wallet/owner/**", Box::new(api_handler)) + .add_route("/v1/wallet/owner/**", Arc::new(api_handler)) .map_err(|_| ErrorKind::GenericError("Router failed to add route".to_string()))?; let mut apis = ApiServer::new(); @@ -102,7 +102,7 @@ where let mut router = Router::new(); router - .add_route("/v1/wallet/foreign/**", Box::new(api_handler)) + .add_route("/v1/wallet/foreign/**", Arc::new(api_handler)) .map_err(|_| ErrorKind::GenericError("Router failed to add route".to_string()))?; let mut apis = ApiServer::new();