| use crate::model::InitializeRequestParams; | ||
| /// State persisted to an external store for cross-instance session recovery. | ||
| /// | ||
| /// When a client reconnects to a different server instance, the new instance | ||
| /// loads this state to transparently replay the `initialize` handshake without | ||
| /// the client needing to re-initialize. | ||
| #[non_exhaustive] | ||
| #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] | ||
| pub struct SessionState { | ||
| /// Parameters from the client's original `initialize` request. | ||
| pub initialize_params: InitializeRequestParams, | ||
| } | ||
| impl SessionState { | ||
| pub fn new(initialize_params: InitializeRequestParams) -> Self { | ||
| Self { initialize_params } | ||
| } | ||
| } | ||
| /// Type alias for boxed session store errors. | ||
| pub type SessionStoreError = Box<dyn std::error::Error + Send + Sync + 'static>; | ||
| /// Pluggable external session store for cross-instance recovery. | ||
| /// | ||
| /// Implement this trait to back sessions with Redis, a database, or any | ||
| /// key-value store. The simplest usage is to set | ||
| /// `StreamableHttpServerConfig::session_store` to an `Arc<impl SessionStore>`. | ||
| /// | ||
| /// # Example (in-memory, for testing) | ||
| /// | ||
| /// ```rust,ignore | ||
| /// use std::{collections::HashMap, sync::Arc}; | ||
| /// use tokio::sync::RwLock; | ||
| /// use rmcp::transport::streamable_http_server::session::store::{ | ||
| /// SessionState, SessionStore, SessionStoreError, | ||
| /// }; | ||
| /// | ||
| /// #[derive(Default)] | ||
| /// struct InMemoryStore(Arc<RwLock<HashMap<String, SessionState>>>); | ||
| /// | ||
| /// #[async_trait::async_trait] | ||
| /// impl SessionStore for InMemoryStore { | ||
| /// async fn load(&self, id: &str) -> Result<Option<SessionState>, SessionStoreError> { | ||
| /// Ok(self.0.read().await.get(id).cloned()) | ||
| /// } | ||
| /// async fn store(&self, id: &str, state: &SessionState) -> Result<(), SessionStoreError> { | ||
| /// self.0.write().await.insert(id.to_owned(), state.clone()); | ||
| /// Ok(()) | ||
| /// } | ||
| /// async fn delete(&self, id: &str) -> Result<(), SessionStoreError> { | ||
| /// self.0.write().await.remove(id); | ||
| /// Ok(()) | ||
| /// } | ||
| /// } | ||
| /// ``` | ||
| #[async_trait::async_trait] | ||
| pub trait SessionStore: Send + Sync + 'static { | ||
| /// Load session state for the given `session_id`. | ||
| /// | ||
| /// Returns `Ok(None)` when no entry exists (i.e. session is unknown to the store). | ||
| async fn load(&self, session_id: &str) -> Result<Option<SessionState>, SessionStoreError>; | ||
| /// Persist session state for the given `session_id`. | ||
| async fn store(&self, session_id: &str, state: &SessionState) -> Result<(), SessionStoreError>; | ||
| /// Remove session state for the given `session_id`. | ||
| async fn delete(&self, session_id: &str) -> Result<(), SessionStoreError>; | ||
| } |
| #![cfg(all(feature = "transport-streamable-http-server", not(feature = "local")))] | ||
| use std::time::Duration; | ||
| use rmcp::{ | ||
| model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId}, | ||
| transport::streamable_http_server::session::{SessionManager, local::LocalSessionManager}, | ||
| }; | ||
| #[tokio::test] | ||
| async fn test_init_timeout_terminates_pre_init_session() -> anyhow::Result<()> { | ||
| let mut manager = LocalSessionManager::default(); | ||
| manager.session_config.init_timeout = Some(Duration::from_millis(200)); | ||
| // Bind the transport so its drop-guard doesn't cancel the worker — we | ||
| // want termination via init_timeout, not via cancellation. | ||
| let (session_id, _transport) = manager.create_session().await?; | ||
| tokio::time::sleep(Duration::from_millis(500)).await; | ||
| let message = ClientJsonRpcMessage::request( | ||
| ClientRequest::PingRequest(PingRequest::default()), | ||
| RequestId::Number(1), | ||
| ); | ||
| let result = manager.initialize_session(&session_id, message).await; | ||
| assert!( | ||
| result.is_err(), | ||
| "expected worker to be dead; got: {result:?}" | ||
| ); | ||
| Ok(()) | ||
| } | ||
| #[tokio::test] | ||
| async fn test_init_timeout_none_keeps_worker_alive() -> anyhow::Result<()> { | ||
| let mut manager = LocalSessionManager::default(); | ||
| manager.session_config.init_timeout = None; | ||
| let (session_id, _transport) = manager.create_session().await?; | ||
| tokio::time::sleep(Duration::from_millis(500)).await; | ||
| let message = ClientJsonRpcMessage::request( | ||
| ClientRequest::PingRequest(PingRequest::default()), | ||
| RequestId::Number(1), | ||
| ); | ||
| // Liveness probe: a live worker accepts the send then stalls waiting for | ||
| // a handler response (none is wired up), tripping the outer timeout. A | ||
| // dead worker would fail the send and return immediately. | ||
| let probe = tokio::time::timeout( | ||
| Duration::from_millis(200), | ||
| manager.initialize_session(&session_id, message), | ||
| ) | ||
| .await; | ||
| assert!( | ||
| probe.is_err(), | ||
| "expected worker to be alive; got: {probe:?}" | ||
| ); | ||
| Ok(()) | ||
| } |
| #![cfg(all( | ||
| feature = "client", | ||
| feature = "server", | ||
| feature = "transport-streamable-http-client-reqwest", | ||
| feature = "transport-streamable-http-server", | ||
| not(feature = "local") | ||
| ))] | ||
| use std::{collections::HashMap, sync::Arc}; | ||
| use rmcp::{ | ||
| ServiceExt, | ||
| transport::{ | ||
| StreamableHttpClientTransport, | ||
| streamable_http_client::StreamableHttpClientTransportConfig, | ||
| streamable_http_server::{ | ||
| StreamableHttpServerConfig, StreamableHttpService, | ||
| session::{SessionState, SessionStore, SessionStoreError, local::LocalSessionManager}, | ||
| }, | ||
| }, | ||
| }; | ||
| use tokio::sync::RwLock; | ||
| use tokio_util::sync::CancellationToken; | ||
| mod common; | ||
| use common::calculator::Calculator; | ||
| // --------------------------------------------------------------------------- | ||
| // Shared in-memory store used across tests | ||
| // --------------------------------------------------------------------------- | ||
| #[derive(Default, Clone)] | ||
| struct InMemorySessionStore(Arc<RwLock<HashMap<String, SessionState>>>); | ||
| impl InMemorySessionStore { | ||
| fn new() -> Self { | ||
| Self::default() | ||
| } | ||
| async fn len(&self) -> usize { | ||
| self.0.read().await.len() | ||
| } | ||
| } | ||
| #[async_trait::async_trait] | ||
| impl SessionStore for InMemorySessionStore { | ||
| async fn load(&self, session_id: &str) -> Result<Option<SessionState>, SessionStoreError> { | ||
| Ok(self.0.read().await.get(session_id).cloned()) | ||
| } | ||
| async fn store(&self, session_id: &str, state: &SessionState) -> Result<(), SessionStoreError> { | ||
| self.0 | ||
| .write() | ||
| .await | ||
| .insert(session_id.to_owned(), state.clone()); | ||
| Ok(()) | ||
| } | ||
| async fn delete(&self, session_id: &str) -> Result<(), SessionStoreError> { | ||
| self.0.write().await.remove(session_id); | ||
| Ok(()) | ||
| } | ||
| } | ||
| // --------------------------------------------------------------------------- | ||
| // Helper: spin up a StreamableHttpService backed by the given store and | ||
| // return the bound address together with the cancellation token. | ||
| // --------------------------------------------------------------------------- | ||
| fn make_service( | ||
| session_store: Arc<dyn SessionStore>, | ||
| ct: &CancellationToken, | ||
| ) -> StreamableHttpService<Calculator, LocalSessionManager> { | ||
| StreamableHttpService::new(|| Ok(Calculator::new()), Default::default(), { | ||
| let mut cfg = StreamableHttpServerConfig::default(); | ||
| cfg.stateful_mode = true; | ||
| cfg.sse_keep_alive = None; | ||
| cfg.cancellation_token = ct.child_token(); | ||
| cfg.session_store = Some(session_store); | ||
| cfg | ||
| }) | ||
| } | ||
| // --------------------------------------------------------------------------- | ||
| // Test 1 — state is persisted to the store after a successful handshake | ||
| // --------------------------------------------------------------------------- | ||
| #[tokio::test] | ||
| async fn test_session_state_persisted_to_store() -> anyhow::Result<()> { | ||
| let store = Arc::new(InMemorySessionStore::new()); | ||
| let ct = CancellationToken::new(); | ||
| let service = make_service(store.clone(), &ct); | ||
| let router = axum::Router::new().nest_service("/mcp", service); | ||
| let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; | ||
| let addr = listener.local_addr()?; | ||
| let handle = tokio::spawn({ | ||
| let ct = ct.clone(); | ||
| async move { | ||
| let _ = axum::serve(listener, router) | ||
| .with_graceful_shutdown(async move { ct.cancelled_owned().await }) | ||
| .await; | ||
| } | ||
| }); | ||
| // Connect a full client — this performs the initialize + initialized handshake. | ||
| let transport = StreamableHttpClientTransport::from_config( | ||
| StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")), | ||
| ); | ||
| let client = ().serve(transport).await?; | ||
| // Make a real request so the session is fully active. | ||
| let _resources = client.list_all_resources().await?; | ||
| // The store should now contain exactly one session entry. | ||
| assert_eq!( | ||
| store.len().await, | ||
| 1, | ||
| "session state should be persisted to the store after initialization" | ||
| ); | ||
| // Verify the stored state contains the expected client info. | ||
| let entries = store.0.read().await; | ||
| let state = entries.values().next().expect("store entry should exist"); | ||
| assert_eq!( | ||
| state.initialize_params.client_info.name, "rmcp", | ||
| "stored client_info.name should match the rmcp client" | ||
| ); | ||
| let _ = client.cancel().await; | ||
| ct.cancel(); | ||
| handle.await?; | ||
| Ok(()) | ||
| } | ||
| // --------------------------------------------------------------------------- | ||
| // Test 2 — store entry is removed when the client sends HTTP DELETE | ||
| // --------------------------------------------------------------------------- | ||
| #[tokio::test] | ||
| async fn test_session_state_deleted_from_store_on_delete() -> anyhow::Result<()> { | ||
| let store = Arc::new(InMemorySessionStore::new()); | ||
| let session_manager = Arc::new(LocalSessionManager::default()); | ||
| let ct = CancellationToken::new(); | ||
| let service = StreamableHttpService::new(|| Ok(Calculator::new()), session_manager.clone(), { | ||
| let mut cfg = StreamableHttpServerConfig::default(); | ||
| cfg.stateful_mode = true; | ||
| cfg.sse_keep_alive = None; | ||
| cfg.cancellation_token = ct.child_token(); | ||
| cfg.session_store = Some(store.clone()); | ||
| cfg | ||
| }); | ||
| let router = axum::Router::new().nest_service("/mcp", service); | ||
| let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; | ||
| let addr = listener.local_addr()?; | ||
| let handle = tokio::spawn({ | ||
| let ct = ct.clone(); | ||
| async move { | ||
| let _ = axum::serve(listener, router) | ||
| .with_graceful_shutdown(async move { ct.cancelled_owned().await }) | ||
| .await; | ||
| } | ||
| }); | ||
| let transport = StreamableHttpClientTransport::from_config( | ||
| StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")), | ||
| ); | ||
| let client = ().serve(transport).await?; | ||
| let _resources = client.list_all_resources().await?; | ||
| assert_eq!(store.len().await, 1, "store should have one entry"); | ||
| // Get the session ID from the server's in-memory map. | ||
| let session_id = { | ||
| let sessions = session_manager.sessions.read().await; | ||
| sessions | ||
| .keys() | ||
| .next() | ||
| .cloned() | ||
| .expect("session should exist") | ||
| }; | ||
| // Send an explicit HTTP DELETE — this is the signal to remove from store. | ||
| let http_client = reqwest::Client::new(); | ||
| let response = http_client | ||
| .delete(format!("http://{addr}/mcp")) | ||
| .header("mcp-session-id", session_id.as_ref()) | ||
| .send() | ||
| .await?; | ||
| assert_eq!(response.status(), 202); | ||
| assert_eq!( | ||
| store.len().await, | ||
| 0, | ||
| "store entry should be removed after explicit DELETE" | ||
| ); | ||
| let _ = client.cancel().await; | ||
| ct.cancel(); | ||
| handle.await?; | ||
| Ok(()) | ||
| } | ||
| // --------------------------------------------------------------------------- | ||
| // Helper: spin up a server on an ephemeral port and return its address and | ||
| // the join handle. The server shuts down when `ct` is cancelled. | ||
| // --------------------------------------------------------------------------- | ||
| fn spawn_server( | ||
| session_store: Option<Arc<dyn SessionStore>>, | ||
| session_manager: Arc<LocalSessionManager>, | ||
| ct: &CancellationToken, | ||
| ) -> (std::net::SocketAddr, tokio::task::JoinHandle<()>) { | ||
| let svc = StreamableHttpService::new(|| Ok(Calculator::new()), session_manager, { | ||
| let mut cfg = StreamableHttpServerConfig::default(); | ||
| cfg.stateful_mode = true; | ||
| cfg.sse_keep_alive = None; | ||
| cfg.cancellation_token = ct.child_token(); | ||
| cfg.session_store = session_store; | ||
| cfg | ||
| }); | ||
| // Use std::net::TcpListener so the port is bound synchronously before | ||
| // we return — avoids a race between returning the addr and the server | ||
| // actually starting to accept connections. | ||
| let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); | ||
| std_listener.set_nonblocking(true).unwrap(); | ||
| let addr = std_listener.local_addr().unwrap(); | ||
| let listener = tokio::net::TcpListener::from_std(std_listener).unwrap(); | ||
| let router = axum::Router::new().nest_service("/mcp", svc); | ||
| let handle = tokio::spawn({ | ||
| let ct = ct.clone(); | ||
| async move { | ||
| let _ = axum::serve(listener, router) | ||
| .with_graceful_shutdown(async move { ct.cancelled_owned().await }) | ||
| .await; | ||
| } | ||
| }); | ||
| (addr, handle) | ||
| } | ||
| // --------------------------------------------------------------------------- | ||
| // Test 3 — cross-instance session restore | ||
| // | ||
| // Both halves follow the same structure: | ||
| // | ||
| // Instance A initializes the session (session state may be saved to store) | ||
| // Instance A is fully shut down | ||
| // Instance B (fresh, no in-memory state) receives a request for the old ID | ||
| // | ||
| // Without a store → 404. With a shared store → transparent restore. | ||
| // --------------------------------------------------------------------------- | ||
| #[tokio::test] | ||
| async fn test_cross_instance_session_restore() -> anyhow::Result<()> { | ||
| let http = reqwest::Client::new(); | ||
| // ----------------------------------------------------------------------- | ||
| // Negative check: no session store → instance B returns 404. | ||
| // ----------------------------------------------------------------------- | ||
| { | ||
| // --- Instance A (no store): initialize --- | ||
| let ct_a = CancellationToken::new(); | ||
| let (addr_a, srv_a) = spawn_server(None, Arc::new(LocalSessionManager::default()), &ct_a); | ||
| let init_resp = http | ||
| .post(format!("http://{addr_a}/mcp")) | ||
| .header("accept", "application/json, text/event-stream") | ||
| .header("content-type", "application/json") | ||
| .body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"test","version":"0"}}}"#) | ||
| .send() | ||
| .await?; | ||
| assert_eq!( | ||
| init_resp.status(), | ||
| 200, | ||
| "instance A: initialize should succeed" | ||
| ); | ||
| let session_id = init_resp | ||
| .headers() | ||
| .get("mcp-session-id") | ||
| .expect("session ID header must be present") | ||
| .to_str()? | ||
| .to_owned(); | ||
| // Shut down instance A completely. | ||
| ct_a.cancel(); | ||
| srv_a.await?; | ||
| // --- Instance B (no store, fresh state): send request --- | ||
| let ct_b = CancellationToken::new(); | ||
| let (addr_b, srv_b) = spawn_server(None, Arc::new(LocalSessionManager::default()), &ct_b); | ||
| let resp = http | ||
| .post(format!("http://{addr_b}/mcp")) | ||
| .header("accept", "application/json, text/event-stream") | ||
| .header("content-type", "application/json") | ||
| .header("mcp-session-id", &session_id) | ||
| .body(r#"{"jsonrpc":"2.0","id":2,"method":"ping","params":{}}"#) | ||
| .send() | ||
| .await?; | ||
| assert_eq!( | ||
| resp.status(), | ||
| reqwest::StatusCode::NOT_FOUND, | ||
| "without a session store, instance B must return 404 for an unknown session ID" | ||
| ); | ||
| ct_b.cancel(); | ||
| srv_b.await?; | ||
| } | ||
| // ----------------------------------------------------------------------- | ||
| // Positive check: shared session store → instance B restores transparently. | ||
| // ----------------------------------------------------------------------- | ||
| { | ||
| let store: Arc<dyn SessionStore> = Arc::new(InMemorySessionStore::new()); | ||
| // --- Instance A (with store): initialize --- | ||
| let ct_a = CancellationToken::new(); | ||
| let sm_a = Arc::new(LocalSessionManager::default()); | ||
| let (addr_a, srv_a) = spawn_server(Some(store.clone()), sm_a.clone(), &ct_a); | ||
| let init_resp = http | ||
| .post(format!("http://{addr_a}/mcp")) | ||
| .header("accept", "application/json, text/event-stream") | ||
| .header("content-type", "application/json") | ||
| .body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"test","version":"0"}}}"#) | ||
| .send() | ||
| .await?; | ||
| assert_eq!( | ||
| init_resp.status(), | ||
| 200, | ||
| "instance A: initialize should succeed" | ||
| ); | ||
| let original_session_id = init_resp | ||
| .headers() | ||
| .get("mcp-session-id") | ||
| .expect("session ID header must be present") | ||
| .to_str()? | ||
| .to_owned(); | ||
| // Confirm the session was persisted. | ||
| let store_ref = store | ||
| .load(&original_session_id) | ||
| .await | ||
| .expect("store load should not error"); | ||
| assert!( | ||
| store_ref.is_some(), | ||
| "store should hold the session after initialization" | ||
| ); | ||
| // Shut down instance A completely — session lives only in the store now. | ||
| ct_a.cancel(); | ||
| srv_a.await?; | ||
| // --- Instance B (same store, fresh in-memory state): send request --- | ||
| let ct_b = CancellationToken::new(); | ||
| let sm_b = Arc::new(LocalSessionManager::default()); | ||
| let (addr_b, srv_b) = spawn_server(Some(store.clone()), sm_b.clone(), &ct_b); | ||
| let resp = http | ||
| .post(format!("http://{addr_b}/mcp")) | ||
| .header("accept", "application/json, text/event-stream") | ||
| .header("content-type", "application/json") | ||
| .header("mcp-session-id", &original_session_id) | ||
| .body(r#"{"jsonrpc":"2.0","id":2,"method":"ping","params":{}}"#) | ||
| .send() | ||
| .await?; | ||
| assert_eq!( | ||
| resp.status(), | ||
| 200, | ||
| "instance B: request must succeed after transparent restore" | ||
| ); | ||
| // The session must be in instance B's memory under the ORIGINAL ID. | ||
| { | ||
| let sessions = sm_b.sessions.read().await; | ||
| let restored_id = sessions | ||
| .keys() | ||
| .next() | ||
| .expect("session should exist in instance B after restore"); | ||
| assert_eq!( | ||
| restored_id.as_ref(), | ||
| original_session_id.as_str(), | ||
| "restored session must keep the original session ID" | ||
| ); | ||
| } | ||
| ct_b.cancel(); | ||
| srv_b.await?; | ||
| } | ||
| Ok(()) | ||
| } |
| //! Integration tests for tool list change notifications. | ||
| #![cfg(all(feature = "client", not(feature = "local")))] | ||
| use std::sync::{ | ||
| Arc, | ||
| atomic::{AtomicUsize, Ordering}, | ||
| }; | ||
| use rmcp::{ | ||
| ClientHandler, RoleClient, RoleServer, ServerHandler, ServiceExt, | ||
| handler::server::{router::tool::ToolRoute, tool::ToolCallContext}, | ||
| model::{CallToolResult, ServerCapabilities, ServerInfo, Tool}, | ||
| service::{MaybeSendFuture, NotificationContext}, | ||
| }; | ||
| use tokio::sync::{Notify, RwLock}; | ||
| #[derive(Clone)] | ||
| struct TestToolServer { | ||
| router: Arc<RwLock<rmcp::handler::server::router::tool::ToolRouter<Self>>>, | ||
| trigger_disable: Arc<Notify>, | ||
| trigger_enable: Arc<Notify>, | ||
| } | ||
| impl TestToolServer { | ||
| fn new() -> Self { | ||
| let mut tool_router = rmcp::handler::server::router::tool::ToolRouter::<Self>::new(); | ||
| tool_router.add_route(ToolRoute::new_dyn( | ||
| Tool::new("tool_a", "Tool A", Arc::new(Default::default())), | ||
| |_ctx| Box::pin(async { Ok(CallToolResult::default()) }), | ||
| )); | ||
| tool_router.add_route(ToolRoute::new_dyn( | ||
| Tool::new("tool_b", "Tool B", Arc::new(Default::default())), | ||
| |_ctx| Box::pin(async { Ok(CallToolResult::default()) }), | ||
| )); | ||
| Self { | ||
| router: Arc::new(RwLock::new(tool_router)), | ||
| trigger_disable: Arc::new(Notify::new()), | ||
| trigger_enable: Arc::new(Notify::new()), | ||
| } | ||
| } | ||
| } | ||
| impl ServerHandler for TestToolServer { | ||
| fn get_info(&self) -> ServerInfo { | ||
| ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) | ||
| } | ||
| fn call_tool( | ||
| &self, | ||
| request: rmcp::model::CallToolRequestParams, | ||
| context: rmcp::service::RequestContext<RoleServer>, | ||
| ) -> impl std::future::Future<Output = Result<CallToolResult, rmcp::ErrorData>> + MaybeSendFuture + '_ | ||
| { | ||
| async move { | ||
| let router = self.router.read().await; | ||
| let tcc = ToolCallContext::new(self, request, context); | ||
| router.call(tcc).await | ||
| } | ||
| } | ||
| fn list_tools( | ||
| &self, | ||
| _request: Option<rmcp::model::PaginatedRequestParams>, | ||
| _context: rmcp::service::RequestContext<RoleServer>, | ||
| ) -> impl std::future::Future<Output = Result<rmcp::model::ListToolsResult, rmcp::ErrorData>> | ||
| + MaybeSendFuture | ||
| + '_ { | ||
| async move { | ||
| let router = self.router.read().await; | ||
| Ok(rmcp::model::ListToolsResult { | ||
| tools: router.list_all(), | ||
| ..Default::default() | ||
| }) | ||
| } | ||
| } | ||
| fn on_initialized( | ||
| &self, | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl std::future::Future<Output = ()> + MaybeSendFuture + '_ { | ||
| let router = self.router.clone(); | ||
| let trigger_disable = self.trigger_disable.clone(); | ||
| let trigger_enable = self.trigger_enable.clone(); | ||
| let peer = context.peer.clone(); | ||
| async move { | ||
| router.write().await.bind_peer_notifier(&peer); | ||
| let router = router.clone(); | ||
| tokio::spawn(async move { | ||
| trigger_disable.notified().await; | ||
| { | ||
| let mut r = router.write().await; | ||
| r.disable_route("tool_a"); | ||
| } | ||
| trigger_enable.notified().await; | ||
| { | ||
| let mut r = router.write().await; | ||
| r.enable_route("tool_a"); | ||
| } | ||
| }); | ||
| } | ||
| } | ||
| } | ||
| #[derive(Clone)] | ||
| struct TestToolClient { | ||
| notification_count: Arc<AtomicUsize>, | ||
| notify: Arc<Notify>, | ||
| } | ||
| impl TestToolClient { | ||
| fn new() -> Self { | ||
| Self { | ||
| notification_count: Arc::new(AtomicUsize::new(0)), | ||
| notify: Arc::new(Notify::new()), | ||
| } | ||
| } | ||
| } | ||
| impl ClientHandler for TestToolClient { | ||
| fn on_tool_list_changed( | ||
| &self, | ||
| _context: NotificationContext<RoleClient>, | ||
| ) -> impl std::future::Future<Output = ()> + MaybeSendFuture + '_ { | ||
| self.notification_count.fetch_add(1, Ordering::SeqCst); | ||
| self.notify.notify_one(); | ||
| std::future::ready(()) | ||
| } | ||
| } | ||
| #[tokio::test] | ||
| async fn test_disable_enable_sends_tool_list_changed() { | ||
| let server = TestToolServer::new(); | ||
| let trigger_disable = server.trigger_disable.clone(); | ||
| let trigger_enable = server.trigger_enable.clone(); | ||
| let client = TestToolClient::new(); | ||
| let notification_count = client.notification_count.clone(); | ||
| let client_notify = client.notify.clone(); | ||
| let (server_transport, client_transport) = tokio::io::duplex(4096); | ||
| let server_handle = tokio::spawn(async move { server.serve(server_transport).await }); | ||
| let client_service = client.serve(client_transport).await.unwrap(); | ||
| let tools = client_service.peer().list_tools(None).await.unwrap(); | ||
| assert_eq!(tools.tools.len(), 2); | ||
| trigger_disable.notify_one(); | ||
| tokio::time::timeout(std::time::Duration::from_secs(5), client_notify.notified()) | ||
| .await | ||
| .expect("timed out waiting for tool_list_changed"); | ||
| assert_eq!(notification_count.load(Ordering::SeqCst), 1); | ||
| let tools = client_service.peer().list_tools(None).await.unwrap(); | ||
| assert_eq!(tools.tools.len(), 1); | ||
| assert_eq!(tools.tools[0].name, "tool_b"); | ||
| trigger_enable.notify_one(); | ||
| tokio::time::timeout(std::time::Duration::from_secs(5), client_notify.notified()) | ||
| .await | ||
| .expect("timed out waiting for tool_list_changed"); | ||
| assert_eq!(notification_count.load(Ordering::SeqCst), 2); | ||
| let tools = client_service.peer().list_tools(None).await.unwrap(); | ||
| assert_eq!(tools.tools.len(), 2); | ||
| client_service.cancel().await.unwrap(); | ||
| server_handle.abort(); | ||
| } |
| { | ||
| "git": { | ||
| "sha1": "020a38b6ad3d0f26487c464250a484fad2a06b0e" | ||
| "sha1": "014fb2e6cd9faddbe86ae30b5cc9adf84a62edb9" | ||
| }, | ||
| "path_in_vcs": "crates/rmcp" | ||
| } |
+20
-2
@@ -15,3 +15,3 @@ # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO | ||
| name = "rmcp" | ||
| version = "1.5.0" | ||
| version = "1.6.0" | ||
| build = "build.rs" | ||
@@ -354,2 +354,6 @@ autolib = false | ||
| [[test]] | ||
| name = "test_streamable_http_init_timeout" | ||
| path = "tests/test_streamable_http_init_timeout.rs" | ||
| [[test]] | ||
| name = "test_streamable_http_json_response" | ||
@@ -375,2 +379,12 @@ path = "tests/test_streamable_http_json_response.rs" | ||
| [[test]] | ||
| name = "test_streamable_http_session_store" | ||
| path = "tests/test_streamable_http_session_store.rs" | ||
| required-features = [ | ||
| "client", | ||
| "server", | ||
| "transport-streamable-http-client-reqwest", | ||
| "transport-streamable-http-server", | ||
| ] | ||
| [[test]] | ||
| name = "test_streamable_http_stale_session" | ||
@@ -408,2 +422,6 @@ path = "tests/test_streamable_http_stale_session.rs" | ||
| [[test]] | ||
| name = "test_tool_disable_notification" | ||
| path = "tests/test_tool_disable_notification.rs" | ||
| [[test]] | ||
| name = "test_tool_handler" | ||
@@ -537,3 +555,3 @@ path = "tests/test_tool_handler.rs" | ||
| [dependencies.rmcp-macros] | ||
| version = "1.5.0" | ||
| version = "1.6.0" | ||
| optional = true | ||
@@ -540,0 +558,0 @@ |
+19
-0
@@ -10,2 +10,21 @@ # Changelog | ||
| ## [1.6.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v1.5.0...rmcp-v1.6.0) - 2026-05-01 | ||
| ### Added | ||
| - *(http)* log Host/Origin rejections ([#826](https://github.com/modelcontextprotocol/rust-sdk/pull/826)) | ||
| - *(http)* add Origin header validation ([#823](https://github.com/modelcontextprotocol/rust-sdk/pull/823)) | ||
| - *(router)* support runtime disabling of tools ([#809](https://github.com/modelcontextprotocol/rust-sdk/pull/809)) | ||
| - optional session store (resumabillity support) ([#775](https://github.com/modelcontextprotocol/rust-sdk/pull/775)) | ||
| ### Fixed | ||
| - add init_timeout for streamable-http sessions ([#811](https://github.com/modelcontextprotocol/rust-sdk/pull/811)) | ||
| - *(http)* fall back to :authority for HTTP/2 ([#827](https://github.com/modelcontextprotocol/rust-sdk/pull/827)) | ||
| - *(docs)* use correct Parameters<T> syntax in tool examples ([#814](https://github.com/modelcontextprotocol/rust-sdk/pull/814)) | ||
| ### Other | ||
| - add systemprompt-template to Built with rmcp ([#820](https://github.com/modelcontextprotocol/rust-sdk/pull/820)) | ||
| ## [1.5.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v1.4.0...rmcp-v1.5.0) - 2026-04-16 | ||
@@ -12,0 +31,0 @@ |
@@ -9,3 +9,3 @@ use std::sync::Arc; | ||
| RoleServer, Service, | ||
| model::{ClientRequest, ListPromptsResult, ListToolsResult, ServerResult}, | ||
| model::{ClientNotification, ClientRequest, ListPromptsResult, ListToolsResult, ServerResult}, | ||
| service::NotificationContext, | ||
@@ -22,2 +22,3 @@ }; | ||
| pub service: Arc<S>, | ||
| peer_slot: Arc<std::sync::OnceLock<crate::service::Peer<RoleServer>>>, | ||
| } | ||
@@ -30,6 +31,10 @@ | ||
| pub fn new(service: S) -> Self { | ||
| let (notifier, peer_slot) = tool::ToolRouter::<S>::deferred_peer_notifier(); | ||
| let mut tool_router = tool::ToolRouter::new(); | ||
| tool_router.set_notifier(notifier); | ||
| Self { | ||
| tool_router: tool::ToolRouter::new(), | ||
| tool_router, | ||
| prompt_router: prompt::PromptRouter::new(), | ||
| service: Arc::new(service), | ||
| peer_slot, | ||
| } | ||
@@ -78,2 +83,8 @@ } | ||
| ) -> Result<(), crate::ErrorData> { | ||
| if matches!( | ||
| ¬ification, | ||
| ClientNotification::InitializedNotification(_) | ||
| ) { | ||
| let _ = self.peer_slot.set(context.peer.clone()); | ||
| } | ||
| self.service | ||
@@ -90,3 +101,6 @@ .handle_notification(notification, context) | ||
| ClientRequest::CallToolRequest(request) => { | ||
| if self.tool_router.has_route(request.params.name.as_ref()) | ||
| if self | ||
| .tool_router | ||
| .map | ||
| .contains_key(request.params.name.as_ref()) | ||
| || !self.tool_router.transparent_when_not_found | ||
@@ -142,4 +156,79 @@ { | ||
| fn get_info(&self) -> <RoleServer as crate::service::ServiceRole>::Info { | ||
| ServerHandler::get_info(&self.service) | ||
| let mut info = ServerHandler::get_info(&self.service); | ||
| info.capabilities | ||
| .tools | ||
| .get_or_insert_with(Default::default) | ||
| .list_changed = Some(true); | ||
| info | ||
| } | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use std::sync::Arc; | ||
| use super::*; | ||
| use crate::{ | ||
| model::{CallToolResult, ClientNotification, ServerNotification, Tool}, | ||
| service::{AtomicU32RequestIdProvider, Peer, PeerSinkMessage, RequestIdProvider}, | ||
| }; | ||
| struct DummyHandler; | ||
| impl ServerHandler for DummyHandler {} | ||
| async fn recv_notification( | ||
| rx: &mut tokio::sync::mpsc::Receiver<PeerSinkMessage<RoleServer>>, | ||
| ) -> ServerNotification { | ||
| let msg = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv()) | ||
| .await | ||
| .expect("timed out") | ||
| .expect("channel closed"); | ||
| match msg { | ||
| PeerSinkMessage::Notification { | ||
| notification, | ||
| responder, | ||
| } => { | ||
| let _ = responder.send(Ok(())); | ||
| notification | ||
| } | ||
| other => panic!("expected notification, got {other:?}"), | ||
| } | ||
| } | ||
| #[tokio::test] | ||
| async fn test_router_deferred_notifier_e2e() { | ||
| let mut router = Router::new(DummyHandler).with_tool(tool::ToolRoute::new_dyn( | ||
| Tool::new("my_tool", "test", Arc::new(Default::default())), | ||
| |_ctx| Box::pin(async { Ok(CallToolResult::default()) }), | ||
| )); | ||
| let id_provider: Arc<dyn RequestIdProvider> = | ||
| Arc::new(AtomicU32RequestIdProvider::default()); | ||
| let (peer, mut rx) = Peer::<RoleServer>::new(id_provider, None); | ||
| let context = crate::service::NotificationContext { | ||
| peer: peer.clone(), | ||
| meta: Default::default(), | ||
| extensions: Default::default(), | ||
| }; | ||
| router | ||
| .handle_notification( | ||
| ClientNotification::InitializedNotification(Default::default()), | ||
| context, | ||
| ) | ||
| .await | ||
| .unwrap(); | ||
| router.tool_router.disable_route("my_tool"); | ||
| assert!(matches!( | ||
| recv_notification(&mut rx).await, | ||
| ServerNotification::ToolListChangedNotification(_) | ||
| )); | ||
| router.tool_router.enable_route("my_tool"); | ||
| assert!(matches!( | ||
| recv_notification(&mut rx).await, | ||
| ServerNotification::ToolListChangedNotification(_) | ||
| )); | ||
| } | ||
| } |
@@ -301,3 +301,2 @@ //! Tools for MCP servers. | ||
| } | ||
| #[derive(Debug)] | ||
| #[non_exhaustive] | ||
@@ -309,4 +308,22 @@ pub struct ToolRouter<S> { | ||
| pub transparent_when_not_found: bool, | ||
| disabled: std::collections::HashSet<Cow<'static, str>>, | ||
| notifier: Option<Arc<dyn Fn() + Send + Sync>>, | ||
| } | ||
| impl<S> std::fmt::Debug for ToolRouter<S> { | ||
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
| f.debug_struct("ToolRouter") | ||
| .field("map", &self.map) | ||
| .field( | ||
| "transparent_when_not_found", | ||
| &self.transparent_when_not_found, | ||
| ) | ||
| .field("disabled", &self.disabled) | ||
| .field("notifier", &self.notifier.as_ref().map(|_| "...")) | ||
| .finish() | ||
| } | ||
| } | ||
| impl<S> Default for ToolRouter<S> { | ||
@@ -317,5 +334,8 @@ fn default() -> Self { | ||
| transparent_when_not_found: false, | ||
| disabled: std::collections::HashSet::new(), | ||
| notifier: None, | ||
| } | ||
| } | ||
| } | ||
| impl<S> Clone for ToolRouter<S> { | ||
@@ -326,2 +346,4 @@ fn clone(&self) -> Self { | ||
| transparent_when_not_found: self.transparent_when_not_found, | ||
| disabled: self.disabled.clone(), | ||
| notifier: self.notifier.clone(), | ||
| } | ||
@@ -336,3 +358,7 @@ } | ||
| fn into_iter(self) -> Self::IntoIter { | ||
| self.map.into_values() | ||
| let mut map = self.map; | ||
| for name in &self.disabled { | ||
| map.remove(name); | ||
| } | ||
| map.into_values() | ||
| } | ||
@@ -346,6 +372,3 @@ } | ||
| pub fn new() -> Self { | ||
| Self { | ||
| map: std::collections::HashMap::new(), | ||
| transparent_when_not_found: false, | ||
| } | ||
| Self::default() | ||
| } | ||
@@ -403,2 +426,3 @@ pub fn with_route<R, A>(mut self, route: R) -> Self | ||
| pub fn merge(&mut self, other: ToolRouter<S>) { | ||
| self.disabled.extend(other.disabled); | ||
| for item in other.map.into_values() { | ||
@@ -409,8 +433,113 @@ self.add_route(item); | ||
| /// Remove a tool route from the router. | ||
| /// | ||
| /// The disabled state is **preserved**: if the name was in the disabled | ||
| /// set, it stays there so that a future [`add_route`](Self::add_route) | ||
| /// or [`merge`](Self::merge) with the same name will inherit the | ||
| /// disabled state. To also clear the disabled marker, call | ||
| /// [`enable_route`](Self::enable_route) afterwards. | ||
| pub fn remove_route(&mut self, name: &str) { | ||
| self.map.remove(name); | ||
| } | ||
| /// Returns `true` if the tool is registered **and** not currently | ||
| /// disabled. | ||
| pub fn has_route(&self, name: &str) -> bool { | ||
| self.map.contains_key(name) | ||
| self.map.contains_key(name) && !self.disabled.contains(name) | ||
| } | ||
| /// Disable a tool by name. Hidden from `list_all`, `get`, rejected by | ||
| /// `call`. Re-enable with [`enable_route`](Self::enable_route). | ||
| /// | ||
| /// Returns `true` if the name was newly added to the disabled set. | ||
| /// The name is recorded even if no matching route exists yet, so routes | ||
| /// added later will inherit the disabled state. | ||
| pub fn disable_route(&mut self, name: impl Into<Cow<'static, str>>) -> bool { | ||
| let name = name.into(); | ||
| let was_visible = self.map.contains_key(&name) && !self.disabled.contains(&name); | ||
| if was_visible { | ||
| self.notify_if_visible(&name); | ||
| } | ||
| self.disabled.insert(name) | ||
| } | ||
| /// Re-enable a previously disabled tool. Returns `true` if the name | ||
| /// was in the disabled set. | ||
| pub fn enable_route(&mut self, name: &str) -> bool { | ||
| let removed = self.disabled.remove(name); | ||
| if removed { | ||
| self.notify_if_visible(name); | ||
| } | ||
| removed | ||
| } | ||
| /// Returns `true` if the tool exists in the router **and** is currently | ||
| /// disabled. Returns `false` if the tool does not exist or if the name | ||
| /// was pre-disabled without a matching route. | ||
| pub fn is_disabled(&self, name: &str) -> bool { | ||
| self.map.contains_key(name) && self.disabled.contains(name) | ||
| } | ||
| /// Builder-style variant of [`disable_route`](Self::disable_route). | ||
| /// | ||
| /// The name is recorded even if no matching route has been added yet, | ||
| /// so it can be called before [`with_route`](Self::with_route) in a | ||
| /// builder chain. | ||
| pub fn with_disabled(mut self, name: impl Into<Cow<'static, str>>) -> Self { | ||
| self.disabled.insert(name.into()); | ||
| self | ||
| } | ||
| /// Install a callback invoked when the visible tool list changes. | ||
| pub fn set_notifier(&mut self, f: impl Fn() + Send + Sync + 'static) { | ||
| self.notifier = Some(Arc::new(f)); | ||
| } | ||
| pub fn clear_notifier(&mut self) { | ||
| self.notifier = None; | ||
| } | ||
| /// Install a notifier that sends `notifications/tools/list_changed` | ||
| /// via the given peer. | ||
| pub fn bind_peer_notifier(&mut self, peer: &crate::service::Peer<crate::RoleServer>) { | ||
| let peer = peer.clone(); | ||
| self.set_notifier(move || { | ||
| let peer = peer.clone(); | ||
| tokio::spawn(async move { | ||
| if let Err(e) = peer.notify_tool_list_changed().await { | ||
| tracing::warn!("failed to send tools/list_changed notification: {e}"); | ||
| } | ||
| }); | ||
| }); | ||
| } | ||
| /// Deferred notifier: no-op until the peer slot is filled. | ||
| pub(crate) fn deferred_peer_notifier() -> ( | ||
| impl Fn() + Send + Sync + 'static, | ||
| Arc<std::sync::OnceLock<crate::service::Peer<crate::RoleServer>>>, | ||
| ) { | ||
| let peer_slot = | ||
| Arc::new(std::sync::OnceLock::<crate::service::Peer<crate::RoleServer>>::new()); | ||
| let slot_clone = peer_slot.clone(); | ||
| let notifier = move || { | ||
| if let Some(peer) = slot_clone.get() { | ||
| let peer = peer.clone(); | ||
| tokio::spawn(async move { | ||
| if let Err(e) = peer.notify_tool_list_changed().await { | ||
| tracing::warn!("failed to send tools/list_changed notification: {e}"); | ||
| } | ||
| }); | ||
| } | ||
| }; | ||
| (notifier, peer_slot) | ||
| } | ||
| fn notify_if_visible(&self, name: &str) { | ||
| if self.map.contains_key(name) { | ||
| if let Some(notifier) = &self.notifier { | ||
| notifier(); | ||
| } | ||
| } | ||
| } | ||
| pub async fn call( | ||
@@ -420,5 +549,9 @@ &self, | ||
| ) -> Result<CallToolResult, crate::ErrorData> { | ||
| let name = context.name(); | ||
| if self.disabled.contains(name) { | ||
| return Err(crate::ErrorData::invalid_params("tool not found", None)); | ||
| } | ||
| let item = self | ||
| .map | ||
| .get(context.name()) | ||
| .get(name) | ||
| .ok_or_else(|| crate::ErrorData::invalid_params("tool not found", None))?; | ||
@@ -432,3 +565,8 @@ | ||
| pub fn list_all(&self) -> Vec<crate::model::Tool> { | ||
| let mut tools: Vec<_> = self.map.values().map(|item| item.attr.clone()).collect(); | ||
| let mut tools: Vec<_> = self | ||
| .map | ||
| .values() | ||
| .filter(|item| !self.disabled.contains(&item.attr.name)) | ||
| .map(|item| item.attr.clone()) | ||
| .collect(); | ||
| tools.sort_by(|a, b| a.name.cmp(&b.name)); | ||
@@ -440,4 +578,8 @@ tools | ||
| /// | ||
| /// Returns the tool if found, or `None` if no tool with the given name exists. | ||
| /// Returns the tool if found and enabled, or `None` if the tool does not | ||
| /// exist or is disabled. | ||
| pub fn get(&self, name: &str) -> Option<&crate::model::Tool> { | ||
| if self.disabled.contains(name) { | ||
| return None; | ||
| } | ||
| self.map.get(name).map(|r| &r.attr) | ||
@@ -467,1 +609,47 @@ } | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use std::sync::Arc; | ||
| use super::*; | ||
| use crate::{ | ||
| RoleServer, | ||
| model::{CallToolRequestParams, ErrorCode, NumberOrString}, | ||
| service::{AtomicU32RequestIdProvider, Peer, RequestContext}, | ||
| }; | ||
| struct DummyService; | ||
| impl crate::handler::server::ServerHandler for DummyService {} | ||
| #[tokio::test] | ||
| async fn test_call_disabled_tool_returns_error() { | ||
| let service = DummyService; | ||
| let mut router = ToolRouter::new().with_route(ToolRoute::new_dyn( | ||
| crate::model::Tool::new("test_tool", "a test tool", Arc::new(Default::default())), | ||
| |_ctx| Box::pin(async { Ok(CallToolResult::default()) }), | ||
| )); | ||
| router.disable_route("test_tool"); | ||
| let id_provider: Arc<dyn crate::service::RequestIdProvider> = | ||
| Arc::new(AtomicU32RequestIdProvider::default()); | ||
| let (peer, _rx) = Peer::<RoleServer>::new(id_provider, None); | ||
| let ctx = crate::handler::server::tool::ToolCallContext::new( | ||
| &service, | ||
| CallToolRequestParams { | ||
| meta: None, | ||
| name: Cow::Borrowed("test_tool"), | ||
| arguments: None, | ||
| task: None, | ||
| }, | ||
| RequestContext::new(NumberOrString::Number(1), peer), | ||
| ); | ||
| let err = router | ||
| .call(ctx) | ||
| .await | ||
| .expect_err("disabled tool should reject"); | ||
| assert_eq!(err.code, ErrorCode::INVALID_PARAMS); | ||
| assert_eq!(err.message, "tool not found"); | ||
| } | ||
| } |
| pub mod session; | ||
| #[cfg(all(feature = "transport-streamable-http-server", not(feature = "local")))] | ||
| pub mod tower; | ||
| pub use session::{SessionId, SessionManager}; | ||
| pub use session::{RestoreOutcome, SessionId, SessionManager, SessionRestoreMarker}; | ||
| #[cfg(all(feature = "transport-streamable-http-server", not(feature = "local")))] | ||
| pub use tower::{StreamableHttpServerConfig, StreamableHttpService}; |
@@ -33,3 +33,38 @@ //! Session management for the Streamable HTTP transport. | ||
| pub mod never; | ||
| pub mod store; | ||
| pub use store::{SessionState, SessionStore, SessionStoreError}; | ||
| /// Extension marker inserted into the `initialize` request extensions during a | ||
| /// session restore replay. Handlers can check for its presence to distinguish a | ||
| /// cross-instance restore from a genuine client-initiated `initialize` request. | ||
| /// | ||
| /// ```rust,ignore | ||
| /// if req.extensions().get::<SessionRestoreMarker>().is_some() { | ||
| /// // this is a restore replay, not a fresh client connection | ||
| /// } | ||
| /// ``` | ||
| #[non_exhaustive] | ||
| #[derive(Debug, Clone)] | ||
| pub struct SessionRestoreMarker { | ||
| pub id: SessionId, | ||
| } | ||
| /// The outcome of a [`SessionManager::restore_session`] call. | ||
| #[non_exhaustive] | ||
| #[derive(Debug)] | ||
| pub enum RestoreOutcome<T> { | ||
| /// The session was just re-created from external state; the caller must | ||
| /// spawn an MCP handler against the returned transport and replay the | ||
| /// `initialize` handshake. | ||
| Restored(T), | ||
| /// The session was already present in memory (e.g. a concurrent request | ||
| /// already restored it). The caller should proceed as if `has_session` | ||
| /// had returned `true` — no further action is required. | ||
| AlreadyPresent, | ||
| /// This session manager does not support external-store restore. | ||
| /// The caller should fall through to the normal 404 response. | ||
| NotSupported, | ||
| } | ||
| /// Controls how MCP sessions are created, validated, and closed. | ||
@@ -102,2 +137,20 @@ /// | ||
| > + Send; | ||
| /// Attempt to restore a previously-known session from external state, | ||
| /// creating a fresh in-memory session worker with the given `id`. | ||
| /// | ||
| /// See [`RestoreOutcome`] for the three possible results: | ||
| /// - [`RestoreOutcome::Restored`] — session re-created; caller must spawn | ||
| /// an MCP handler and replay the `initialize` handshake. | ||
| /// - [`RestoreOutcome::AlreadyPresent`] — session is already in memory | ||
| /// (e.g. a concurrent request restored it first); caller proceeds | ||
| /// normally. | ||
| /// - [`RestoreOutcome::NotSupported`] (default) — this session manager | ||
| /// does not support external-store restore; caller returns 404. | ||
| fn restore_session( | ||
| &self, | ||
| _id: SessionId, | ||
| ) -> impl Future<Output = Result<RestoreOutcome<Self::Transport>, Self::Error>> + Send { | ||
| futures::future::ready(Ok(RestoreOutcome::NotSupported)) | ||
| } | ||
| } |
@@ -139,2 +139,16 @@ use std::{ | ||
| } | ||
| async fn restore_session( | ||
| &self, | ||
| id: SessionId, | ||
| ) -> Result<RestoreOutcome<Self::Transport>, Self::Error> { | ||
| let mut sessions = self.sessions.write().await; | ||
| if sessions.contains_key(&id) { | ||
| // A concurrent request already restored this session. | ||
| return Ok(RestoreOutcome::AlreadyPresent); | ||
| } | ||
| let (handle, worker) = create_local_session(id.clone(), self.session_config.clone()); | ||
| sessions.insert(id, handle); | ||
| Ok(RestoreOutcome::Restored(WorkerTransport::spawn(worker))) | ||
| } | ||
| } | ||
@@ -192,3 +206,3 @@ | ||
| use super::{ServerSseMessage, SessionManager}; | ||
| use super::{RestoreOutcome, ServerSseMessage, SessionManager}; | ||
@@ -921,2 +935,4 @@ struct CachedTx { | ||
| KeepAliveTimeout(Duration), | ||
| #[error("init timeout after {}ms", _0.as_millis())] | ||
| InitTimeout(Duration), | ||
| #[error("Transport closed")] | ||
@@ -951,9 +967,20 @@ TransportClosed, | ||
| } | ||
| // waiting for initialize request | ||
| let evt = self.event_rx.recv().await.ok_or_else(|| { | ||
| WorkerQuitReason::fatal( | ||
| LocalSessionWorkerError::TransportTerminated, | ||
| "get initialize request", | ||
| ) | ||
| })?; | ||
| let init_timeout = self.session_config.init_timeout.unwrap_or(Duration::MAX); | ||
| let evt = tokio::select! { | ||
| evt = self.event_rx.recv() => evt.ok_or_else(|| { | ||
| WorkerQuitReason::fatal( | ||
| LocalSessionWorkerError::TransportTerminated, | ||
| "get initialize request", | ||
| ) | ||
| })?, | ||
| _ = context.cancellation_token.cancelled() => { | ||
| return Err(WorkerQuitReason::Cancelled); | ||
| } | ||
| _ = tokio::time::sleep(init_timeout) => { | ||
| return Err(WorkerQuitReason::fatal( | ||
| LocalSessionWorkerError::InitTimeout(init_timeout), | ||
| "waiting for initialize request", | ||
| )); | ||
| } | ||
| }; | ||
| let SessionEvent::InitializeRequest { request, responder } = evt else { | ||
@@ -1115,2 +1142,6 @@ return Err(WorkerQuitReason::fatal( | ||
| pub completed_cache_ttl: Duration, | ||
| /// Maximum duration to wait for the `initialize` request after session | ||
| /// creation. If not received within this window, the session is | ||
| /// terminated. Default is 60 seconds. Set to `None` to disable. | ||
| pub init_timeout: Option<Duration>, | ||
| } | ||
@@ -1123,2 +1154,3 @@ | ||
| pub const DEFAULT_COMPLETED_CACHE_TTL: Duration = Duration::from_secs(60); | ||
| pub const DEFAULT_INIT_TIMEOUT: Duration = Duration::from_secs(60); | ||
| } | ||
@@ -1133,2 +1165,3 @@ | ||
| completed_cache_ttl: Self::DEFAULT_COMPLETED_CACHE_TTL, | ||
| init_timeout: Some(Self::DEFAULT_INIT_TIMEOUT), | ||
| } | ||
@@ -1135,0 +1168,0 @@ } |
@@ -1,2 +0,2 @@ | ||
| use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration}; | ||
| use std::{collections::HashMap, convert::Infallible, fmt::Display, sync::Arc, time::Duration}; | ||
@@ -11,6 +11,11 @@ use bytes::Bytes; | ||
| use super::session::SessionManager; | ||
| use super::session::{ | ||
| RestoreOutcome, SessionId, SessionManager, SessionRestoreMarker, SessionState, SessionStore, | ||
| }; | ||
| use crate::{ | ||
| RoleServer, | ||
| model::{ClientJsonRpcMessage, ClientRequest, GetExtensions, ProtocolVersion}, | ||
| model::{ | ||
| ClientJsonRpcMessage, ClientNotification, ClientRequest, GetExtensions, InitializeRequest, | ||
| InitializedNotification, ProtocolVersion, | ||
| }, | ||
| serve_server, | ||
@@ -63,4 +68,41 @@ service::serve_directly, | ||
| pub allowed_hosts: Vec<String>, | ||
| /// Allowed browser origins for inbound `Origin` validation. | ||
| /// | ||
| /// Defaults to an empty list, which disables Origin validation. When | ||
| /// non-empty, requests carrying an `Origin` header must match per RFC 6454 | ||
| /// `(scheme, host, port)`; missing-`Origin` requests still pass. Entries | ||
| /// must include a scheme; `"null"` matches the browser's `Origin: null`. | ||
| /// examples: | ||
| /// allowed_origins = ["https://app.example.com", "http://localhost:8080"] | ||
| pub allowed_origins: Vec<String>, | ||
| /// Optional external session store for cross-instance recovery. | ||
| /// | ||
| /// When set, [`SessionState`] (the client's `initialize` parameters) is | ||
| /// persisted after a successful handshake and deleted when the session | ||
| /// closes. On any subsequent request that arrives at an instance with no | ||
| /// in-memory session, the store is consulted: if an entry is found the | ||
| /// session is transparently restored so the client does not need to | ||
| /// re-initialize. | ||
| /// | ||
| /// # Example | ||
| /// ```rust,ignore | ||
| /// use std::sync::Arc; | ||
| /// use rmcp::transport::streamable_http_server::{ | ||
| /// StreamableHttpServerConfig, session::SessionStore, | ||
| /// }; | ||
| /// | ||
| /// let config = StreamableHttpServerConfig { | ||
| /// session_store: Some(Arc::new(MyRedisStore::new())), | ||
| /// ..Default::default() | ||
| /// }; | ||
| /// ``` | ||
| pub session_store: Option<Arc<dyn SessionStore>>, | ||
| } | ||
| impl std::fmt::Debug for dyn SessionStore { | ||
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
| f.write_str("<SessionStore>") | ||
| } | ||
| } | ||
| impl Default for StreamableHttpServerConfig { | ||
@@ -75,2 +117,4 @@ fn default() -> Self { | ||
| allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()], | ||
| allowed_origins: vec![], | ||
| session_store: None, | ||
| } | ||
@@ -93,2 +137,14 @@ } | ||
| } | ||
| pub fn with_allowed_origins( | ||
| mut self, | ||
| allowed_origins: impl IntoIterator<Item = impl Into<String>>, | ||
| ) -> Self { | ||
| self.allowed_origins = allowed_origins.into_iter().map(Into::into).collect(); | ||
| self | ||
| } | ||
| /// Disable Origin validation, reverting to the default ignore-Origin behavior. | ||
| pub fn disable_allowed_origins(mut self) -> Self { | ||
| self.allowed_origins.clear(); | ||
| self | ||
| } | ||
| pub fn with_sse_keep_alive(mut self, duration: Option<Duration>) -> Self { | ||
@@ -216,2 +272,55 @@ self.sse_keep_alive = duration; | ||
| #[derive(Debug, Clone, PartialEq, Eq)] | ||
| enum NormalizedOrigin { | ||
| Null, | ||
| Tuple { | ||
| scheme: String, | ||
| host: String, | ||
| port: Option<u16>, | ||
| }, | ||
| } | ||
| fn parse_origin_value(value: &str) -> Option<NormalizedOrigin> { | ||
| let value = value.trim(); | ||
| if value.is_empty() { | ||
| return None; | ||
| } | ||
| if value.eq_ignore_ascii_case("null") { | ||
| return Some(NormalizedOrigin::Null); | ||
| } | ||
| let uri = http::Uri::try_from(value).ok()?; | ||
| let scheme = uri.scheme_str()?.to_ascii_lowercase(); | ||
| let authority = uri.authority()?; | ||
| Some(NormalizedOrigin::Tuple { | ||
| scheme, | ||
| host: normalize_host(authority.host()), | ||
| port: authority.port_u16(), | ||
| }) | ||
| } | ||
| fn origin_is_allowed(origin: &NormalizedOrigin, allowed_origins: &[String]) -> bool { | ||
| if allowed_origins.is_empty() { | ||
| return true; | ||
| } | ||
| allowed_origins | ||
| .iter() | ||
| .filter_map(|raw| parse_origin_value(raw)) | ||
| .any(|allowed| match (&allowed, origin) { | ||
| (NormalizedOrigin::Null, NormalizedOrigin::Null) => true, | ||
| ( | ||
| NormalizedOrigin::Tuple { | ||
| scheme: a_scheme, | ||
| host: a_host, | ||
| port: a_port, | ||
| }, | ||
| NormalizedOrigin::Tuple { | ||
| scheme: o_scheme, | ||
| host: o_host, | ||
| port: o_port, | ||
| }, | ||
| ) => a_scheme == o_scheme && a_host == o_host && (a_port.is_none() || a_port == o_port), | ||
| _ => false, | ||
| }) | ||
| } | ||
| fn bad_request_response(message: &str) -> BoxResponse { | ||
@@ -227,12 +336,29 @@ let body = Full::from(message.to_string()).boxed(); | ||
| fn parse_host_header(headers: &HeaderMap) -> Result<NormalizedAuthority, BoxResponse> { | ||
| let Some(host) = headers.get(http::header::HOST) else { | ||
| return Err(bad_request_response("Bad Request: missing Host header")); | ||
| }; | ||
| let host = host | ||
| .to_str() | ||
| .map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?; | ||
| let authority = http::uri::Authority::try_from(host) | ||
| .map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?; | ||
| fn parse_host_header( | ||
| uri: &http::Uri, | ||
| headers: &HeaderMap, | ||
| ) -> Result<NormalizedAuthority, BoxResponse> { | ||
| if let Some(host) = headers.get(http::header::HOST) { | ||
| let host_str = host | ||
| .to_str() | ||
| .inspect_err(|_| { | ||
| tracing::warn!(host = ?host, "rejected request with non-UTF-8 Host header"); | ||
| }) | ||
| .map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?; | ||
| let authority = http::uri::Authority::try_from(host_str) | ||
| .inspect_err(|_| { | ||
| tracing::warn!( | ||
| host = host_str, | ||
| "rejected request with malformed Host header" | ||
| ); | ||
| }) | ||
| .map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?; | ||
| return Ok(normalize_authority(authority.host(), authority.port_u16())); | ||
| } | ||
| // HTTP/2 carries the host in `:authority`; middleware such as | ||
| // `axum::Router::nest` can drop the `Host` header hyper synthesizes from it. | ||
| let authority = uri.authority().ok_or_else(|| { | ||
| tracing::warn!("rejected request with missing Host header and no :authority"); | ||
| bad_request_response("Bad Request: missing Host header") | ||
| })?; | ||
| Ok(normalize_authority(authority.host(), authority.port_u16())) | ||
@@ -242,10 +368,50 @@ } | ||
| fn validate_dns_rebinding_headers( | ||
| uri: &http::Uri, | ||
| headers: &HeaderMap, | ||
| config: &StreamableHttpServerConfig, | ||
| ) -> Result<(), BoxResponse> { | ||
| let host = parse_host_header(headers)?; | ||
| let host = parse_host_header(uri, headers)?; | ||
| if !host_is_allowed(&host, &config.allowed_hosts) { | ||
| tracing::warn!( | ||
| host = ?host, | ||
| "rejected request with disallowed Host header (possible DNS rebinding attempt)", | ||
| ); | ||
| return Err(forbidden_response("Forbidden: Host header is not allowed")); | ||
| } | ||
| validate_origin_header(headers, &config.allowed_origins)?; | ||
| Ok(()) | ||
| } | ||
| fn validate_origin_header( | ||
| headers: &HeaderMap, | ||
| allowed_origins: &[String], | ||
| ) -> Result<(), BoxResponse> { | ||
| if allowed_origins.is_empty() { | ||
| return Ok(()); | ||
| } | ||
| let Some(origin_header) = headers.get(http::header::ORIGIN) else { | ||
| return Ok(()); | ||
| }; | ||
| let origin_str = origin_header | ||
| .to_str() | ||
| .inspect_err(|_| { | ||
| tracing::warn!(origin = ?origin_header, "rejected request with non-UTF-8 Origin header"); | ||
| }) | ||
| .map_err(|_| bad_request_response("Bad Request: Invalid Origin header encoding"))?; | ||
| let origin = parse_origin_value(origin_str).ok_or_else(|| { | ||
| tracing::warn!( | ||
| origin = origin_str, | ||
| "rejected request with malformed Origin header", | ||
| ); | ||
| bad_request_response("Bad Request: Invalid Origin header") | ||
| })?; | ||
| if !origin_is_allowed(&origin, allowed_origins) { | ||
| tracing::warn!( | ||
| origin = ?origin, | ||
| "rejected request with disallowed Origin header (possible cross-origin attack)", | ||
| ); | ||
| return Err(forbidden_response( | ||
| "Forbidden: Origin header is not allowed", | ||
| )); | ||
| } | ||
| Ok(()) | ||
@@ -341,2 +507,9 @@ } | ||
| service_factory: Arc<dyn Fn() -> Result<S, std::io::Error> + Send + Sync>, | ||
| /// Tracks in-progress session restores so that concurrent requests for the | ||
| /// same unknown session ID wait for the first restore to complete rather | ||
| /// than racing to replay the initialize handshake. `None` when no external | ||
| /// session store is configured (avoids allocating the map). | ||
| pending_restores: Option< | ||
| Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::watch::Sender<Option<bool>>>>>, | ||
| >, | ||
| } | ||
@@ -350,2 +523,3 @@ | ||
| service_factory: self.service_factory.clone(), | ||
| pending_restores: self.pending_restores.clone(), | ||
| } | ||
@@ -381,2 +555,31 @@ } | ||
| /// Guard used inside [`StreamableHttpService::try_restore_from_store`]. | ||
| /// | ||
| /// Ensures the `pending_restores` map entry is always cleaned up — even when | ||
| /// the future is cancelled mid-await. | ||
| /// | ||
| /// `result` defaults to `false` (failure / cancellation). Only the success path | ||
| /// needs to set it to `true` before returning. | ||
| struct PendingRestoreGuard { | ||
| pending_restores: | ||
| Arc<tokio::sync::RwLock<HashMap<SessionId, tokio::sync::watch::Sender<Option<bool>>>>>, | ||
| session_id: SessionId, | ||
| watch_tx: tokio::sync::watch::Sender<Option<bool>>, | ||
| /// The value that will be broadcast to waiting tasks on drop. | ||
| result: bool, | ||
| } | ||
| impl Drop for PendingRestoreGuard { | ||
| fn drop(&mut self) { | ||
| // `send` is synchronous — unblocks waiters immediately, no lock needed. | ||
| let _ = self.watch_tx.send(Some(self.result)); | ||
| // Remove the map entry asynchronously (requires the async write lock). | ||
| let pending_restores = self.pending_restores.clone(); | ||
| let session_id = self.session_id.clone(); | ||
| tokio::spawn(async move { | ||
| pending_restores.write().await.remove(&session_id); | ||
| }); | ||
| } | ||
| } | ||
| impl<S, M> StreamableHttpService<S, M> | ||
@@ -392,2 +595,8 @@ where | ||
| ) -> Self { | ||
| let pending_restores = config.session_store.is_some().then(|| { | ||
| Arc::new(tokio::sync::RwLock::new(HashMap::< | ||
| SessionId, | ||
| tokio::sync::watch::Sender<Option<bool>>, | ||
| >::new())) | ||
| }); | ||
| Self { | ||
@@ -397,2 +606,3 @@ config, | ||
| service_factory: Arc::new(service_factory), | ||
| pending_restores, | ||
| } | ||
@@ -403,2 +613,213 @@ } | ||
| } | ||
| /// Spawn a task that runs `serve_server` for the given session, waits for | ||
| /// it to finish, and then calls `close_session`. | ||
| /// | ||
| /// `init_done_tx`: when `Some`, the sender is fired after `serve_server` | ||
| /// returns successfully, signalling to the caller that the MCP handshake | ||
| /// is complete. Used by `try_restore_from_store` to synchronise with the | ||
| /// restore `initialize` replay; `handle_post` passes `None`. | ||
| fn spawn_session_worker( | ||
| session_manager: Arc<M>, | ||
| session_id: SessionId, | ||
| service: S, | ||
| transport: M::Transport, | ||
| init_done_tx: Option<tokio::sync::oneshot::Sender<()>>, | ||
| ) where | ||
| S: crate::Service<RoleServer> + Send + 'static, | ||
| M: SessionManager, | ||
| { | ||
| tokio::spawn(async move { | ||
| let svc = | ||
| serve_server::<S, M::Transport, _, TransportAdapterIdentity>(service, transport) | ||
| .await; | ||
| match svc { | ||
| Ok(svc) => { | ||
| if let Some(tx) = init_done_tx { | ||
| let _ = tx.send(()); | ||
| } | ||
| let _ = svc.waiting().await; | ||
| } | ||
| Err(e) => { | ||
| tracing::error!("Failed to serve session: {e}"); | ||
| // Dropping init_done_tx (if Some) signals failure to the caller. | ||
| } | ||
| } | ||
| let _ = session_manager | ||
| .close_session(&session_id) | ||
| .await | ||
| .inspect_err(|e| { | ||
| tracing::error!("Failed to close session {session_id}: {e}"); | ||
| }); | ||
| }); | ||
| } | ||
| /// Attempt to restore a session from the external store. | ||
| /// | ||
| /// Returns `true` when the session is available and ready to serve the | ||
| /// current request (either just restored or already in memory). Returns | ||
| /// `false` when no store is configured or the session ID is unknown. | ||
| /// | ||
| /// Concurrent requests for the same unknown session ID are serialized: the | ||
| /// first caller performs the full restore and handshake replay while others | ||
| /// subscribe to a `watch` channel and wait, avoiding duplicate handshakes. | ||
| async fn try_restore_from_store( | ||
| &self, | ||
| session_id: &SessionId, | ||
| parts: &http::request::Parts, | ||
| ) -> Result<bool, std::io::Error> | ||
| where | ||
| S: crate::Service<RoleServer> + Send + 'static, | ||
| M: SessionManager, | ||
| { | ||
| // Both fields are Some iff a session store is configured. | ||
| let (Some(pending_restores), Some(store)) = | ||
| (&self.pending_restores, &self.config.session_store) | ||
| else { | ||
| return Ok(false); | ||
| }; | ||
| // Serialize concurrent restores for the same session ID. | ||
| // Write-lock once: if another task is already restoring, subscribe and wait; | ||
| // otherwise, register ourselves as the restoring task. | ||
| // Channel value: None = in progress, Some(true) = restored, Some(false) = not found/failed. | ||
| let (watch_tx, _watch_rx) = tokio::sync::watch::channel(None::<bool>); | ||
| { | ||
| let mut pending = pending_restores.write().await; | ||
| if let Some(tx) = pending.get(session_id) { | ||
| let mut rx = tx.subscribe(); | ||
| drop(pending); | ||
| // Wait for the restore to finish, then propagate the outcome. | ||
| let result = rx | ||
| .wait_for(|r| r.is_some()) | ||
| .await | ||
| .map(|r| r.unwrap_or(false)) | ||
| .unwrap_or(false); | ||
| return Ok(result); | ||
| } | ||
| pending.insert(session_id.clone(), watch_tx.clone()); | ||
| } | ||
| // Guard: signals waiters and cleans up the map entry on drop | ||
| let mut guard = PendingRestoreGuard { | ||
| pending_restores: pending_restores.clone(), | ||
| session_id: session_id.clone(), | ||
| watch_tx: watch_tx.clone(), | ||
| result: false, | ||
| }; | ||
| // --- Step 3: load from external store --- | ||
| let state = match store.load(session_id.as_ref()).await { | ||
| Ok(Some(s)) => s, | ||
| Ok(None) => { | ||
| return Ok(false); | ||
| } | ||
| Err(e) => { | ||
| tracing::error!( | ||
| session_id = session_id.as_ref(), | ||
| error = %e, | ||
| "session store load failed during restore" | ||
| ); | ||
| return Err(std::io::Error::other(e)); | ||
| } | ||
| }; | ||
| // --- Step 4: ask the session manager to allocate an in-memory worker --- | ||
| let transport = match self | ||
| .session_manager | ||
| .restore_session(session_id.clone()) | ||
| .await | ||
| .map_err(|e| std::io::Error::other(e.to_string())) | ||
| { | ||
| Ok(RestoreOutcome::Restored(t)) => t, | ||
| Ok(RestoreOutcome::AlreadyPresent) => { | ||
| // Invariant violation: pending_restores ensures only one task can call | ||
| // restore_session per session ID, so AlreadyPresent is impossible here. | ||
| return Err(std::io::Error::other( | ||
| "restore_session returned AlreadyPresent unexpectedly; session manager might have modified the session store outside of the restore_session API", | ||
| )); | ||
| } | ||
| Ok(RestoreOutcome::NotSupported) => { | ||
| return Ok(false); | ||
| } | ||
| Err(e) => { | ||
| return Err(e); | ||
| } | ||
| }; | ||
| // --- Step 5: replay the MCP initialize handshake --- | ||
| let service = match self.get_service() { | ||
| Ok(s) => s, | ||
| Err(e) => { | ||
| return Err(e); | ||
| } | ||
| }; | ||
| // `serve_server` requires both the `initialize` request and the | ||
| // `notifications/initialized` notification before transitioning to | ||
| // the running state — we must send both before returning. | ||
| let mut restore_init = ClientJsonRpcMessage::request( | ||
| ClientRequest::InitializeRequest(InitializeRequest { | ||
| params: state.initialize_params, | ||
| ..Default::default() | ||
| }), | ||
| crate::model::NumberOrString::Number(0), | ||
| ); | ||
| restore_init.insert_extension(parts.clone()); | ||
| restore_init.insert_extension(SessionRestoreMarker { | ||
| id: session_id.clone(), | ||
| }); | ||
| let mut restore_initialized = ClientJsonRpcMessage::notification( | ||
| ClientNotification::InitializedNotification(InitializedNotification { | ||
| ..Default::default() | ||
| }), | ||
| ); | ||
| restore_initialized.insert_extension(parts.clone()); | ||
| restore_initialized.insert_extension(SessionRestoreMarker { | ||
| id: session_id.clone(), | ||
| }); | ||
| // Signal from the spawned task once serve_server finishes initialising. | ||
| let (init_done_tx, init_done_rx) = tokio::sync::oneshot::channel::<()>(); | ||
| Self::spawn_session_worker( | ||
| self.session_manager.clone(), | ||
| session_id.clone(), | ||
| service, | ||
| transport, | ||
| Some(init_done_tx), | ||
| ); | ||
| if let Err(e) = self | ||
| .session_manager | ||
| .initialize_session(session_id, restore_init) | ||
| .await | ||
| .map_err(|e| std::io::Error::other(e.to_string())) | ||
| { | ||
| return Err(e); | ||
| } | ||
| if let Err(e) = self | ||
| .session_manager | ||
| .accept_message(session_id, restore_initialized) | ||
| .await | ||
| .map_err(|e| std::io::Error::other(e.to_string())) | ||
| { | ||
| return Err(e); | ||
| } | ||
| if init_done_rx.await.is_err() { | ||
| return Err(std::io::Error::other( | ||
| "serve_server initialization failed during restore", | ||
| )); | ||
| } | ||
| // Restore complete — wake any waiting concurrent requests. | ||
| guard.result = true; | ||
| tracing::debug!( | ||
| session_id = session_id.as_ref(), | ||
| "session restored from external store" | ||
| ); | ||
| Ok(true) | ||
| } | ||
| pub async fn handle<B>(&self, request: Request<B>) -> Response<BoxBody<Bytes, Infallible>> | ||
@@ -409,3 +830,5 @@ where | ||
| { | ||
| if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) { | ||
| if let Err(response) = | ||
| validate_dns_rebinding_headers(request.uri(), request.headers(), &self.config) | ||
| { | ||
| return response; | ||
@@ -479,14 +902,22 @@ } | ||
| .map_err(internal_error_response("check session"))?; | ||
| let (parts, _) = request.into_parts(); | ||
| if !has_session { | ||
| // MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions | ||
| return Ok(Response::builder() | ||
| .status(http::StatusCode::NOT_FOUND) | ||
| .body(Full::new(Bytes::from("Not Found: Session not found")).boxed()) | ||
| .expect("valid response")); | ||
| // Attempt transparent cross-instance restore from external store. | ||
| let restored = self | ||
| .try_restore_from_store(&session_id, &parts) | ||
| .await | ||
| .map_err(internal_error_response("restore session"))?; | ||
| if !restored { | ||
| // MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions | ||
| return Ok(Response::builder() | ||
| .status(http::StatusCode::NOT_FOUND) | ||
| .body(Full::new(Bytes::from("Not Found: Session not found")).boxed()) | ||
| .expect("valid response")); | ||
| } | ||
| } | ||
| // Validate MCP-Protocol-Version header (per 2025-06-18 spec) | ||
| validate_protocol_version_header(request.headers())?; | ||
| validate_protocol_version_header(&parts.headers)?; | ||
| // check if last event id is provided | ||
| let last_event_id = request | ||
| .headers() | ||
| let last_event_id = parts | ||
| .headers | ||
| .get(HEADER_LAST_EVENT_ID) | ||
@@ -603,7 +1034,14 @@ .and_then(|v| v.to_str().ok()) | ||
| if !has_session { | ||
| // MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions | ||
| return Ok(Response::builder() | ||
| .status(http::StatusCode::NOT_FOUND) | ||
| .body(Full::new(Bytes::from("Not Found: Session not found")).boxed()) | ||
| .expect("valid response")); | ||
| // Attempt transparent cross-instance restore from external store. | ||
| let restored = self | ||
| .try_restore_from_store(&session_id, &part) | ||
| .await | ||
| .map_err(internal_error_response("restore session"))?; | ||
| if !restored { | ||
| // MCP spec: server MUST respond with 404 Not Found for terminated/unknown sessions | ||
| return Ok(Response::builder() | ||
| .status(http::StatusCode::NOT_FOUND) | ||
| .body(Full::new(Bytes::from("Not Found: Session not found")).boxed()) | ||
| .expect("valid response")); | ||
| } | ||
| } | ||
@@ -660,2 +1098,17 @@ | ||
| .map_err(internal_error_response("create session"))?; | ||
| // Capture init params for external store persistence before | ||
| // extensions are injected (which would require Clone). | ||
| let stored_init_params = if self.config.session_store.is_some() { | ||
| if let ClientJsonRpcMessage::Request(req) = &message { | ||
| if let ClientRequest::InitializeRequest(init_req) = &req.request { | ||
| Some(init_req.params.clone()) | ||
| } else { | ||
| None | ||
| } | ||
| } else { | ||
| None | ||
| } | ||
| } else { | ||
| None | ||
| }; | ||
| if let ClientJsonRpcMessage::Request(req) = &mut message { | ||
@@ -674,27 +1127,9 @@ if !matches!(req.request, ClientRequest::InitializeRequest(_)) { | ||
| // spawn a task to serve the session | ||
| tokio::spawn({ | ||
| let session_manager = self.session_manager.clone(); | ||
| let session_id = session_id.clone(); | ||
| async move { | ||
| let service = serve_server::<S, M::Transport, _, TransportAdapterIdentity>( | ||
| service, transport, | ||
| ) | ||
| .await; | ||
| match service { | ||
| Ok(service) => { | ||
| // on service created | ||
| let _ = service.waiting().await; | ||
| } | ||
| Err(e) => { | ||
| tracing::error!("Failed to create service: {e}"); | ||
| } | ||
| } | ||
| let _ = session_manager | ||
| .close_session(&session_id) | ||
| .await | ||
| .inspect_err(|e| { | ||
| tracing::error!("Failed to close session {session_id}: {e}"); | ||
| }); | ||
| } | ||
| }); | ||
| Self::spawn_session_worker( | ||
| self.session_manager.clone(), | ||
| session_id.clone(), | ||
| service, | ||
| transport, | ||
| None, | ||
| ); | ||
| // get initialize response | ||
@@ -706,2 +1141,19 @@ let response = self | ||
| .map_err(internal_error_response("create stream"))?; | ||
| // Persist session state to external store after a successful handshake. | ||
| if let (Some(store), Some(params)) = | ||
| (&self.config.session_store, stored_init_params) | ||
| { | ||
| let state = SessionState { | ||
| initialize_params: params, | ||
| }; | ||
| let _ = store | ||
| .store(session_id.as_ref(), &state) | ||
| .await | ||
| .inspect_err(|e| { | ||
| tracing::warn!( | ||
| "Failed to persist session {} to store: {e}", | ||
| session_id | ||
| ); | ||
| }); | ||
| } | ||
| let stream = | ||
@@ -829,4 +1281,11 @@ futures::stream::once(async move { ServerSseMessage::from_message(response) }); | ||
| .map_err(internal_error_response("close session"))?; | ||
| // Remove from external store: a DELETE means the client intentionally | ||
| // ends the session, so the store entry is no longer needed. | ||
| if let Some(store) = &self.config.session_store { | ||
| let _ = store.delete(session_id.as_ref()).await.inspect_err(|e| { | ||
| tracing::warn!("Failed to delete session {} from store: {e}", session_id); | ||
| }); | ||
| } | ||
| Ok(accepted_response()) | ||
| } | ||
| } |
@@ -1033,1 +1033,198 @@ #![cfg(not(feature = "local"))] | ||
| } | ||
| /// Integration test: Verify the validator falls back to the URI authority when | ||
| /// the Host header is absent (HTTP/2 :authority pseudo-header scenario). | ||
| #[tokio::test] | ||
| #[cfg(all(feature = "transport-streamable-http-server", feature = "server",))] | ||
| async fn test_server_falls_back_to_uri_authority_when_host_header_missing() { | ||
| use std::sync::Arc; | ||
| use bytes::Bytes; | ||
| use http::{Method, Request, header::CONTENT_TYPE}; | ||
| use http_body_util::Full; | ||
| use rmcp::{ | ||
| handler::server::ServerHandler, | ||
| model::{ServerCapabilities, ServerInfo}, | ||
| transport::streamable_http_server::{ | ||
| StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, | ||
| }, | ||
| }; | ||
| use serde_json::json; | ||
| #[derive(Clone)] | ||
| struct TestHandler; | ||
| impl ServerHandler for TestHandler { | ||
| fn get_info(&self) -> ServerInfo { | ||
| ServerInfo::new(ServerCapabilities::builder().build()) | ||
| } | ||
| } | ||
| let service = StreamableHttpService::new( | ||
| || Ok(TestHandler), | ||
| Arc::new(LocalSessionManager::default()), | ||
| StreamableHttpServerConfig::default(), | ||
| ); | ||
| let init_body = json!({ | ||
| "jsonrpc": "2.0", | ||
| "id": 1, | ||
| "method": "initialize", | ||
| "params": { | ||
| "protocolVersion": "2025-03-26", | ||
| "capabilities": {}, | ||
| "clientInfo": { | ||
| "name": "test-client", | ||
| "version": "1.0.0" | ||
| } | ||
| } | ||
| }); | ||
| // Allowed authority via URI only — no Host header. | ||
| let allowed_request = Request::builder() | ||
| .method(Method::POST) | ||
| .uri("http://localhost:8080/") | ||
| .header("Accept", "application/json, text/event-stream") | ||
| .header(CONTENT_TYPE, "application/json") | ||
| .body(Full::new(Bytes::from(init_body.to_string()))) | ||
| .unwrap(); | ||
| assert!(allowed_request.headers().get("Host").is_none()); | ||
| let response = service.handle(allowed_request).await; | ||
| assert_eq!(response.status(), http::StatusCode::OK); | ||
| // Disallowed authority via URI only — no Host header. | ||
| let bad_request = Request::builder() | ||
| .method(Method::POST) | ||
| .uri("http://attacker.example/") | ||
| .header("Accept", "application/json, text/event-stream") | ||
| .header(CONTENT_TYPE, "application/json") | ||
| .body(Full::new(Bytes::from(init_body.to_string()))) | ||
| .unwrap(); | ||
| assert!(bad_request.headers().get("Host").is_none()); | ||
| let response = service.handle(bad_request).await; | ||
| assert_eq!(response.status(), http::StatusCode::FORBIDDEN); | ||
| // Neither Host header nor URI authority — still a 400. | ||
| let missing_request = Request::builder() | ||
| .method(Method::POST) | ||
| .uri("/") | ||
| .header("Accept", "application/json, text/event-stream") | ||
| .header(CONTENT_TYPE, "application/json") | ||
| .body(Full::new(Bytes::from(init_body.to_string()))) | ||
| .unwrap(); | ||
| assert!(missing_request.headers().get("Host").is_none()); | ||
| assert!(missing_request.uri().authority().is_none()); | ||
| let response = service.handle(missing_request).await; | ||
| assert_eq!(response.status(), http::StatusCode::BAD_REQUEST); | ||
| } | ||
| #[cfg(all(feature = "transport-streamable-http-server", feature = "server"))] | ||
| mod origin_validation { | ||
| use std::sync::Arc; | ||
| use bytes::Bytes; | ||
| use http::{Method, Request, header::CONTENT_TYPE}; | ||
| use http_body_util::Full; | ||
| use rmcp::{ | ||
| handler::server::ServerHandler, | ||
| model::{ServerCapabilities, ServerInfo}, | ||
| transport::streamable_http_server::{ | ||
| StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, | ||
| }, | ||
| }; | ||
| use serde_json::json; | ||
| #[derive(Clone)] | ||
| struct TestHandler; | ||
| impl ServerHandler for TestHandler { | ||
| fn get_info(&self) -> ServerInfo { | ||
| ServerInfo::new(ServerCapabilities::builder().build()) | ||
| } | ||
| } | ||
| fn service_with_allowed_origins( | ||
| origins: &[&str], | ||
| ) -> StreamableHttpService<TestHandler, LocalSessionManager> { | ||
| StreamableHttpService::new( | ||
| || Ok(TestHandler), | ||
| Arc::new(LocalSessionManager::default()), | ||
| StreamableHttpServerConfig::default().with_allowed_origins(origins.iter().copied()), | ||
| ) | ||
| } | ||
| fn init_request(origin: Option<&str>) -> Request<Full<Bytes>> { | ||
| let init_body = json!({ | ||
| "jsonrpc": "2.0", | ||
| "id": 1, | ||
| "method": "initialize", | ||
| "params": { | ||
| "protocolVersion": "2025-03-26", | ||
| "capabilities": {}, | ||
| "clientInfo": {"name": "test-client", "version": "1.0.0"} | ||
| } | ||
| }); | ||
| let mut builder = Request::builder() | ||
| .method(Method::POST) | ||
| .header("Accept", "application/json, text/event-stream") | ||
| .header(CONTENT_TYPE, "application/json") | ||
| .header("Host", "localhost:8080"); | ||
| if let Some(origin) = origin { | ||
| builder = builder.header("Origin", origin); | ||
| } | ||
| builder | ||
| .body(Full::new(Bytes::from(init_body.to_string()))) | ||
| .unwrap() | ||
| } | ||
| #[tokio::test] | ||
| async fn allowlisted_origin_is_allowed() { | ||
| let service = service_with_allowed_origins(&["http://localhost:8080"]); | ||
| let response = service | ||
| .handle(init_request(Some("http://localhost:8080"))) | ||
| .await; | ||
| assert_eq!(response.status(), http::StatusCode::OK); | ||
| } | ||
| #[tokio::test] | ||
| async fn non_allowlisted_origin_is_forbidden() { | ||
| let service = service_with_allowed_origins(&["http://localhost:8080"]); | ||
| let response = service | ||
| .handle(init_request(Some("http://attacker.example"))) | ||
| .await; | ||
| assert_eq!(response.status(), http::StatusCode::FORBIDDEN); | ||
| } | ||
| #[tokio::test] | ||
| async fn missing_origin_passes_through() { | ||
| let service = service_with_allowed_origins(&["http://localhost:8080"]); | ||
| let response = service.handle(init_request(None)).await; | ||
| assert_eq!(response.status(), http::StatusCode::OK); | ||
| } | ||
| #[tokio::test] | ||
| async fn scheme_mismatch_is_forbidden() { | ||
| let service = service_with_allowed_origins(&["http://localhost:8080"]); | ||
| let response = service | ||
| .handle(init_request(Some("https://localhost:8080"))) | ||
| .await; | ||
| assert_eq!(response.status(), http::StatusCode::FORBIDDEN); | ||
| } | ||
| #[tokio::test] | ||
| async fn null_origin_is_allowed_when_allowlisted() { | ||
| let service = service_with_allowed_origins(&["null"]); | ||
| let response = service.handle(init_request(Some("null"))).await; | ||
| assert_eq!(response.status(), http::StatusCode::OK); | ||
| } | ||
| #[tokio::test] | ||
| async fn null_origin_is_forbidden_when_not_allowlisted() { | ||
| let service = service_with_allowed_origins(&["http://localhost:8080"]); | ||
| let response = service.handle(init_request(Some("null"))).await; | ||
| assert_eq!(response.status(), http::StatusCode::FORBIDDEN); | ||
| } | ||
| } |
| #![cfg(not(feature = "local"))] | ||
| use std::collections::HashMap; | ||
| use std::{ | ||
| collections::HashMap, | ||
| sync::atomic::{AtomicUsize, Ordering}, | ||
| }; | ||
@@ -87,1 +90,284 @@ use futures::future::BoxFuture; | ||
| } | ||
| fn build_router() -> ToolRouter<TestHandler<()>> { | ||
| ToolRouter::<TestHandler<()>>::new() | ||
| .with_route((async_function_tool_attr(), async_function)) | ||
| .with_route((async_function2_tool_attr(), async_function2)) | ||
| + TestHandler::<()>::test_router_1() | ||
| + TestHandler::<()>::test_router_2() | ||
| } | ||
| #[test] | ||
| fn test_disable_route() { | ||
| let mut router = build_router(); | ||
| assert_eq!(router.list_all().len(), 4); | ||
| assert!(router.has_route("async_function")); | ||
| assert!(router.get("async_function").is_some()); | ||
| assert!(router.disable_route("async_function")); | ||
| assert_eq!(router.list_all().len(), 3); | ||
| assert!(!router.has_route("async_function")); | ||
| assert!(router.get("async_function").is_none()); | ||
| assert!(router.is_disabled("async_function")); | ||
| // other tools unaffected | ||
| assert!(router.has_route("async_function2")); | ||
| assert!(router.get("async_function2").is_some()); | ||
| assert!(!router.is_disabled("async_function2")); | ||
| } | ||
| #[test] | ||
| fn test_enable_route() { | ||
| let mut router = build_router(); | ||
| assert!(router.disable_route("async_function")); | ||
| assert!(!router.has_route("async_function")); | ||
| assert!(router.enable_route("async_function")); | ||
| assert!(router.has_route("async_function")); | ||
| assert!(router.get("async_function").is_some()); | ||
| assert!(!router.is_disabled("async_function")); | ||
| assert_eq!(router.list_all().len(), 4); | ||
| } | ||
| #[test] | ||
| fn test_with_disabled_builder() { | ||
| let router = build_router() | ||
| .with_disabled("async_function") | ||
| .with_disabled("sync_method"); | ||
| assert_eq!(router.list_all().len(), 2); | ||
| assert!(!router.has_route("async_function")); | ||
| assert!(!router.has_route("sync_method")); | ||
| assert!(router.has_route("async_function2")); | ||
| assert!(router.has_route("async_method")); | ||
| } | ||
| #[test] | ||
| fn test_disabled_tools_survive_merge() { | ||
| let mut router_a = ToolRouter::<TestHandler<()>>::new() | ||
| .with_route((async_function_tool_attr(), async_function)); | ||
| assert!(router_a.disable_route("async_function")); | ||
| let router_b = ToolRouter::<TestHandler<()>>::new() | ||
| .with_route((async_function2_tool_attr(), async_function2)); | ||
| router_a.merge(router_b); | ||
| assert_eq!(router_a.list_all().len(), 1); | ||
| assert!(router_a.is_disabled("async_function")); | ||
| assert!(router_a.has_route("async_function2")); | ||
| } | ||
| #[test] | ||
| fn test_disable_nonexistent_tool() { | ||
| let mut router = build_router(); | ||
| // should not panic; returns true because the name is newly added to disabled set | ||
| assert!(router.disable_route("does_not_exist")); | ||
| assert_eq!(router.list_all().len(), 4); | ||
| // is_disabled returns false for tools not in the map | ||
| assert!(!router.is_disabled("does_not_exist")); | ||
| } | ||
| #[test] | ||
| fn test_remove_route_preserves_disabled_state() { | ||
| let mut router = build_router(); | ||
| assert!(router.disable_route("async_function")); | ||
| assert!(router.is_disabled("async_function")); | ||
| router.remove_route("async_function"); | ||
| assert!(!router.has_route("async_function")); | ||
| // Disabled marker is preserved — is_disabled returns false (no route in map) | ||
| // but re-adding will inherit the disabled state (tested separately) | ||
| assert!(!router.is_disabled("async_function")); | ||
| } | ||
| #[test] | ||
| fn test_remove_route_then_readd_stays_disabled() { | ||
| let mut router = build_router(); | ||
| assert!(router.disable_route("async_function")); | ||
| router.remove_route("async_function"); | ||
| assert!(!router.has_route("async_function")); | ||
| // Re-add the route — it should inherit the disabled state | ||
| let other = ToolRouter::<TestHandler<()>>::new() | ||
| .with_route((async_function_tool_attr(), async_function)); | ||
| router.merge(other); | ||
| assert!(!router.has_route("async_function")); | ||
| assert!(router.is_disabled("async_function")); | ||
| assert!(router.get("async_function").is_none()); | ||
| } | ||
| #[test] | ||
| fn test_into_iter_skips_disabled() { | ||
| let router = build_router().with_disabled("async_function"); | ||
| let names: Vec<_> = router | ||
| .into_iter() | ||
| .map(|r| r.attr.name.to_string()) | ||
| .collect(); | ||
| assert_eq!(names.len(), 3); | ||
| assert!(!names.contains(&"async_function".to_string())); | ||
| } | ||
| #[test] | ||
| fn test_pre_disable_before_add_route() { | ||
| // Disabling a name before adding a route with that name should | ||
| // result in the route being disabled once added. | ||
| let router = ToolRouter::<TestHandler<()>>::new() | ||
| .with_disabled("async_function") | ||
| .with_route((async_function_tool_attr(), async_function)); | ||
| assert_eq!(router.list_all().len(), 0); | ||
| assert!(router.is_disabled("async_function")); | ||
| assert!(!router.has_route("async_function")); | ||
| } | ||
| #[test] | ||
| fn test_disabled_tool_invisible_across_all_queries() { | ||
| let router = build_router().with_disabled("async_function"); | ||
| // Not listed | ||
| let names: Vec<_> = router.list_all().iter().map(|t| t.name.clone()).collect(); | ||
| assert!(!names.contains(&"async_function".into())); | ||
| // Not retrievable | ||
| assert!(router.get("async_function").is_none()); | ||
| // Not routable | ||
| assert!(!router.has_route("async_function")); | ||
| // But known as disabled | ||
| assert!(router.is_disabled("async_function")); | ||
| } | ||
| #[test] | ||
| fn test_disable_route_then_add_route_blocks_tool() { | ||
| // Full pre-disable lifecycle via runtime mutation (not builder) | ||
| let mut router = ToolRouter::<TestHandler<()>>::new(); | ||
| router.disable_route("async_function"); | ||
| // Add route after disabling — tool should be blocked | ||
| let other = ToolRouter::<TestHandler<()>>::new() | ||
| .with_route((async_function_tool_attr(), async_function)); | ||
| router.merge(other); | ||
| assert!(router.is_disabled("async_function")); | ||
| assert!(!router.has_route("async_function")); | ||
| assert!(router.get("async_function").is_none()); | ||
| assert_eq!(router.list_all().len(), 0); | ||
| } | ||
| #[test] | ||
| fn test_disable_enable_return_false_cases() { | ||
| let mut router = build_router(); | ||
| // Repeated disable returns false | ||
| assert!(router.disable_route("async_function")); | ||
| assert!(!router.disable_route("async_function")); | ||
| // Enable returns true, then false on repeat | ||
| assert!(router.enable_route("async_function")); | ||
| assert!(!router.enable_route("async_function")); | ||
| // Enable on name never disabled returns false | ||
| assert!(!router.enable_route("async_function2")); | ||
| // Enable on unknown name returns false | ||
| assert!(!router.enable_route("unknown")); | ||
| } | ||
| // ── Notifier tests ────────────────────────────────────────────────────── | ||
| fn counter_notifier() -> ( | ||
| impl Fn() + Send + Sync + 'static, | ||
| std::sync::Arc<AtomicUsize>, | ||
| ) { | ||
| let counter = std::sync::Arc::new(AtomicUsize::new(0)); | ||
| let c = counter.clone(); | ||
| let notifier = move || { | ||
| c.fetch_add(1, Ordering::SeqCst); | ||
| }; | ||
| (notifier, counter) | ||
| } | ||
| #[test] | ||
| fn test_notifier_fires_on_disable_and_enable() { | ||
| let (notifier, counter) = counter_notifier(); | ||
| let mut router = build_router(); | ||
| router.set_notifier(notifier); | ||
| assert!(router.disable_route("async_function")); | ||
| assert_eq!(counter.load(Ordering::SeqCst), 1); | ||
| assert!(!router.disable_route("async_function")); | ||
| assert_eq!(counter.load(Ordering::SeqCst), 1); | ||
| assert!(router.enable_route("async_function")); | ||
| assert_eq!(counter.load(Ordering::SeqCst), 2); | ||
| assert!(!router.enable_route("async_function")); | ||
| assert_eq!(counter.load(Ordering::SeqCst), 2); | ||
| } | ||
| #[test] | ||
| fn test_notifier_skips_nonexistent_tools() { | ||
| let (notifier, counter) = counter_notifier(); | ||
| let mut router = build_router(); | ||
| router.set_notifier(notifier); | ||
| assert!(router.disable_route("does_not_exist")); | ||
| assert_eq!(counter.load(Ordering::SeqCst), 0); | ||
| assert!(router.enable_route("does_not_exist")); | ||
| assert_eq!(counter.load(Ordering::SeqCst), 0); | ||
| assert!(router.disable_route("future_tool")); | ||
| assert_eq!(counter.load(Ordering::SeqCst), 0); | ||
| assert!(router.enable_route("future_tool")); | ||
| assert_eq!(counter.load(Ordering::SeqCst), 0); | ||
| } | ||
| #[test] | ||
| fn test_no_notifier_no_panic() { | ||
| let mut router = build_router(); | ||
| assert!(router.disable_route("async_function")); | ||
| assert!(router.enable_route("async_function")); | ||
| assert!(router.disable_route("async_function")); | ||
| assert!(!router.disable_route("async_function")); | ||
| } | ||
| #[test] | ||
| fn test_clone_shares_notifier() { | ||
| let (notifier, counter) = counter_notifier(); | ||
| let mut router = build_router(); | ||
| router.set_notifier(notifier); | ||
| let mut cloned = router.clone(); | ||
| assert!(cloned.disable_route("async_function")); | ||
| assert_eq!(counter.load(Ordering::SeqCst), 1); | ||
| assert!(router.disable_route("async_function")); | ||
| assert_eq!(counter.load(Ordering::SeqCst), 2); | ||
| cloned.clear_notifier(); | ||
| assert!(cloned.enable_route("async_function")); | ||
| assert_eq!(counter.load(Ordering::SeqCst), 2); | ||
| assert!(router.enable_route("async_function")); | ||
| assert_eq!(counter.load(Ordering::SeqCst), 3); | ||
| } | ||
| #[test] | ||
| fn test_pre_init_disable_silent_but_correct() { | ||
| let mut router = build_router(); | ||
| assert!(router.disable_route("async_function")); | ||
| assert_eq!(router.list_all().len(), 3); | ||
| assert!(!router.has_route("async_function")); | ||
| let (notifier, counter) = counter_notifier(); | ||
| router.set_notifier(notifier); | ||
| assert_eq!(counter.load(Ordering::SeqCst), 0); | ||
| assert!(router.enable_route("async_function")); | ||
| assert_eq!(counter.load(Ordering::SeqCst), 1); | ||
| } |
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet