diff --git a/api/src/handlers.rs b/api/src/handlers.rs
index 8e3bf499d..37c4b6487 100644
--- a/api/src/handlers.rs
+++ b/api/src/handlers.rs
@@ -879,22 +879,19 @@ pub fn build_router(
let mut router = Router::new();
// example how we can use midlleware
- router.add_route(
- "/v1/",
- Box::new(LoggingMiddleware::new(Box::new(index_handler))),
- )?;
- router.add_route("/v1/blocks/*", Box::new(block_handler))?;
- router.add_route("/v1/headers/*", Box::new(header_handler))?;
- router.add_route("/v1/chain", Box::new(chain_tip_handler))?;
- router.add_route("/v1/chain/outputs/*", Box::new(output_handler))?;
- router.add_route("/v1/chain/compact", Box::new(chain_compact_handler))?;
- router.add_route("/v1/chain/validate", Box::new(chain_validation_handler))?;
- router.add_route("/v1/txhashset/*", Box::new(txhashset_handler))?;
- router.add_route("/v1/status", Box::new(status_handler))?;
- router.add_route("/v1/pool", Box::new(pool_info_handler))?;
- router.add_route("/v1/pool/push", Box::new(pool_push_handler))?;
- router.add_route("/v1/peers/all", Box::new(peers_all_handler))?;
- router.add_route("/v1/peers/connected", Box::new(peers_connected_handler))?;
- router.add_route("/v1/peers/**", Box::new(peer_handler))?;
+ router.add_route("/v1/", Arc::new(index_handler))?;
+ router.add_route("/v1/blocks/*", Arc::new(block_handler))?;
+ router.add_route("/v1/headers/*", Arc::new(header_handler))?;
+ router.add_route("/v1/chain", Arc::new(chain_tip_handler))?;
+ router.add_route("/v1/chain/outputs/*", Arc::new(output_handler))?;
+ router.add_route("/v1/chain/compact", Arc::new(chain_compact_handler))?;
+ router.add_route("/v1/chain/validate", Arc::new(chain_validation_handler))?;
+ router.add_route("/v1/txhashset/*", Arc::new(txhashset_handler))?;
+ router.add_route("/v1/status", Arc::new(status_handler))?;
+ router.add_route("/v1/pool", Arc::new(pool_info_handler))?;
+ router.add_route("/v1/pool/push", Arc::new(pool_push_handler))?;
+ router.add_route("/v1/peers/all", Arc::new(peers_all_handler))?;
+ router.add_route("/v1/peers/connected", Arc::new(peers_connected_handler))?;
+ router.add_route("/v1/peers/**", Arc::new(peer_handler))?;
Ok(router)
}
diff --git a/api/src/rest.rs b/api/src/rest.rs
index 7633bf37f..75d4d6aa6 100644
--- a/api/src/rest.rs
+++ b/api/src/rest.rs
@@ -199,20 +199,15 @@ impl ApiServer {
}
}
-// Simple example of middleware
-pub struct LoggingMiddleware {
- next: HandlerObj,
-}
-
-impl LoggingMiddleware {
- pub fn new(next: HandlerObj) -> LoggingMiddleware {
- LoggingMiddleware { next }
- }
-}
+pub struct LoggingMiddleware {}
impl Handler for LoggingMiddleware {
- fn call(&self, req: Request
) -> ResponseFuture {
+ fn call(
+ &self,
+ req: Request,
+ mut handlers: Box>,
+ ) -> ResponseFuture {
debug!(LOGGER, "REST call: {} {}", req.method(), req.uri().path());
- self.next.call(req)
+ handlers.next().unwrap().call(req, handlers)
}
}
diff --git a/api/src/router.rs b/api/src/router.rs
index 47555f0cf..edf1604ca 100644
--- a/api/src/router.rs
+++ b/api/src/router.rs
@@ -51,7 +51,11 @@ pub trait Handler {
not_found()
}
- fn call(&self, req: Request) -> ResponseFuture {
+ fn call(
+ &self,
+ req: Request,
+ mut _handlers: Box>,
+ ) -> ResponseFuture {
match req.method() {
&Method::GET => self.get(req),
&Method::POST => self.post(req),
@@ -66,6 +70,7 @@ pub trait Handler {
}
}
}
+
#[derive(Fail, Debug)]
pub enum RouterError {
#[fail(display = "Route already exists")]
@@ -86,14 +91,15 @@ struct NodeId(usize);
const MAX_CHILDREN: usize = 16;
-pub type HandlerObj = Box;
+pub type HandlerObj = Arc;
#[derive(Clone)]
-struct Node {
+pub struct Node {
key: u64,
- value: Option>,
+ value: Option,
children: [NodeId; MAX_CHILDREN],
children_count: usize,
+ mws: Option>,
}
impl Router {
@@ -104,6 +110,10 @@ impl Router {
Router { nodes }
}
+ pub fn add_middleware(&mut self, mw: HandlerObj) {
+ self.node_mut(NodeId(0)).add_middleware(mw);
+ }
+
fn root(&self) -> NodeId {
NodeId(0)
}
@@ -134,7 +144,11 @@ impl Router {
id
}
- pub fn add_route(&mut self, route: &'static str, value: HandlerObj) -> Result<(), RouterError> {
+ pub fn add_route(
+ &mut self,
+ route: &'static str,
+ value: HandlerObj,
+ ) -> Result<&mut Node, RouterError> {
let keys = generate_path(route);
let mut node_id = self.root();
for key in keys {
@@ -144,23 +158,34 @@ impl Router {
}
match self.node(node_id).value() {
None => {
- self.node_mut(node_id).set_value(value);
- Ok(())
+ let node = self.node_mut(node_id);
+ node.set_value(value);
+ Ok(node)
}
Some(_) => Err(RouterError::RouteAlreadyExists),
}
}
- pub fn get(&self, path: &str) -> Result, RouterError> {
+ pub fn get(&self, path: &str) -> Result, RouterError> {
let keys = generate_path(path);
+ let mut handlers = vec![];
let mut node_id = self.root();
+ collect_node_middleware(&mut handlers, self.node(node_id));
for key in keys {
node_id = self.find(node_id, key).ok_or(RouterError::RouteNotFound)?;
- if self.node(node_id).key == *WILDCARD_STOP_HASH {
+ let node = self.node(node_id);
+ collect_node_middleware(&mut handlers, self.node(node_id));
+ if node.key == *WILDCARD_STOP_HASH {
break;
}
}
- self.node(node_id).value().ok_or(RouterError::NoValue)
+
+ if let Some(h) = self.node(node_id).value() {
+ handlers.push(h);
+ Ok(handlers.into_iter())
+ } else {
+ Err(RouterError::NoValue)
+ }
}
}
@@ -173,7 +198,10 @@ impl Service for Router {
fn call(&mut self, req: Request) -> Self::Future {
match self.get(req.uri().path()) {
Err(_) => not_found(),
- Ok(h) => h.call(req),
+ Ok(mut handlers) => match handlers.next() {
+ None => not_found(),
+ Some(h) => h.call(req, Box::new(handlers)),
+ },
}
}
}
@@ -191,16 +219,27 @@ impl NewService for Router {
}
impl Node {
- fn new(key: u64, value: Option>) -> Node {
+ fn new(key: u64, value: Option) -> Node {
Node {
key,
value,
children: [NodeId(0); MAX_CHILDREN],
children_count: 0,
+ mws: None,
}
}
- fn value(&self) -> Option> {
+ pub fn add_middleware(&mut self, mw: HandlerObj) -> &mut Node {
+ if self.mws.is_none() {
+ self.mws = Some(vec![]);
+ }
+ if let Some(ref mut mws) = self.mws {
+ mws.push(mw.clone());
+ }
+ self
+ }
+
+ fn value(&self) -> Option {
match &self.value {
None => None,
Some(v) => Some(v.clone()),
@@ -208,7 +247,7 @@ impl Node {
}
fn set_value(&mut self, value: HandlerObj) {
- self.value = Some(Arc::new(value));
+ self.value = Some(value);
}
fn add_child(&mut self, child_id: NodeId) {
@@ -240,6 +279,14 @@ fn generate_path(route: &str) -> Vec {
.collect()
}
+fn collect_node_middleware(handlers: &mut Vec, node: &Node) {
+ if let Some(ref mws) = node.mws {
+ for mw in mws {
+ handlers.push(mw.clone());
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
@@ -263,30 +310,17 @@ mod tests {
#[test]
fn test_add_route() {
let mut routes = Router::new();
+ let h1 = Arc::new(HandlerImpl(1));
+ let h2 = Arc::new(HandlerImpl(2));
+ let h3 = Arc::new(HandlerImpl(3));
+ routes.add_route("/v1/users", h1.clone()).unwrap();
+ assert!(routes.add_route("/v1/users", h2.clone()).is_err());
+ routes.add_route("/v1/users/xxx", h3.clone()).unwrap();
+ routes.add_route("/v1/users/xxx/yyy", h3.clone()).unwrap();
+ routes.add_route("/v1/zzz/*", h3.clone()).unwrap();
+ assert!(routes.add_route("/v1/zzz/ccc", h2.clone()).is_err());
routes
- .add_route("/v1/users", Box::new(HandlerImpl(1)))
- .unwrap();
- assert!(
- routes
- .add_route("/v1/users", Box::new(HandlerImpl(2)))
- .is_err()
- );
- routes
- .add_route("/v1/users/xxx", Box::new(HandlerImpl(3)))
- .unwrap();
- routes
- .add_route("/v1/users/xxx/yyy", Box::new(HandlerImpl(3)))
- .unwrap();
- routes
- .add_route("/v1/zzz/*", Box::new(HandlerImpl(3)))
- .unwrap();
- assert!(
- routes
- .add_route("/v1/zzz/ccc", Box::new(HandlerImpl(2)))
- .is_err()
- );
- routes
- .add_route("/v1/zzz/*/zzz", Box::new(HandlerImpl(6)))
+ .add_route("/v1/zzz/*/zzz", Arc::new(HandlerImpl(6)))
.unwrap();
}
@@ -294,19 +328,19 @@ mod tests {
fn test_get() {
let mut routes = Router::new();
routes
- .add_route("/v1/users", Box::new(HandlerImpl(101)))
+ .add_route("/v1/users", Arc::new(HandlerImpl(101)))
.unwrap();
routes
- .add_route("/v1/users/xxx", Box::new(HandlerImpl(103)))
+ .add_route("/v1/users/xxx", Arc::new(HandlerImpl(103)))
.unwrap();
routes
- .add_route("/v1/users/xxx/yyy", Box::new(HandlerImpl(103)))
+ .add_route("/v1/users/xxx/yyy", Arc::new(HandlerImpl(103)))
.unwrap();
routes
- .add_route("/v1/zzz/*", Box::new(HandlerImpl(103)))
+ .add_route("/v1/zzz/*", Arc::new(HandlerImpl(103)))
.unwrap();
routes
- .add_route("/v1/zzz/*/zzz", Box::new(HandlerImpl(106)))
+ .add_route("/v1/zzz/*/zzz", Arc::new(HandlerImpl(106)))
.unwrap();
let call_handler = |url| {
@@ -314,6 +348,8 @@ mod tests {
let task = routes
.get(url)
.unwrap()
+ .next()
+ .unwrap()
.get(Request::new(Body::default()))
.and_then(|resp| ok(resp.status().as_u16()));
event_loop.run(task).unwrap()
diff --git a/api/tests/rest.rs b/api/tests/rest.rs
index b6d9ca618..b3447d426 100644
--- a/api/tests/rest.rs
+++ b/api/tests/rest.rs
@@ -5,6 +5,8 @@ extern crate hyper;
use api::*;
use hyper::{Body, Request};
use std::net::SocketAddr;
+use std::sync::atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT};
+use std::sync::Arc;
use std::{thread, time};
struct IndexHandler {
@@ -19,13 +21,41 @@ impl Handler for IndexHandler {
}
}
+pub struct CounterMiddleware {
+ counter: AtomicUsize,
+}
+
+impl CounterMiddleware {
+ fn new() -> CounterMiddleware {
+ CounterMiddleware {
+ counter: ATOMIC_USIZE_INIT,
+ }
+ }
+
+ fn value(&self) -> usize {
+ self.counter.load(Ordering::SeqCst)
+ }
+}
+
+impl Handler for CounterMiddleware {
+ fn call(
+ &self,
+ req: Request,
+ mut handlers: Box>,
+ ) -> ResponseFuture {
+ self.counter.fetch_add(1, Ordering::SeqCst);
+ handlers.next().unwrap().call(req, handlers)
+ }
+}
+
fn build_router() -> Router {
let route_list = vec!["get blocks".to_string(), "get chain".to_string()];
let index_handler = IndexHandler { list: route_list };
let mut router = Router::new();
router
- .add_route("/v1/*", Box::new(index_handler))
- .expect("add_route failed");
+ .add_route("/v1/*", Arc::new(index_handler))
+ .expect("add_route failed")
+ .add_middleware(Arc::new(LoggingMiddleware {}));
router
}
@@ -33,13 +63,17 @@ fn build_router() -> Router {
fn test_start_api() {
util::init_test_logger();
let mut server = ApiServer::new();
- let router = build_router();
+ let mut router = build_router();
+ let counter = Arc::new(CounterMiddleware::new());
+ // add middleware to the root
+ router.add_middleware(counter.clone());
let server_addr = "127.0.0.1:14434";
let addr: SocketAddr = server_addr.parse().expect("unable to parse server address");
assert!(server.start(addr, router));
let url = format!("http://{}/v1/", server_addr);
let index = api::client::get::>(url.as_str()).unwrap();
assert_eq!(index.len(), 2);
+ assert_eq!(counter.value(), 1);
assert!(server.stop());
thread::sleep(time::Duration::from_millis(1_000));
}
diff --git a/wallet/src/libwallet/controller.rs b/wallet/src/libwallet/controller.rs
index b5fa4c5e9..2ae2413f1 100644
--- a/wallet/src/libwallet/controller.rs
+++ b/wallet/src/libwallet/controller.rs
@@ -80,7 +80,7 @@ where
let mut router = Router::new();
router
- .add_route("/v1/wallet/owner/**", Box::new(api_handler))
+ .add_route("/v1/wallet/owner/**", Arc::new(api_handler))
.map_err(|_| ErrorKind::GenericError("Router failed to add route".to_string()))?;
let mut apis = ApiServer::new();
@@ -102,7 +102,7 @@ where
let mut router = Router::new();
router
- .add_route("/v1/wallet/foreign/**", Box::new(api_handler))
+ .add_route("/v1/wallet/foreign/**", Arc::new(api_handler))
.map_err(|_| ErrorKind::GenericError("Router failed to add route".to_string()))?;
let mut apis = ApiServer::new();