🚀. Socket Launch Week Day 3:Socket Firewall Now Blocks Malicious VS Code and Open VSX Extensions.Learn more
Sign In

rmcp

Package Overview
Dependencies
Maintainers
1
Versions
46
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

rmcp - cargo Package Compare versions

Comparing version
1.5.0
to
1.6.0
+69
src/transport/streamable_http_server/session/store.rs
use crate::model::InitializeRequestParams;
/// State persisted to an external store for cross-instance session recovery.
///
/// When a client reconnects to a different server instance, the new instance
/// loads this state to transparently replay the `initialize` handshake without
/// the client needing to re-initialize.
#[non_exhaustive]
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SessionState {
/// Parameters from the client's original `initialize` request.
pub initialize_params: InitializeRequestParams,
}
impl SessionState {
pub fn new(initialize_params: InitializeRequestParams) -> Self {
Self { initialize_params }
}
}
/// Type alias for boxed session store errors.
pub type SessionStoreError = Box<dyn std::error::Error + Send + Sync + 'static>;
/// Pluggable external session store for cross-instance recovery.
///
/// Implement this trait to back sessions with Redis, a database, or any
/// key-value store. The simplest usage is to set
/// `StreamableHttpServerConfig::session_store` to an `Arc<impl SessionStore>`.
///
/// # Example (in-memory, for testing)
///
/// ```rust,ignore
/// use std::{collections::HashMap, sync::Arc};
/// use tokio::sync::RwLock;
/// use rmcp::transport::streamable_http_server::session::store::{
/// SessionState, SessionStore, SessionStoreError,
/// };
///
/// #[derive(Default)]
/// struct InMemoryStore(Arc<RwLock<HashMap<String, SessionState>>>);
///
/// #[async_trait::async_trait]
/// impl SessionStore for InMemoryStore {
/// async fn load(&self, id: &str) -> Result<Option<SessionState>, SessionStoreError> {
/// Ok(self.0.read().await.get(id).cloned())
/// }
/// async fn store(&self, id: &str, state: &SessionState) -> Result<(), SessionStoreError> {
/// self.0.write().await.insert(id.to_owned(), state.clone());
/// Ok(())
/// }
/// async fn delete(&self, id: &str) -> Result<(), SessionStoreError> {
/// self.0.write().await.remove(id);
/// Ok(())
/// }
/// }
/// ```
#[async_trait::async_trait]
pub trait SessionStore: Send + Sync + 'static {
/// Load session state for the given `session_id`.
///
/// Returns `Ok(None)` when no entry exists (i.e. session is unknown to the store).
async fn load(&self, session_id: &str) -> Result<Option<SessionState>, SessionStoreError>;
/// Persist session state for the given `session_id`.
async fn store(&self, session_id: &str, state: &SessionState) -> Result<(), SessionStoreError>;
/// Remove session state for the given `session_id`.
async fn delete(&self, session_id: &str) -> Result<(), SessionStoreError>;
}
#![cfg(all(feature = "transport-streamable-http-server", not(feature = "local")))]
use std::time::Duration;
use rmcp::{
model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId},
transport::streamable_http_server::session::{SessionManager, local::LocalSessionManager},
};
#[tokio::test]
async fn test_init_timeout_terminates_pre_init_session() -> anyhow::Result<()> {
let mut manager = LocalSessionManager::default();
manager.session_config.init_timeout = Some(Duration::from_millis(200));
// Bind the transport so its drop-guard doesn't cancel the worker — we
// want termination via init_timeout, not via cancellation.
let (session_id, _transport) = manager.create_session().await?;
tokio::time::sleep(Duration::from_millis(500)).await;
let message = ClientJsonRpcMessage::request(
ClientRequest::PingRequest(PingRequest::default()),
RequestId::Number(1),
);
let result = manager.initialize_session(&session_id, message).await;
assert!(
result.is_err(),
"expected worker to be dead; got: {result:?}"
);
Ok(())
}
#[tokio::test]
async fn test_init_timeout_none_keeps_worker_alive() -> anyhow::Result<()> {
let mut manager = LocalSessionManager::default();
manager.session_config.init_timeout = None;
let (session_id, _transport) = manager.create_session().await?;
tokio::time::sleep(Duration::from_millis(500)).await;
let message = ClientJsonRpcMessage::request(
ClientRequest::PingRequest(PingRequest::default()),
RequestId::Number(1),
);
// Liveness probe: a live worker accepts the send then stalls waiting for
// a handler response (none is wired up), tripping the outer timeout. A
// dead worker would fail the send and return immediately.
let probe = tokio::time::timeout(
Duration::from_millis(200),
manager.initialize_session(&session_id, message),
)
.await;
assert!(
probe.is_err(),
"expected worker to be alive; got: {probe:?}"
);
Ok(())
}
#![cfg(all(
feature = "client",
feature = "server",
feature = "transport-streamable-http-client-reqwest",
feature = "transport-streamable-http-server",
not(feature = "local")
))]
use std::{collections::HashMap, sync::Arc};
use rmcp::{
ServiceExt,
transport::{
StreamableHttpClientTransport,
streamable_http_client::StreamableHttpClientTransportConfig,
streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService,
session::{SessionState, SessionStore, SessionStoreError, local::LocalSessionManager},
},
},
};
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
mod common;
use common::calculator::Calculator;
// ---------------------------------------------------------------------------
// Shared in-memory store used across tests
// ---------------------------------------------------------------------------
#[derive(Default, Clone)]
struct InMemorySessionStore(Arc<RwLock<HashMap<String, SessionState>>>);
impl InMemorySessionStore {
fn new() -> Self {
Self::default()
}
async fn len(&self) -> usize {
self.0.read().await.len()
}
}
#[async_trait::async_trait]
impl SessionStore for InMemorySessionStore {
async fn load(&self, session_id: &str) -> Result<Option<SessionState>, SessionStoreError> {
Ok(self.0.read().await.get(session_id).cloned())
}
async fn store(&self, session_id: &str, state: &SessionState) -> Result<(), SessionStoreError> {
self.0
.write()
.await
.insert(session_id.to_owned(), state.clone());
Ok(())
}
async fn delete(&self, session_id: &str) -> Result<(), SessionStoreError> {
self.0.write().await.remove(session_id);
Ok(())
}
}
// ---------------------------------------------------------------------------
// Helper: spin up a StreamableHttpService backed by the given store and
// return the bound address together with the cancellation token.
// ---------------------------------------------------------------------------
fn make_service(
session_store: Arc<dyn SessionStore>,
ct: &CancellationToken,
) -> StreamableHttpService<Calculator, LocalSessionManager> {
StreamableHttpService::new(|| Ok(Calculator::new()), Default::default(), {
let mut cfg = StreamableHttpServerConfig::default();
cfg.stateful_mode = true;
cfg.sse_keep_alive = None;
cfg.cancellation_token = ct.child_token();
cfg.session_store = Some(session_store);
cfg
})
}
// ---------------------------------------------------------------------------
// Test 1 — state is persisted to the store after a successful handshake
// ---------------------------------------------------------------------------
#[tokio::test]
async fn test_session_state_persisted_to_store() -> anyhow::Result<()> {
let store = Arc::new(InMemorySessionStore::new());
let ct = CancellationToken::new();
let service = make_service(store.clone(), &ct);
let router = axum::Router::new().nest_service("/mcp", service);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let handle = tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});
// Connect a full client — this performs the initialize + initialized handshake.
let transport = StreamableHttpClientTransport::from_config(
StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")),
);
let client = ().serve(transport).await?;
// Make a real request so the session is fully active.
let _resources = client.list_all_resources().await?;
// The store should now contain exactly one session entry.
assert_eq!(
store.len().await,
1,
"session state should be persisted to the store after initialization"
);
// Verify the stored state contains the expected client info.
let entries = store.0.read().await;
let state = entries.values().next().expect("store entry should exist");
assert_eq!(
state.initialize_params.client_info.name, "rmcp",
"stored client_info.name should match the rmcp client"
);
let _ = client.cancel().await;
ct.cancel();
handle.await?;
Ok(())
}
// ---------------------------------------------------------------------------
// Test 2 — store entry is removed when the client sends HTTP DELETE
// ---------------------------------------------------------------------------
#[tokio::test]
async fn test_session_state_deleted_from_store_on_delete() -> anyhow::Result<()> {
let store = Arc::new(InMemorySessionStore::new());
let session_manager = Arc::new(LocalSessionManager::default());
let ct = CancellationToken::new();
let service = StreamableHttpService::new(|| Ok(Calculator::new()), session_manager.clone(), {
let mut cfg = StreamableHttpServerConfig::default();
cfg.stateful_mode = true;
cfg.sse_keep_alive = None;
cfg.cancellation_token = ct.child_token();
cfg.session_store = Some(store.clone());
cfg
});
let router = axum::Router::new().nest_service("/mcp", service);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;
let handle = tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});
let transport = StreamableHttpClientTransport::from_config(
StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")),
);
let client = ().serve(transport).await?;
let _resources = client.list_all_resources().await?;
assert_eq!(store.len().await, 1, "store should have one entry");
// Get the session ID from the server's in-memory map.
let session_id = {
let sessions = session_manager.sessions.read().await;
sessions
.keys()
.next()
.cloned()
.expect("session should exist")
};
// Send an explicit HTTP DELETE — this is the signal to remove from store.
let http_client = reqwest::Client::new();
let response = http_client
.delete(format!("http://{addr}/mcp"))
.header("mcp-session-id", session_id.as_ref())
.send()
.await?;
assert_eq!(response.status(), 202);
assert_eq!(
store.len().await,
0,
"store entry should be removed after explicit DELETE"
);
let _ = client.cancel().await;
ct.cancel();
handle.await?;
Ok(())
}
// ---------------------------------------------------------------------------
// Helper: spin up a server on an ephemeral port and return its address and
// the join handle. The server shuts down when `ct` is cancelled.
// ---------------------------------------------------------------------------
fn spawn_server(
session_store: Option<Arc<dyn SessionStore>>,
session_manager: Arc<LocalSessionManager>,
ct: &CancellationToken,
) -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) {
let svc = StreamableHttpService::new(|| Ok(Calculator::new()), session_manager, {
let mut cfg = StreamableHttpServerConfig::default();
cfg.stateful_mode = true;
cfg.sse_keep_alive = None;
cfg.cancellation_token = ct.child_token();
cfg.session_store = session_store;
cfg
});
// Use std::net::TcpListener so the port is bound synchronously before
// we return — avoids a race between returning the addr and the server
// actually starting to accept connections.
let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
std_listener.set_nonblocking(true).unwrap();
let addr = std_listener.local_addr().unwrap();
let listener = tokio::net::TcpListener::from_std(std_listener).unwrap();
let router = axum::Router::new().nest_service("/mcp", svc);
let handle = tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});
(addr, handle)
}
// ---------------------------------------------------------------------------
// Test 3 — cross-instance session restore
//
// Both halves follow the same structure:
//
// Instance A initializes the session (session state may be saved to store)
// Instance A is fully shut down
// Instance B (fresh, no in-memory state) receives a request for the old ID
//
// Without a store → 404. With a shared store → transparent restore.
// ---------------------------------------------------------------------------
#[tokio::test]
async fn test_cross_instance_session_restore() -> anyhow::Result<()> {
let http = reqwest::Client::new();
// -----------------------------------------------------------------------
// Negative check: no session store → instance B returns 404.
// -----------------------------------------------------------------------
{
// --- Instance A (no store): initialize ---
let ct_a = CancellationToken::new();
let (addr_a, srv_a) = spawn_server(None, Arc::new(LocalSessionManager::default()), &ct_a);
let init_resp = http
.post(format!("http://{addr_a}/mcp"))
.header("accept", "application/json, text/event-stream")
.header("content-type", "application/json")
.body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"test","version":"0"}}}"#)
.send()
.await?;
assert_eq!(
init_resp.status(),
200,
"instance A: initialize should succeed"
);
let session_id = init_resp
.headers()
.get("mcp-session-id")
.expect("session ID header must be present")
.to_str()?
.to_owned();
// Shut down instance A completely.
ct_a.cancel();
srv_a.await?;
// --- Instance B (no store, fresh state): send request ---
let ct_b = CancellationToken::new();
let (addr_b, srv_b) = spawn_server(None, Arc::new(LocalSessionManager::default()), &ct_b);
let resp = http
.post(format!("http://{addr_b}/mcp"))
.header("accept", "application/json, text/event-stream")
.header("content-type", "application/json")
.header("mcp-session-id", &session_id)
.body(r#"{"jsonrpc":"2.0","id":2,"method":"ping","params":{}}"#)
.send()
.await?;
assert_eq!(
resp.status(),
reqwest::StatusCode::NOT_FOUND,
"without a session store, instance B must return 404 for an unknown session ID"
);
ct_b.cancel();
srv_b.await?;
}
// -----------------------------------------------------------------------
// Positive check: shared session store → instance B restores transparently.
// -----------------------------------------------------------------------
{
let store: Arc<dyn SessionStore> = Arc::new(InMemorySessionStore::new());
// --- Instance A (with store): initialize ---
let ct_a = CancellationToken::new();
let sm_a = Arc::new(LocalSessionManager::default());
let (addr_a, srv_a) = spawn_server(Some(store.clone()), sm_a.clone(), &ct_a);
let init_resp = http
.post(format!("http://{addr_a}/mcp"))
.header("accept", "application/json, text/event-stream")
.header("content-type", "application/json")
.body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"test","version":"0"}}}"#)
.send()
.await?;
assert_eq!(
init_resp.status(),
200,
"instance A: initialize should succeed"
);
let original_session_id = init_resp
.headers()
.get("mcp-session-id")
.expect("session ID header must be present")
.to_str()?
.to_owned();
// Confirm the session was persisted.
let store_ref = store
.load(&original_session_id)
.await
.expect("store load should not error");
assert!(
store_ref.is_some(),
"store should hold the session after initialization"
);
// Shut down instance A completely — session lives only in the store now.
ct_a.cancel();
srv_a.await?;
// --- Instance B (same store, fresh in-memory state): send request ---
let ct_b = CancellationToken::new();
let sm_b = Arc::new(LocalSessionManager::default());
let (addr_b, srv_b) = spawn_server(Some(store.clone()), sm_b.clone(), &ct_b);
let resp = http
.post(format!("http://{addr_b}/mcp"))
.header("accept", "application/json, text/event-stream")
.header("content-type", "application/json")
.header("mcp-session-id", &original_session_id)
.body(r#"{"jsonrpc":"2.0","id":2,"method":"ping","params":{}}"#)
.send()
.await?;
assert_eq!(
resp.status(),
200,
"instance B: request must succeed after transparent restore"
);
// The session must be in instance B's memory under the ORIGINAL ID.
{
let sessions = sm_b.sessions.read().await;
let restored_id = sessions
.keys()
.next()
.expect("session should exist in instance B after restore");
assert_eq!(
restored_id.as_ref(),
original_session_id.as_str(),
"restored session must keep the original session ID"
);
}
ct_b.cancel();
srv_b.await?;
}
Ok(())
}
//! Integration tests for tool list change notifications.
#![cfg(all(feature = "client", not(feature = "local")))]
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use rmcp::{
ClientHandler, RoleClient, RoleServer, ServerHandler, ServiceExt,
handler::server::{router::tool::ToolRoute, tool::ToolCallContext},
model::{CallToolResult, ServerCapabilities, ServerInfo, Tool},
service::{MaybeSendFuture, NotificationContext},
};
use tokio::sync::{Notify, RwLock};
#[derive(Clone)]
struct TestToolServer {
router: Arc<RwLock<rmcp::handler::server::router::tool::ToolRouter<Self>>>,
trigger_disable: Arc<Notify>,
trigger_enable: Arc<Notify>,
}
impl TestToolServer {
fn new() -> Self {
let mut tool_router = rmcp::handler::server::router::tool::ToolRouter::<Self>::new();
tool_router.add_route(ToolRoute::new_dyn(
Tool::new("tool_a", "Tool A", Arc::new(Default::default())),
|_ctx| Box::pin(async { Ok(CallToolResult::default()) }),
));
tool_router.add_route(ToolRoute::new_dyn(
Tool::new("tool_b", "Tool B", Arc::new(Default::default())),
|_ctx| Box::pin(async { Ok(CallToolResult::default()) }),
));
Self {
router: Arc::new(RwLock::new(tool_router)),
trigger_disable: Arc::new(Notify::new()),
trigger_enable: Arc::new(Notify::new()),
}
}
}
impl ServerHandler for TestToolServer {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
}
fn call_tool(
&self,
request: rmcp::model::CallToolRequestParams,
context: rmcp::service::RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<CallToolResult, rmcp::ErrorData>> + MaybeSendFuture + '_
{
async move {
let router = self.router.read().await;
let tcc = ToolCallContext::new(self, request, context);
router.call(tcc).await
}
}
fn list_tools(
&self,
_request: Option<rmcp::model::PaginatedRequestParams>,
_context: rmcp::service::RequestContext<RoleServer>,
) -> impl std::future::Future<Output = Result<rmcp::model::ListToolsResult, rmcp::ErrorData>>
+ MaybeSendFuture
+ '_ {
async move {
let router = self.router.read().await;
Ok(rmcp::model::ListToolsResult {
tools: router.list_all(),
..Default::default()
})
}
}
fn on_initialized(
&self,
context: NotificationContext<RoleServer>,
) -> impl std::future::Future<Output = ()> + MaybeSendFuture + '_ {
let router = self.router.clone();
let trigger_disable = self.trigger_disable.clone();
let trigger_enable = self.trigger_enable.clone();
let peer = context.peer.clone();
async move {
router.write().await.bind_peer_notifier(&peer);
let router = router.clone();
tokio::spawn(async move {
trigger_disable.notified().await;
{
let mut r = router.write().await;
r.disable_route("tool_a");
}
trigger_enable.notified().await;
{
let mut r = router.write().await;
r.enable_route("tool_a");
}
});
}
}
}
#[derive(Clone)]
struct TestToolClient {
notification_count: Arc<AtomicUsize>,
notify: Arc<Notify>,
}
impl TestToolClient {
fn new() -> Self {
Self {
notification_count: Arc::new(AtomicUsize::new(0)),
notify: Arc::new(Notify::new()),
}
}
}
impl ClientHandler for TestToolClient {
fn on_tool_list_changed(
&self,
_context: NotificationContext<RoleClient>,
) -> impl std::future::Future<Output = ()> + MaybeSendFuture + '_ {
self.notification_count.fetch_add(1, Ordering::SeqCst);
self.notify.notify_one();
std::future::ready(())
}
}
#[tokio::test]
async fn test_disable_enable_sends_tool_list_changed() {
let server = TestToolServer::new();
let trigger_disable = server.trigger_disable.clone();
let trigger_enable = server.trigger_enable.clone();
let client = TestToolClient::new();
let notification_count = client.notification_count.clone();
let client_notify = client.notify.clone();
let (server_transport, client_transport) = tokio::io::duplex(4096);
let server_handle = tokio::spawn(async move { server.serve(server_transport).await });
let client_service = client.serve(client_transport).await.unwrap();
let tools = client_service.peer().list_tools(None).await.unwrap();
assert_eq!(tools.tools.len(), 2);
trigger_disable.notify_one();
tokio::time::timeout(std::time::Duration::from_secs(5), client_notify.notified())
.await
.expect("timed out waiting for tool_list_changed");
assert_eq!(notification_count.load(Ordering::SeqCst), 1);
let tools = client_service.peer().list_tools(None).await.unwrap();
assert_eq!(tools.tools.len(), 1);
assert_eq!(tools.tools[0].name, "tool_b");
trigger_enable.notify_one();
tokio::time::timeout(std::time::Duration::from_secs(5), client_notify.notified())
.await
.expect("timed out waiting for tool_list_changed");
assert_eq!(notification_count.load(Ordering::SeqCst), 2);
let tools = client_service.peer().list_tools(None).await.unwrap();
assert_eq!(tools.tools.len(), 2);
client_service.cancel().await.unwrap();
server_handle.abort();
}
+1
-1
{
"git": {
"sha1": "020a38b6ad3d0f26487c464250a484fad2a06b0e"
"sha1": "014fb2e6cd9faddbe86ae30b5cc9adf84a62edb9"
},
"path_in_vcs": "crates/rmcp"
}

@@ -15,3 +15,3 @@ # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO

name = "rmcp"
version = "1.5.0"
version = "1.6.0"
build = "build.rs"

@@ -354,2 +354,6 @@ autolib = false

[[test]]
name = "test_streamable_http_init_timeout"
path = "tests/test_streamable_http_init_timeout.rs"
[[test]]
name = "test_streamable_http_json_response"

@@ -375,2 +379,12 @@ path = "tests/test_streamable_http_json_response.rs"

[[test]]
name = "test_streamable_http_session_store"
path = "tests/test_streamable_http_session_store.rs"
required-features = [
"client",
"server",
"transport-streamable-http-client-reqwest",
"transport-streamable-http-server",
]
[[test]]
name = "test_streamable_http_stale_session"

@@ -408,2 +422,6 @@ path = "tests/test_streamable_http_stale_session.rs"

[[test]]
name = "test_tool_disable_notification"
path = "tests/test_tool_disable_notification.rs"
[[test]]
name = "test_tool_handler"

@@ -537,3 +555,3 @@ path = "tests/test_tool_handler.rs"

[dependencies.rmcp-macros]
version = "1.5.0"
version = "1.6.0"
optional = true

@@ -540,0 +558,0 @@

@@ -10,2 +10,21 @@ # Changelog

## [1.6.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v1.5.0...rmcp-v1.6.0) - 2026-05-01
### Added
- *(http)* log Host/Origin rejections ([#826](https://github.com/modelcontextprotocol/rust-sdk/pull/826))
- *(http)* add Origin header validation ([#823](https://github.com/modelcontextprotocol/rust-sdk/pull/823))
- *(router)* support runtime disabling of tools ([#809](https://github.com/modelcontextprotocol/rust-sdk/pull/809))
- optional session store (resumabillity support) ([#775](https://github.com/modelcontextprotocol/rust-sdk/pull/775))
### Fixed
- add init_timeout for streamable-http sessions ([#811](https://github.com/modelcontextprotocol/rust-sdk/pull/811))
- *(http)* fall back to :authority for HTTP/2 ([#827](https://github.com/modelcontextprotocol/rust-sdk/pull/827))
- *(docs)* use correct Parameters<T> syntax in tool examples ([#814](https://github.com/modelcontextprotocol/rust-sdk/pull/814))
### Other
- add systemprompt-template to Built with rmcp ([#820](https://github.com/modelcontextprotocol/rust-sdk/pull/820))
## [1.5.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v1.4.0...rmcp-v1.5.0) - 2026-04-16

@@ -12,0 +31,0 @@

@@ -9,3 +9,3 @@ use std::sync::Arc;

RoleServer, Service,
model::{ClientRequest, ListPromptsResult, ListToolsResult, ServerResult},
model::{ClientNotification, ClientRequest, ListPromptsResult, ListToolsResult, ServerResult},
service::NotificationContext,

@@ -22,2 +22,3 @@ };

pub service: Arc<S>,
peer_slot: Arc<std::sync::OnceLock<crate::service::Peer<RoleServer>>>,
}

@@ -30,6 +31,10 @@

pub fn new(service: S) -> Self {
let (notifier, peer_slot) = tool::ToolRouter::<S>::deferred_peer_notifier();
let mut tool_router = tool::ToolRouter::new();
tool_router.set_notifier(notifier);
Self {
tool_router: tool::ToolRouter::new(),
tool_router,
prompt_router: prompt::PromptRouter::new(),
service: Arc::new(service),
peer_slot,
}

@@ -78,2 +83,8 @@ }

) -> Result<(), crate::ErrorData> {
if matches!(
&notification,
ClientNotification::InitializedNotification(_)
) {
let _ = self.peer_slot.set(context.peer.clone());
}
self.service

@@ -90,3 +101,6 @@ .handle_notification(notification, context)

ClientRequest::CallToolRequest(request) => {
if self.tool_router.has_route(request.params.name.as_ref())
if self
.tool_router
.map
.contains_key(request.params.name.as_ref())
|| !self.tool_router.transparent_when_not_found

@@ -142,4 +156,79 @@ {

fn get_info(&self) -> <RoleServer as crate::service::ServiceRole>::Info {
ServerHandler::get_info(&self.service)
let mut info = ServerHandler::get_info(&self.service);
info.capabilities
.tools
.get_or_insert_with(Default::default)
.list_changed = Some(true);
info
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::{
model::{CallToolResult, ClientNotification, ServerNotification, Tool},
service::{AtomicU32RequestIdProvider, Peer, PeerSinkMessage, RequestIdProvider},
};
struct DummyHandler;
impl ServerHandler for DummyHandler {}
async fn recv_notification(
rx: &mut tokio::sync::mpsc::Receiver<PeerSinkMessage<RoleServer>>,
) -> ServerNotification {
let msg = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
.await
.expect("timed out")
.expect("channel closed");
match msg {
PeerSinkMessage::Notification {
notification,
responder,
} => {
let _ = responder.send(Ok(()));
notification
}
other => panic!("expected notification, got {other:?}"),
}
}
#[tokio::test]
async fn test_router_deferred_notifier_e2e() {
let mut router = Router::new(DummyHandler).with_tool(tool::ToolRoute::new_dyn(
Tool::new("my_tool", "test", Arc::new(Default::default())),
|_ctx| Box::pin(async { Ok(CallToolResult::default()) }),
));
let id_provider: Arc<dyn RequestIdProvider> =
Arc::new(AtomicU32RequestIdProvider::default());
let (peer, mut rx) = Peer::<RoleServer>::new(id_provider, None);
let context = crate::service::NotificationContext {
peer: peer.clone(),
meta: Default::default(),
extensions: Default::default(),
};
router
.handle_notification(
ClientNotification::InitializedNotification(Default::default()),
context,
)
.await
.unwrap();
router.tool_router.disable_route("my_tool");
assert!(matches!(
recv_notification(&mut rx).await,
ServerNotification::ToolListChangedNotification(_)
));
router.tool_router.enable_route("my_tool");
assert!(matches!(
recv_notification(&mut rx).await,
ServerNotification::ToolListChangedNotification(_)
));
}
}

@@ -301,3 +301,2 @@ //! Tools for MCP servers.

}
#[derive(Debug)]
#[non_exhaustive]

@@ -309,4 +308,22 @@ pub struct ToolRouter<S> {

pub transparent_when_not_found: bool,
disabled: std::collections::HashSet<Cow<'static, str>>,
notifier: Option<Arc<dyn Fn() + Send + Sync>>,
}
impl<S> std::fmt::Debug for ToolRouter<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ToolRouter")
.field("map", &self.map)
.field(
"transparent_when_not_found",
&self.transparent_when_not_found,
)
.field("disabled", &self.disabled)
.field("notifier", &self.notifier.as_ref().map(|_| "..."))
.finish()
}
}
impl<S> Default for ToolRouter<S> {

@@ -317,5 +334,8 @@ fn default() -> Self {

transparent_when_not_found: false,
disabled: std::collections::HashSet::new(),
notifier: None,
}
}
}
impl<S> Clone for ToolRouter<S> {

@@ -326,2 +346,4 @@ fn clone(&self) -> Self {

transparent_when_not_found: self.transparent_when_not_found,
disabled: self.disabled.clone(),
notifier: self.notifier.clone(),
}

@@ -336,3 +358,7 @@ }

fn into_iter(self) -> Self::IntoIter {
self.map.into_values()
let mut map = self.map;
for name in &self.disabled {
map.remove(name);
}
map.into_values()
}

@@ -346,6 +372,3 @@ }

pub fn new() -> Self {
Self {
map: std::collections::HashMap::new(),
transparent_when_not_found: false,
}
Self::default()
}

@@ -403,2 +426,3 @@ pub fn with_route<R, A>(mut self, route: R) -> Self

pub fn merge(&mut self, other: ToolRouter<S>) {
self.disabled.extend(other.disabled);
for item in other.map.into_values() {

@@ -409,8 +433,113 @@ self.add_route(item);

/// Remove a tool route from the router.
///
/// The disabled state is **preserved**: if the name was in the disabled
/// set, it stays there so that a future [`add_route`](Self::add_route)
/// or [`merge`](Self::merge) with the same name will inherit the
/// disabled state. To also clear the disabled marker, call
/// [`enable_route`](Self::enable_route) afterwards.
pub fn remove_route(&mut self, name: &str) {
self.map.remove(name);
}
/// Returns `true` if the tool is registered **and** not currently
/// disabled.
pub fn has_route(&self, name: &str) -> bool {
self.map.contains_key(name)
self.map.contains_key(name) && !self.disabled.contains(name)
}
/// Disable a tool by name. Hidden from `list_all`, `get`, rejected by
/// `call`. Re-enable with [`enable_route`](Self::enable_route).
///
/// Returns `true` if the name was newly added to the disabled set.
/// The name is recorded even if no matching route exists yet, so routes
/// added later will inherit the disabled state.
pub fn disable_route(&mut self, name: impl Into<Cow<'static, str>>) -> bool {
let name = name.into();
let was_visible = self.map.contains_key(&name) && !self.disabled.contains(&name);
if was_visible {
self.notify_if_visible(&name);
}
self.disabled.insert(name)
}
/// Re-enable a previously disabled tool. Returns `true` if the name
/// was in the disabled set.
pub fn enable_route(&mut self, name: &str) -> bool {
let removed = self.disabled.remove(name);
if removed {
self.notify_if_visible(name);
}
removed
}
/// Returns `true` if the tool exists in the router **and** is currently
/// disabled. Returns `false` if the tool does not exist or if the name
/// was pre-disabled without a matching route.
pub fn is_disabled(&self, name: &str) -> bool {
self.map.contains_key(name) && self.disabled.contains(name)
}
/// Builder-style variant of [`disable_route`](Self::disable_route).
///
/// The name is recorded even if no matching route has been added yet,
/// so it can be called before [`with_route`](Self::with_route) in a
/// builder chain.
pub fn with_disabled(mut self, name: impl Into<Cow<'static, str>>) -> Self {
self.disabled.insert(name.into());
self
}
/// Install a callback invoked when the visible tool list changes.
pub fn set_notifier(&mut self, f: impl Fn() + Send + Sync + 'static) {
self.notifier = Some(Arc::new(f));
}
pub fn clear_notifier(&mut self) {
self.notifier = None;
}
/// Install a notifier that sends `notifications/tools/list_changed`
/// via the given peer.
pub fn bind_peer_notifier(&mut self, peer: &crate::service::Peer<crate::RoleServer>) {
let peer = peer.clone();
self.set_notifier(move || {
let peer = peer.clone();
tokio::spawn(async move {
if let Err(e) = peer.notify_tool_list_changed().await {
tracing::warn!("failed to send tools/list_changed notification: {e}");
}
});
});
}
/// Deferred notifier: no-op until the peer slot is filled.
pub(crate) fn deferred_peer_notifier() -> (
impl Fn() + Send + Sync + 'static,
Arc<std::sync::OnceLock<crate::service::Peer<crate::RoleServer>>>,
) {
let peer_slot =
Arc::new(std::sync::OnceLock::<crate::service::Peer<crate::RoleServer>>::new());
let slot_clone = peer_slot.clone();
let notifier = move || {
if let Some(peer) = slot_clone.get() {
let peer = peer.clone();
tokio::spawn(async move {
if let Err(e) = peer.notify_tool_list_changed().await {
tracing::warn!("failed to send tools/list_changed notification: {e}");
}
});
}
};
(notifier, peer_slot)
}
fn notify_if_visible(&self, name: &str) {
if self.map.contains_key(name) {
if let Some(notifier) = &self.notifier {
notifier();
}
}
}
pub async fn call(

@@ -420,5 +549,9 @@ &self,

) -> Result<CallToolResult, crate::ErrorData> {
let name = context.name();
if self.disabled.contains(name) {
return Err(crate::ErrorData::invalid_params("tool not found", None));
}
let item = self
.map
.get(context.name())
.get(name)
.ok_or_else(|| crate::ErrorData::invalid_params("tool not found", None))?;

@@ -432,3 +565,8 @@

pub fn list_all(&self) -> Vec<crate::model::Tool> {
let mut tools: Vec<_> = self.map.values().map(|item| item.attr.clone()).collect();
let mut tools: Vec<_> = self
.map
.values()
.filter(|item| !self.disabled.contains(&item.attr.name))
.map(|item| item.attr.clone())
.collect();
tools.sort_by(|a, b| a.name.cmp(&b.name));

@@ -440,4 +578,8 @@ tools

///
/// Returns the tool if found, or `None` if no tool with the given name exists.
/// Returns the tool if found and enabled, or `None` if the tool does not
/// exist or is disabled.
pub fn get(&self, name: &str) -> Option<&crate::model::Tool> {
if self.disabled.contains(name) {
return None;
}
self.map.get(name).map(|r| &r.attr)

@@ -467,1 +609,47 @@ }

}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::{
RoleServer,
model::{CallToolRequestParams, ErrorCode, NumberOrString},
service::{AtomicU32RequestIdProvider, Peer, RequestContext},
};
struct DummyService;
impl crate::handler::server::ServerHandler for DummyService {}
#[tokio::test]
async fn test_call_disabled_tool_returns_error() {
let service = DummyService;
let mut router = ToolRouter::new().with_route(ToolRoute::new_dyn(
crate::model::Tool::new("test_tool", "a test tool", Arc::new(Default::default())),
|_ctx| Box::pin(async { Ok(CallToolResult::default()) }),
));
router.disable_route("test_tool");
let id_provider: Arc<dyn crate::service::RequestIdProvider> =
Arc::new(AtomicU32RequestIdProvider::default());
let (peer, _rx) = Peer::<RoleServer>::new(id_provider, None);
let ctx = crate::handler::server::tool::ToolCallContext::new(
&service,
CallToolRequestParams {
meta: None,
name: Cow::Borrowed("test_tool"),
arguments: None,
task: None,
},
RequestContext::new(NumberOrString::Number(1), peer),
);
let err = router
.call(ctx)
.await
.expect_err("disabled tool should reject");
assert_eq!(err.code, ErrorCode::INVALID_PARAMS);
assert_eq!(err.message, "tool not found");
}
}
pub mod session;
#[cfg(all(feature = "transport-streamable-http-server", not(feature = "local")))]
pub mod tower;
pub use session::{SessionId, SessionManager};
pub use session::{RestoreOutcome, SessionId, SessionManager, SessionRestoreMarker};
#[cfg(all(feature = "transport-streamable-http-server", not(feature = "local")))]
pub use tower::{StreamableHttpServerConfig, StreamableHttpService};

@@ -33,3 +33,38 @@ //! Session management for the Streamable HTTP transport.

pub mod never;
pub mod store;
pub use store::{SessionState, SessionStore, SessionStoreError};
/// Extension marker inserted into the `initialize` request extensions during a
/// session restore replay. Handlers can check for its presence to distinguish a
/// cross-instance restore from a genuine client-initiated `initialize` request.
///
/// ```rust,ignore
/// if req.extensions().get::<SessionRestoreMarker>().is_some() {
/// // this is a restore replay, not a fresh client connection
/// }
/// ```
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct SessionRestoreMarker {
pub id: SessionId,
}
/// The outcome of a [`SessionManager::restore_session`] call.
#[non_exhaustive]
#[derive(Debug)]
pub enum RestoreOutcome<T> {
/// The session was just re-created from external state; the caller must
/// spawn an MCP handler against the returned transport and replay the
/// `initialize` handshake.
Restored(T),
/// The session was already present in memory (e.g. a concurrent request
/// already restored it). The caller should proceed as if `has_session`
/// had returned `true` — no further action is required.
AlreadyPresent,
/// This session manager does not support external-store restore.
/// The caller should fall through to the normal 404 response.
NotSupported,
}
/// Controls how MCP sessions are created, validated, and closed.

@@ -102,2 +137,20 @@ ///

> + Send;
/// Attempt to restore a previously-known session from external state,
/// creating a fresh in-memory session worker with the given `id`.
///
/// See [`RestoreOutcome`] for the three possible results:
/// - [`RestoreOutcome::Restored`] — session re-created; caller must spawn
/// an MCP handler and replay the `initialize` handshake.
/// - [`RestoreOutcome::AlreadyPresent`] — session is already in memory
/// (e.g. a concurrent request restored it first); caller proceeds
/// normally.
/// - [`RestoreOutcome::NotSupported`] (default) — this session manager
/// does not support external-store restore; caller returns 404.
fn restore_session(
&self,
_id: SessionId,
) -> impl Future<Output = Result<RestoreOutcome<Self::Transport>, Self::Error>> + Send {
futures::future::ready(Ok(RestoreOutcome::NotSupported))
}
}

@@ -139,2 +139,16 @@ use std::{

}
async fn restore_session(
&self,
id: SessionId,
) -> Result<RestoreOutcome<Self::Transport>, Self::Error> {
let mut sessions = self.sessions.write().await;
if sessions.contains_key(&id) {
// A concurrent request already restored this session.
return Ok(RestoreOutcome::AlreadyPresent);
}
let (handle, worker) = create_local_session(id.clone(), self.session_config.clone());
sessions.insert(id, handle);
Ok(RestoreOutcome::Restored(WorkerTransport::spawn(worker)))
}
}

@@ -192,3 +206,3 @@

use super::{ServerSseMessage, SessionManager};
use super::{RestoreOutcome, ServerSseMessage, SessionManager};

@@ -921,2 +935,4 @@ struct CachedTx {

KeepAliveTimeout(Duration),
#[error("init timeout after {}ms", _0.as_millis())]
InitTimeout(Duration),
#[error("Transport closed")]

@@ -951,9 +967,20 @@ TransportClosed,

}
// waiting for initialize request
let evt = self.event_rx.recv().await.ok_or_else(|| {
WorkerQuitReason::fatal(
LocalSessionWorkerError::TransportTerminated,
"get initialize request",
)
})?;
let init_timeout = self.session_config.init_timeout.unwrap_or(Duration::MAX);
let evt = tokio::select! {
evt = self.event_rx.recv() => evt.ok_or_else(|| {
WorkerQuitReason::fatal(
LocalSessionWorkerError::TransportTerminated,
"get initialize request",
)
})?,
_ = context.cancellation_token.cancelled() => {
return Err(WorkerQuitReason::Cancelled);
}
_ = tokio::time::sleep(init_timeout) => {
return Err(WorkerQuitReason::fatal(
LocalSessionWorkerError::InitTimeout(init_timeout),
"waiting for initialize request",
));
}
};
let SessionEvent::InitializeRequest { request, responder } = evt else {

@@ -1115,2 +1142,6 @@ return Err(WorkerQuitReason::fatal(

pub completed_cache_ttl: Duration,
/// Maximum duration to wait for the `initialize` request after session
/// creation. If not received within this window, the session is
/// terminated. Default is 60 seconds. Set to `None` to disable.
pub init_timeout: Option<Duration>,
}

@@ -1123,2 +1154,3 @@

pub const DEFAULT_COMPLETED_CACHE_TTL: Duration = Duration::from_secs(60);
pub const DEFAULT_INIT_TIMEOUT: Duration = Duration::from_secs(60);
}

@@ -1133,2 +1165,3 @@

completed_cache_ttl: Self::DEFAULT_COMPLETED_CACHE_TTL,
init_timeout: Some(Self::DEFAULT_INIT_TIMEOUT),
}

@@ -1135,0 +1168,0 @@ }

@@ -1,2 +0,2 @@

use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};
use std::{collections::HashMap, convert::Infallible, fmt::Display, sync::Arc, time::Duration};

@@ -11,6 +11,11 @@ use bytes::Bytes;

use super::session::SessionManager;
use super::session::{
RestoreOutcome, SessionId, SessionManager, SessionRestoreMarker, SessionState, SessionStore,
};
use crate::{
RoleServer,
model::{ClientJsonRpcMessage, ClientRequest, GetExtensions, ProtocolVersion},
model::{
ClientJsonRpcMessage, ClientNotification, ClientRequest, GetExtensions, InitializeRequest,
InitializedNotification, ProtocolVersion,
},
serve_server,

@@ -63,4 +68,41 @@ service::serve_directly,

pub allowed_hosts: Vec<String>,
/// Allowed browser origins for inbound `Origin` validation.
///
/// Defaults to an empty list, which disables Origin validation. When
/// non-empty, requests carrying an `Origin` header must match per RFC 6454
/// `(scheme, host, port)`; missing-`Origin` requests still pass. Entries
/// must include a scheme; `"null"` matches the browser's `Origin: null`.
/// examples:
/// allowed_origins = ["https://app.example.com", "http://localhost:8080"]
pub allowed_origins: Vec<String>,
/// Optional external session store for cross-instance recovery.
///
/// When set, [`SessionState`] (the client's `initialize` parameters) is
/// persisted after a successful handshake and deleted when the session
/// closes. On any subsequent request that arrives at an instance with no
/// in-memory session, the store is consulted: if an entry is found the
/// session is transparently restored so the client does not need to
/// re-initialize.
///
/// # Example
/// ```rust,ignore
/// use std::sync::Arc;
/// use rmcp::transport::streamable_http_server::{
/// StreamableHttpServerConfig, session::SessionStore,
/// };
///
/// let config = StreamableHttpServerConfig {
/// session_store: Some(Arc::new(MyRedisStore::new())),
/// ..Default::default()
/// };
/// ```
pub session_store: Option<Arc<dyn SessionStore>>,
}
impl std::fmt::Debug for dyn SessionStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("<SessionStore>")
}
}
impl Default for StreamableHttpServerConfig {

@@ -75,2 +117,4 @@ fn default() -> Self {

allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()],
allowed_origins: vec![],
session_store: None,
}

@@ -93,2 +137,14 @@ }

}
pub fn with_allowed_origins(
mut self,
allowed_origins: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.allowed_origins = allowed_origins.into_iter().map(Into::into).collect();
self
}
/// Disable Origin validation, reverting to the default ignore-Origin behavior.
pub fn disable_allowed_origins(mut self) -> Self {
self.allowed_origins.clear();
self
}
pub fn with_sse_keep_alive(mut self, duration: Option<Duration>) -> Self {

@@ -216,2 +272,55 @@ self.sse_keep_alive = duration;

#[derive(Debug, Clone, PartialEq, Eq)]
enum NormalizedOrigin {
Null,
Tuple {
scheme: String,
host: String,
port: Option<u16>,
},
}
fn parse_origin_value(value: &str) -> Option<NormalizedOrigin> {
let value = value.trim();
if value.is_empty() {
return None;
}
if value.eq_ignore_ascii_case("null") {
return Some(NormalizedOrigin::Null);
}
let uri = http::Uri::try_from(value).ok()?;
let scheme = uri.scheme_str()?.to_ascii_lowercase();
let authority = uri.authority()?;
Some(NormalizedOrigin::Tuple {
scheme,
host: normalize_host(authority.host()),
port: authority.port_u16(),
})
}
fn origin_is_allowed(origin: &NormalizedOrigin, allowed_origins: &[String]) -> bool {
if allowed_origins.is_empty() {
return true;
}
allowed_origins
.iter()
.filter_map(|raw| parse_origin_value(raw))
.any(|allowed| match (&allowed, origin) {
(NormalizedOrigin::Null, NormalizedOrigin::Null) => true,
(
NormalizedOrigin::Tuple {
scheme: a_scheme,
host: a_host,
port: a_port,
},
NormalizedOrigin::Tuple {
scheme: o_scheme,
host: o_host,
port: o_port,
},
) => a_scheme == o_scheme && a_host == o_host && (a_port.is_none() || a_port == o_port),
_ => false,
})
}
fn bad_request_response(message: &str) -> BoxResponse {

@@ -227,12 +336,29 @@ let body = Full::from(message.to_string()).boxed();

fn parse_host_header(headers: &HeaderMap) -> Result<NormalizedAuthority, BoxResponse> {
let Some(host) = headers.get(http::header::HOST) else {
return Err(bad_request_response("Bad Request: missing Host header"));
};
let host = host
.to_str()
.map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?;
let authority = http::uri::Authority::try_from(host)
.map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?;
fn parse_host_header(
uri: &http::Uri,
headers: &HeaderMap,
) -> Result<NormalizedAuthority, BoxResponse> {
if let Some(host) = headers.get(http::header::HOST) {
let host_str = host
.to_str()
.inspect_err(|_| {
tracing::warn!(host = ?host, "rejected request with non-UTF-8 Host header");
})
.map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?;
let authority = http::uri::Authority::try_from(host_str)
.inspect_err(|_| {
tracing::warn!(
host = host_str,
"rejected request with malformed Host header"
);
})
.map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?;
return Ok(normalize_authority(authority.host(), authority.port_u16()));
}
// HTTP/2 carries the host in `:authority`; middleware such as
// `axum::Router::nest` can drop the `Host` header hyper synthesizes from it.
let authority = uri.authority().ok_or_else(|| {
tracing::warn!("rejected request with missing Host header and no :authority");
bad_request_response("Bad Request: missing Host header")
})?;
Ok(normalize_authority(authority.host(), authority.port_u16()))

@@ -242,10 +368,50 @@ }

fn validate_dns_rebinding_headers(
uri: &http::Uri,
headers: &HeaderMap,
config: &StreamableHttpServerConfig,
) -> Result<(), BoxResponse> {
let host = parse_host_header(headers)?;
let host = parse_host_header(uri, headers)?;
if !host_is_allowed(&host, &config.allowed_hosts) {
tracing::warn!(
host = ?host,
"rejected request with disallowed Host header (possible DNS rebinding attempt)",
);
return Err(forbidden_response("Forbidden: Host header is not allowed"));
}
validate_origin_header(headers, &config.allowed_origins)?;
Ok(())
}
fn validate_origin_header(
headers: &HeaderMap,
allowed_origins: &[String],
) -> Result<(), BoxResponse> {
if allowed_origins.is_empty() {
return Ok(());
}
let Some(origin_header) = headers.get(http::header::ORIGIN) else {
return Ok(());
};
let origin_str = origin_header
.to_str()
.inspect_err(|_| {
tracing::warn!(origin = ?origin_header, "rejected request with non-UTF-8 Origin header");
})
.map_err(|_| bad_request_response("Bad Request: Invalid Origin header encoding"))?;
let origin = parse_origin_value(origin_str).ok_or_else(|| {
tracing::warn!(
origin = origin_str,
"rejected request with malformed Origin header",
);
bad_request_response("Bad Request: Invalid Origin header")
})?;
if !origin_is_allowed(&origin, allowed_origins) {
tracing::warn!(
origin = ?origin,
"rejected request with disallowed Origin header (possible cross-origin attack)",
);
return Err(forbidden_response(
"Forbidden: Origin header is not allowed",
));
}
Ok(())

@@ -341,2 +507,9 @@ }

service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>,
/// Tracks in-progress session restores so that concurrent requests for the
/// same unknown session ID wait for the first restore to complete rather
/// than racing to replay the initialize handshake. `None` when no external
/// session store is configured (avoids allocating the map).
pending_restores: Option<
Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::watch::Sender<Option<bool>>>>>,
>,
}

@@ -350,2 +523,3 @@

service_factory: self.service_factory.clone(),
pending_restores: self.pending_restores.clone(),
}

@@ -381,2 +555,31 @@ }

/// Guard used inside [`StreamableHttpService::try_restore_from_store`].
///
/// Ensures the `pending_restores` map entry is always cleaned up — even when
/// the future is cancelled mid-await.
///
/// `result` defaults to `false` (failure / cancellation). Only the success path
/// needs to set it to `true` before returning.
struct PendingRestoreGuard {
pending_restores:
Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::watch::Sender<Option<bool>>>>>,
session_id: SessionId,
watch_tx: tokio::sync::watch::Sender<Option<bool>>,
/// The value that will be broadcast to waiting tasks on drop.
result: bool,
}
impl Drop for PendingRestoreGuard {
fn drop(&mut self) {
// `send` is synchronous — unblocks waiters immediately, no lock needed.
let _ = self.watch_tx.send(Some(self.result));
// Remove the map entry asynchronously (requires the async write lock).
let pending_restores = self.pending_restores.clone();
let session_id = self.session_id.clone();
tokio::spawn(async move {
pending_restores.write().await.remove(&session_id);
});
}
}
impl<S, M> StreamableHttpService<S, M>

@@ -392,2 +595,8 @@ where

) -> Self {
let pending_restores = config.session_store.is_some().then(|| {
Arc::new(tokio::sync::RwLock::new(HashMap::<
SessionId,
tokio::sync::watch::Sender<Option<bool>>,
>::new()))
});
Self {

@@ -397,2 +606,3 @@ config,

service_factory: Arc::new(service_factory),
pending_restores,
}

@@ -403,2 +613,213 @@ }

}
/// Spawn a task that runs `serve_server` for the given session, waits for
/// it to finish, and then calls `close_session`.
///
/// `init_done_tx`: when `Some`, the sender is fired after `serve_server`
/// returns successfully, signalling to the caller that the MCP handshake
/// is complete. Used by `try_restore_from_store` to synchronise with the
/// restore `initialize` replay; `handle_post` passes `None`.
fn spawn_session_worker(
session_manager: Arc<M>,
session_id: SessionId,
service: S,
transport: M::Transport,
init_done_tx: Option<tokio::sync::oneshot::Sender<()>>,
) where
S: crate::Service<RoleServer> + Send + 'static,
M: SessionManager,
{
tokio::spawn(async move {
let svc =
serve_server::<S, M::Transport, _, TransportAdapterIdentity>(service, transport)
.await;
match svc {
Ok(svc) => {
if let Some(tx) = init_done_tx {
let _ = tx.send(());
}
let _ = svc.waiting().await;
}
Err(e) => {
tracing::error!("Failed to serve session: {e}");
// Dropping init_done_tx (if Some) signals failure to the caller.
}
}
let _ = session_manager
.close_session(&session_id)
.await
.inspect_err(|e| {
tracing::error!("Failed to close session {session_id}: {e}");
});
});
}
/// Attempt to restore a session from the external store.
///
/// Returns `true` when the session is available and ready to serve the
/// current request (either just restored or already in memory). Returns
/// `false` when no store is configured or the session ID is unknown.
///
/// Concurrent requests for the same unknown session ID are serialized: the
/// first caller performs the full restore and handshake replay while others
/// subscribe to a `watch` channel and wait, avoiding duplicate handshakes.
async fn try_restore_from_store(
&self,
session_id: &SessionId,
parts: &http::request::Parts,
) -> Result<bool, std::io::Error>
where
S: crate::Service<RoleServer> + Send + 'static,
M: SessionManager,
{
// Both fields are Some iff a session store is configured.
let (Some(pending_restores), Some(store)) =
(&self.pending_restores, &self.config.session_store)
else {
return Ok(false);
};
// Serialize concurrent restores for the same session ID.
// Write-lock once: if another task is already restoring, subscribe and wait;
// otherwise, register ourselves as the restoring task.
// Channel value: None = in progress, Some(true) = restored, Some(false) = not found/failed.
let (watch_tx, _watch_rx) = tokio::sync::watch::channel(None::<bool>);
{
let mut pending = pending_restores.write().await;
if let Some(tx) = pending.get(session_id) {
let mut rx = tx.subscribe();
drop(pending);
// Wait for the restore to finish, then propagate the outcome.
let result = rx
.wait_for(|r| r.is_some())
.await
.map(|r| r.unwrap_or(false))
.unwrap_or(false);
return Ok(result);
}
pending.insert(session_id.clone(), watch_tx.clone());
}
// Guard: signals waiters and cleans up the map entry on drop
let mut guard = PendingRestoreGuard {
pending_restores: pending_restores.clone(),
session_id: session_id.clone(),
watch_tx: watch_tx.clone(),
result: false,
};
// --- Step 3: load from external store ---
let state = match store.load(session_id.as_ref()).await {
Ok(Some(s)) => s,
Ok(None) => {
return Ok(false);
}
Err(e) => {
tracing::error!(
session_id = session_id.as_ref(),
error = %e,
"session store load failed during restore"
);
return Err(std::io::Error::other(e));
}
};
// --- Step 4: ask the session manager to allocate an in-memory worker ---
let transport = match self
.session_manager
.restore_session(session_id.clone())
.await
.map_err(|e| std::io::Error::other(e.to_string()))
{
Ok(RestoreOutcome::Restored(t)) => t,
Ok(RestoreOutcome::AlreadyPresent) => {
// Invariant violation: pending_restores ensures only one task can call
// restore_session per session ID, so AlreadyPresent is impossible here.
return Err(std::io::Error::other(
"restore_session returned AlreadyPresent unexpectedly; session manager might have modified the session store outside of the restore_session API",
));
}
Ok(RestoreOutcome::NotSupported) => {
return Ok(false);
}
Err(e) => {
return Err(e);
}
};
// --- Step 5: replay the MCP initialize handshake ---
let service = match self.get_service() {
Ok(s) => s,
Err(e) => {
return Err(e);
}
};
// `serve_server` requires both the `initialize` request and the
// `notifications/initialized` notification before transitioning to
// the running state — we must send both before returning.
let mut restore_init = ClientJsonRpcMessage::request(
ClientRequest::InitializeRequest(InitializeRequest {
params: state.initialize_params,
..Default::default()
}),
crate::model::NumberOrString::Number(0),
);
restore_init.insert_extension(parts.clone());
restore_init.insert_extension(SessionRestoreMarker {
id: session_id.clone(),
});
let mut restore_initialized = ClientJsonRpcMessage::notification(
ClientNotification::InitializedNotification(InitializedNotification {
..Default::default()
}),
);
restore_initialized.insert_extension(parts.clone());
restore_initialized.insert_extension(SessionRestoreMarker {
id: session_id.clone(),
});
// Signal from the spawned task once serve_server finishes initialising.
let (init_done_tx, init_done_rx) = tokio::sync::oneshot::channel::<()>();
Self::spawn_session_worker(
self.session_manager.clone(),
session_id.clone(),
service,
transport,
Some(init_done_tx),
);
if let Err(e) = self
.session_manager
.initialize_session(session_id, restore_init)
.await
.map_err(|e| std::io::Error::other(e.to_string()))
{
return Err(e);
}
if let Err(e) = self
.session_manager
.accept_message(session_id, restore_initialized)
.await
.map_err(|e| std::io::Error::other(e.to_string()))
{
return Err(e);
}
if init_done_rx.await.is_err() {
return Err(std::io::Error::other(
"serve_server initialization failed during restore",
));
}
// Restore complete — wake any waiting concurrent requests.
guard.result = true;
tracing::debug!(
session_id = session_id.as_ref(),
"session restored from external store"
);
Ok(true)
}
pub async fn handle<B>(&self, request: Request<B>) -> Response<BoxBody<Bytes, Infallible>>

@@ -409,3 +830,5 @@ where

{
if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) {
if let Err(response) =
validate_dns_rebinding_headers(request.uri(), request.headers(), &self.config)
{
return response;

@@ -479,14 +902,22 @@ }

.map_err(internal_error_response("check session"))?;
let (parts, _) = request.into_parts();
if !has_session {
// MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions
return Ok(Response::builder()
.status(http::StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
.expect("valid response"));
// Attempt transparent cross-instance restore from external store.
let restored = self
.try_restore_from_store(&session_id, &parts)
.await
.map_err(internal_error_response("restore session"))?;
if !restored {
// MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions
return Ok(Response::builder()
.status(http::StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
.expect("valid response"));
}
}
// Validate MCP-Protocol-Version header (per 2025-06-18 spec)
validate_protocol_version_header(request.headers())?;
validate_protocol_version_header(&parts.headers)?;
// check if last event id is provided
let last_event_id = request
.headers()
let last_event_id = parts
.headers
.get(HEADER_LAST_EVENT_ID)

@@ -603,7 +1034,14 @@ .and_then(|v| v.to_str().ok())

if !has_session {
// MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions
return Ok(Response::builder()
.status(http::StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
.expect("valid response"));
// Attempt transparent cross-instance restore from external store.
let restored = self
.try_restore_from_store(&session_id, &part)
.await
.map_err(internal_error_response("restore session"))?;
if !restored {
// MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions
return Ok(Response::builder()
.status(http::StatusCode::NOT_FOUND)
.body(Full::new(Bytes::from("Not Found: Session not found")).boxed())
.expect("valid response"));
}
}

@@ -660,2 +1098,17 @@

.map_err(internal_error_response("create session"))?;
// Capture init params for external store persistence before
// extensions are injected (which would require Clone).
let stored_init_params = if self.config.session_store.is_some() {
if let ClientJsonRpcMessage::Request(req) = &message {
if let ClientRequest::InitializeRequest(init_req) = &req.request {
Some(init_req.params.clone())
} else {
None
}
} else {
None
}
} else {
None
};
if let ClientJsonRpcMessage::Request(req) = &mut message {

@@ -674,27 +1127,9 @@ if !matches!(req.request, ClientRequest::InitializeRequest(_)) {

// spawn a task to serve the session
tokio::spawn({
let session_manager = self.session_manager.clone();
let session_id = session_id.clone();
async move {
let service = serve_server::<S, M::Transport, _, TransportAdapterIdentity>(
service, transport,
)
.await;
match service {
Ok(service) => {
// on service created
let _ = service.waiting().await;
}
Err(e) => {
tracing::error!("Failed to create service: {e}");
}
}
let _ = session_manager
.close_session(&session_id)
.await
.inspect_err(|e| {
tracing::error!("Failed to close session {session_id}: {e}");
});
}
});
Self::spawn_session_worker(
self.session_manager.clone(),
session_id.clone(),
service,
transport,
None,
);
// get initialize response

@@ -706,2 +1141,19 @@ let response = self

.map_err(internal_error_response("create stream"))?;
// Persist session state to external store after a successful handshake.
if let (Some(store), Some(params)) =
(&self.config.session_store, stored_init_params)
{
let state = SessionState {
initialize_params: params,
};
let _ = store
.store(session_id.as_ref(), &state)
.await
.inspect_err(|e| {
tracing::warn!(
"Failed to persist session {} to store: {e}",
session_id
);
});
}
let stream =

@@ -829,4 +1281,11 @@ futures::stream::once(async move { ServerSseMessage::from_message(response) });

.map_err(internal_error_response("close session"))?;
// Remove from external store: a DELETE means the client intentionally
// ends the session, so the store entry is no longer needed.
if let Some(store) = &self.config.session_store {
let _ = store.delete(session_id.as_ref()).await.inspect_err(|e| {
tracing::warn!("Failed to delete session {} from store: {e}", session_id);
});
}
Ok(accepted_response())
}
}

@@ -1033,1 +1033,198 @@ #![cfg(not(feature = "local"))]

}
/// Integration test: Verify the validator falls back to the URI authority when
/// the Host header is absent (HTTP/2 :authority pseudo-header scenario).
#[tokio::test]
#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))]
async fn test_server_falls_back_to_uri_authority_when_host_header_missing() {
use std::sync::Arc;
use bytes::Bytes;
use http::{Method, Request, header::CONTENT_TYPE};
use http_body_util::Full;
use rmcp::{
handler::server::ServerHandler,
model::{ServerCapabilities, ServerInfo},
transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
},
};
use serde_json::json;
#[derive(Clone)]
struct TestHandler;
impl ServerHandler for TestHandler {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().build())
}
}
let service = StreamableHttpService::new(
|| Ok(TestHandler),
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default(),
);
let init_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
});
// Allowed authority via URI only — no Host header.
let allowed_request = Request::builder()
.method(Method::POST)
.uri("http://localhost:8080/")
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();
assert!(allowed_request.headers().get("Host").is_none());
let response = service.handle(allowed_request).await;
assert_eq!(response.status(), http::StatusCode::OK);
// Disallowed authority via URI only — no Host header.
let bad_request = Request::builder()
.method(Method::POST)
.uri("http://attacker.example/")
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();
assert!(bad_request.headers().get("Host").is_none());
let response = service.handle(bad_request).await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
// Neither Host header nor URI authority — still a 400.
let missing_request = Request::builder()
.method(Method::POST)
.uri("/")
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();
assert!(missing_request.headers().get("Host").is_none());
assert!(missing_request.uri().authority().is_none());
let response = service.handle(missing_request).await;
assert_eq!(response.status(), http::StatusCode::BAD_REQUEST);
}
#[cfg(all(feature = "transport-streamable-http-server", feature = "server"))]
mod origin_validation {
use std::sync::Arc;
use bytes::Bytes;
use http::{Method, Request, header::CONTENT_TYPE};
use http_body_util::Full;
use rmcp::{
handler::server::ServerHandler,
model::{ServerCapabilities, ServerInfo},
transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
},
};
use serde_json::json;
#[derive(Clone)]
struct TestHandler;
impl ServerHandler for TestHandler {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().build())
}
}
fn service_with_allowed_origins(
origins: &[&str],
) -> StreamableHttpService<TestHandler, LocalSessionManager> {
StreamableHttpService::new(
|| Ok(TestHandler),
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default().with_allowed_origins(origins.iter().copied()),
)
}
fn init_request(origin: Option<&str>) -> Request<Full<Bytes>> {
let init_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0.0"}
}
});
let mut builder = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080");
if let Some(origin) = origin {
builder = builder.header("Origin", origin);
}
builder
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap()
}
#[tokio::test]
async fn allowlisted_origin_is_allowed() {
let service = service_with_allowed_origins(&["http://localhost:8080"]);
let response = service
.handle(init_request(Some("http://localhost:8080")))
.await;
assert_eq!(response.status(), http::StatusCode::OK);
}
#[tokio::test]
async fn non_allowlisted_origin_is_forbidden() {
let service = service_with_allowed_origins(&["http://localhost:8080"]);
let response = service
.handle(init_request(Some("http://attacker.example")))
.await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn missing_origin_passes_through() {
let service = service_with_allowed_origins(&["http://localhost:8080"]);
let response = service.handle(init_request(None)).await;
assert_eq!(response.status(), http::StatusCode::OK);
}
#[tokio::test]
async fn scheme_mismatch_is_forbidden() {
let service = service_with_allowed_origins(&["http://localhost:8080"]);
let response = service
.handle(init_request(Some("https://localhost:8080")))
.await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn null_origin_is_allowed_when_allowlisted() {
let service = service_with_allowed_origins(&["null"]);
let response = service.handle(init_request(Some("null"))).await;
assert_eq!(response.status(), http::StatusCode::OK);
}
#[tokio::test]
async fn null_origin_is_forbidden_when_not_allowlisted() {
let service = service_with_allowed_origins(&["http://localhost:8080"]);
let response = service.handle(init_request(Some("null"))).await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
}
}
#![cfg(not(feature = "local"))]
use std::collections::HashMap;
use std::{
collections::HashMap,
sync::atomic::{AtomicUsize, Ordering},
};

@@ -87,1 +90,284 @@ use futures::future::BoxFuture;

}
fn build_router() -> ToolRouter<TestHandler<()>> {
ToolRouter::<TestHandler<()>>::new()
.with_route((async_function_tool_attr(), async_function))
.with_route((async_function2_tool_attr(), async_function2))
+ TestHandler::<()>::test_router_1()
+ TestHandler::<()>::test_router_2()
}
#[test]
fn test_disable_route() {
let mut router = build_router();
assert_eq!(router.list_all().len(), 4);
assert!(router.has_route("async_function"));
assert!(router.get("async_function").is_some());
assert!(router.disable_route("async_function"));
assert_eq!(router.list_all().len(), 3);
assert!(!router.has_route("async_function"));
assert!(router.get("async_function").is_none());
assert!(router.is_disabled("async_function"));
// other tools unaffected
assert!(router.has_route("async_function2"));
assert!(router.get("async_function2").is_some());
assert!(!router.is_disabled("async_function2"));
}
#[test]
fn test_enable_route() {
let mut router = build_router();
assert!(router.disable_route("async_function"));
assert!(!router.has_route("async_function"));
assert!(router.enable_route("async_function"));
assert!(router.has_route("async_function"));
assert!(router.get("async_function").is_some());
assert!(!router.is_disabled("async_function"));
assert_eq!(router.list_all().len(), 4);
}
#[test]
fn test_with_disabled_builder() {
let router = build_router()
.with_disabled("async_function")
.with_disabled("sync_method");
assert_eq!(router.list_all().len(), 2);
assert!(!router.has_route("async_function"));
assert!(!router.has_route("sync_method"));
assert!(router.has_route("async_function2"));
assert!(router.has_route("async_method"));
}
#[test]
fn test_disabled_tools_survive_merge() {
let mut router_a = ToolRouter::<TestHandler<()>>::new()
.with_route((async_function_tool_attr(), async_function));
assert!(router_a.disable_route("async_function"));
let router_b = ToolRouter::<TestHandler<()>>::new()
.with_route((async_function2_tool_attr(), async_function2));
router_a.merge(router_b);
assert_eq!(router_a.list_all().len(), 1);
assert!(router_a.is_disabled("async_function"));
assert!(router_a.has_route("async_function2"));
}
#[test]
fn test_disable_nonexistent_tool() {
let mut router = build_router();
// should not panic; returns true because the name is newly added to disabled set
assert!(router.disable_route("does_not_exist"));
assert_eq!(router.list_all().len(), 4);
// is_disabled returns false for tools not in the map
assert!(!router.is_disabled("does_not_exist"));
}
#[test]
fn test_remove_route_preserves_disabled_state() {
let mut router = build_router();
assert!(router.disable_route("async_function"));
assert!(router.is_disabled("async_function"));
router.remove_route("async_function");
assert!(!router.has_route("async_function"));
// Disabled marker is preserved — is_disabled returns false (no route in map)
// but re-adding will inherit the disabled state (tested separately)
assert!(!router.is_disabled("async_function"));
}
#[test]
fn test_remove_route_then_readd_stays_disabled() {
let mut router = build_router();
assert!(router.disable_route("async_function"));
router.remove_route("async_function");
assert!(!router.has_route("async_function"));
// Re-add the route — it should inherit the disabled state
let other = ToolRouter::<TestHandler<()>>::new()
.with_route((async_function_tool_attr(), async_function));
router.merge(other);
assert!(!router.has_route("async_function"));
assert!(router.is_disabled("async_function"));
assert!(router.get("async_function").is_none());
}
#[test]
fn test_into_iter_skips_disabled() {
let router = build_router().with_disabled("async_function");
let names: Vec<_> = router
.into_iter()
.map(|r| r.attr.name.to_string())
.collect();
assert_eq!(names.len(), 3);
assert!(!names.contains(&"async_function".to_string()));
}
#[test]
fn test_pre_disable_before_add_route() {
// Disabling a name before adding a route with that name should
// result in the route being disabled once added.
let router = ToolRouter::<TestHandler<()>>::new()
.with_disabled("async_function")
.with_route((async_function_tool_attr(), async_function));
assert_eq!(router.list_all().len(), 0);
assert!(router.is_disabled("async_function"));
assert!(!router.has_route("async_function"));
}
#[test]
fn test_disabled_tool_invisible_across_all_queries() {
let router = build_router().with_disabled("async_function");
// Not listed
let names: Vec<_> = router.list_all().iter().map(|t| t.name.clone()).collect();
assert!(!names.contains(&"async_function".into()));
// Not retrievable
assert!(router.get("async_function").is_none());
// Not routable
assert!(!router.has_route("async_function"));
// But known as disabled
assert!(router.is_disabled("async_function"));
}
#[test]
fn test_disable_route_then_add_route_blocks_tool() {
// Full pre-disable lifecycle via runtime mutation (not builder)
let mut router = ToolRouter::<TestHandler<()>>::new();
router.disable_route("async_function");
// Add route after disabling — tool should be blocked
let other = ToolRouter::<TestHandler<()>>::new()
.with_route((async_function_tool_attr(), async_function));
router.merge(other);
assert!(router.is_disabled("async_function"));
assert!(!router.has_route("async_function"));
assert!(router.get("async_function").is_none());
assert_eq!(router.list_all().len(), 0);
}
#[test]
fn test_disable_enable_return_false_cases() {
let mut router = build_router();
// Repeated disable returns false
assert!(router.disable_route("async_function"));
assert!(!router.disable_route("async_function"));
// Enable returns true, then false on repeat
assert!(router.enable_route("async_function"));
assert!(!router.enable_route("async_function"));
// Enable on name never disabled returns false
assert!(!router.enable_route("async_function2"));
// Enable on unknown name returns false
assert!(!router.enable_route("unknown"));
}
// ── Notifier tests ──────────────────────────────────────────────────────
fn counter_notifier() -> (
impl Fn() + Send + Sync + 'static,
std::sync::Arc<AtomicUsize>,
) {
let counter = std::sync::Arc::new(AtomicUsize::new(0));
let c = counter.clone();
let notifier = move || {
c.fetch_add(1, Ordering::SeqCst);
};
(notifier, counter)
}
#[test]
fn test_notifier_fires_on_disable_and_enable() {
let (notifier, counter) = counter_notifier();
let mut router = build_router();
router.set_notifier(notifier);
assert!(router.disable_route("async_function"));
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert!(!router.disable_route("async_function"));
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert!(router.enable_route("async_function"));
assert_eq!(counter.load(Ordering::SeqCst), 2);
assert!(!router.enable_route("async_function"));
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[test]
fn test_notifier_skips_nonexistent_tools() {
let (notifier, counter) = counter_notifier();
let mut router = build_router();
router.set_notifier(notifier);
assert!(router.disable_route("does_not_exist"));
assert_eq!(counter.load(Ordering::SeqCst), 0);
assert!(router.enable_route("does_not_exist"));
assert_eq!(counter.load(Ordering::SeqCst), 0);
assert!(router.disable_route("future_tool"));
assert_eq!(counter.load(Ordering::SeqCst), 0);
assert!(router.enable_route("future_tool"));
assert_eq!(counter.load(Ordering::SeqCst), 0);
}
#[test]
fn test_no_notifier_no_panic() {
let mut router = build_router();
assert!(router.disable_route("async_function"));
assert!(router.enable_route("async_function"));
assert!(router.disable_route("async_function"));
assert!(!router.disable_route("async_function"));
}
#[test]
fn test_clone_shares_notifier() {
let (notifier, counter) = counter_notifier();
let mut router = build_router();
router.set_notifier(notifier);
let mut cloned = router.clone();
assert!(cloned.disable_route("async_function"));
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert!(router.disable_route("async_function"));
assert_eq!(counter.load(Ordering::SeqCst), 2);
cloned.clear_notifier();
assert!(cloned.enable_route("async_function"));
assert_eq!(counter.load(Ordering::SeqCst), 2);
assert!(router.enable_route("async_function"));
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[test]
fn test_pre_init_disable_silent_but_correct() {
let mut router = build_router();
assert!(router.disable_route("async_function"));
assert_eq!(router.list_all().len(), 3);
assert!(!router.has_route("async_function"));
let (notifier, counter) = counter_notifier();
router.set_notifier(notifier);
assert_eq!(counter.load(Ordering::SeqCst), 0);
assert!(router.enable_route("async_function"));
assert_eq!(counter.load(Ordering::SeqCst), 1);
}

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is not supported yet