| use std::{borrow::Cow, collections::HashMap, sync::Arc}; | ||
| use bytes::Bytes; | ||
| use futures::{StreamExt, stream::BoxStream}; | ||
| use http::{HeaderName, HeaderValue, Method, Request, StatusCode, header::WWW_AUTHENTICATE}; | ||
| use http_body_util::{BodyExt, Full}; | ||
| use hyper::body::Incoming; | ||
| use hyper_util::rt::TokioIo; | ||
| use sse_stream::{Sse, SseStream}; | ||
| use tokio::net::UnixStream; | ||
| use crate::{ | ||
| model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, | ||
| transport::{ | ||
| common::http_header::{ | ||
| EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE, | ||
| extract_scope_from_header, validate_custom_header, | ||
| }, | ||
| streamable_http_client::*, | ||
| }, | ||
| }; | ||
| #[derive(Debug, thiserror::Error)] | ||
| #[non_exhaustive] | ||
| pub enum UnixSocketError { | ||
| #[error("hyper error: {0}")] | ||
| Hyper(#[from] hyper::Error), | ||
| #[error("IO error: {0}")] | ||
| Io(#[from] std::io::Error), | ||
| #[error("HTTP error: {0}")] | ||
| Http(#[from] http::Error), | ||
| #[error("JSON error: {0}")] | ||
| Json(#[from] serde_json::Error), | ||
| } | ||
| impl From<UnixSocketError> for StreamableHttpError<UnixSocketError> { | ||
| fn from(e: UnixSocketError) -> Self { | ||
| StreamableHttpError::Client(e) | ||
| } | ||
| } | ||
| /// HTTP client that routes requests through a Unix domain socket. | ||
| /// | ||
| /// Implements [`StreamableHttpClient`] using `hyper` over `tokio::net::UnixStream`, | ||
| /// enabling MCP hosts in Kubernetes environments to connect through Envoy sidecars | ||
| /// or other Unix socket-based proxies. | ||
| /// | ||
| /// Each request opens a new Unix socket connection (no connection pooling). | ||
| /// This is appropriate when connecting through a sidecar proxy that manages | ||
| /// its own upstream connection pool. | ||
| /// | ||
| /// # Example | ||
| /// | ||
| /// ```rust,no_run | ||
| /// use rmcp::transport::{StreamableHttpClientTransport, UnixSocketHttpClient}; | ||
| /// use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; | ||
| /// | ||
| /// let client = UnixSocketHttpClient::new("/var/run/envoy.sock", "http://mcp-server.internal/mcp"); | ||
| /// let config = StreamableHttpClientTransportConfig::with_uri("http://mcp-server.internal/mcp"); | ||
| /// let transport = StreamableHttpClientTransport::with_client(client, config); | ||
| /// ``` | ||
| #[derive(Clone, Debug)] | ||
| pub struct UnixSocketHttpClient { | ||
| socket_path: Arc<str>, | ||
| host_header: HeaderValue, | ||
| } | ||
| impl UnixSocketHttpClient { | ||
| /// Creates a new Unix socket HTTP client. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `socket_path` - Path to the Unix domain socket. Use `@name` syntax for Linux | ||
| /// abstract sockets (e.g., `@egress.sock` becomes `\0egress.sock`). | ||
| /// * `uri` - The MCP server URI. The authority (host:port) is extracted for the | ||
| /// HTTP `Host` header, since hyper does not auto-set it for Unix socket connections. | ||
| /// | ||
| /// # Panics | ||
| /// | ||
| /// Panics if `socket_path` is empty or is `@` with no name (empty abstract socket). | ||
| pub fn new(socket_path: &str, uri: &str) -> Self { | ||
| assert!( | ||
| !socket_path.is_empty() && socket_path != "@", | ||
| "socket_path must not be empty or a bare '@' (empty abstract socket name)" | ||
| ); | ||
| let host_header = uri | ||
| .parse::<http::Uri>() | ||
| .ok() | ||
| .and_then(|u| u.authority().cloned()) | ||
| .and_then(|a| HeaderValue::from_str(a.as_str()).ok()) | ||
| .unwrap_or_else(|| HeaderValue::from_static("localhost")); | ||
| Self { | ||
| socket_path: resolve_socket_path(socket_path).into(), | ||
| host_header, | ||
| } | ||
| } | ||
| } | ||
| /// Converts the `@`-prefixed abstract socket notation to the null-byte prefix | ||
| /// expected by the Linux kernel. Filesystem socket paths are returned unchanged. | ||
| fn resolve_socket_path(raw: &str) -> String { | ||
| if let Some(name) = raw.strip_prefix('@') { | ||
| format!("\0{name}") | ||
| } else { | ||
| raw.to_string() | ||
| } | ||
| } | ||
| async fn connect_unix(socket_path: &str) -> Result<UnixStream, std::io::Error> { | ||
| #[cfg(target_os = "linux")] | ||
| if let Some(abstract_name) = socket_path.strip_prefix('\0') { | ||
| let abstract_name = abstract_name.to_string(); | ||
| let std_stream = tokio::task::spawn_blocking(move || { | ||
| use std::os::linux::net::SocketAddrExt; | ||
| let addr = std::os::unix::net::SocketAddr::from_abstract_name(&abstract_name)?; | ||
| let stream = std::os::unix::net::UnixStream::connect_addr(&addr)?; | ||
| stream.set_nonblocking(true)?; | ||
| Ok::<_, std::io::Error>(stream) | ||
| }) | ||
| .await | ||
| .map_err(std::io::Error::other)??; | ||
| return UnixStream::from_std(std_stream); | ||
| } | ||
| UnixStream::connect(socket_path).await | ||
| } | ||
| /// Opens a new Unix socket connection and sends the HTTP request. | ||
| /// One connection per request — the sidecar proxy handles connection pooling. | ||
| async fn send_http_request( | ||
| socket_path: &str, | ||
| request: Request<Full<Bytes>>, | ||
| ) -> Result<http::Response<Incoming>, UnixSocketError> { | ||
| let stream = connect_unix(socket_path).await?; | ||
| let io = TokioIo::new(stream); | ||
| let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; | ||
| tokio::spawn(async move { | ||
| if let Err(e) = conn.await { | ||
| tracing::warn!("unix socket HTTP/1.1 connection error: {e}"); | ||
| } | ||
| }); | ||
| Ok(sender.send_request(request).await?) | ||
| } | ||
| /// Applies custom headers to a request builder, rejecting reserved headers. | ||
| fn apply_custom_headers( | ||
| mut builder: http::request::Builder, | ||
| custom_headers: HashMap<HeaderName, HeaderValue>, | ||
| ) -> Result<http::request::Builder, StreamableHttpError<UnixSocketError>> { | ||
| for (name, value) in custom_headers { | ||
| validate_custom_header(&name).map_err(StreamableHttpError::ReservedHeaderConflict)?; | ||
| builder = builder.header(name, value); | ||
| } | ||
| Ok(builder) | ||
| } | ||
| impl StreamableHttpClient for UnixSocketHttpClient { | ||
| type Error = UnixSocketError; | ||
| async fn post_message( | ||
| &self, | ||
| uri: Arc<str>, | ||
| message: ClientJsonRpcMessage, | ||
| session_id: Option<Arc<str>>, | ||
| auth_token: Option<String>, | ||
| custom_headers: HashMap<HeaderName, HeaderValue>, | ||
| ) -> Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>> { | ||
| let json_body = serde_json::to_string(&message) | ||
| .map_err(|e| StreamableHttpError::Client(UnixSocketError::Json(e)))?; | ||
| let mut builder = Request::builder() | ||
| .method(Method::POST) | ||
| .uri(uri.as_ref()) | ||
| .header(http::header::HOST, self.host_header.clone()) | ||
| .header(http::header::CONTENT_TYPE, JSON_MIME_TYPE) | ||
| .header( | ||
| http::header::ACCEPT, | ||
| format!("{EVENT_STREAM_MIME_TYPE}, {JSON_MIME_TYPE}"), | ||
| ); | ||
| if let Some(auth) = auth_token { | ||
| builder = builder.header(http::header::AUTHORIZATION, format!("Bearer {auth}")); | ||
| } | ||
| builder = apply_custom_headers(builder, custom_headers)?; | ||
| let session_was_attached = session_id.is_some(); | ||
| if let Some(sid) = session_id { | ||
| builder = builder.header(HEADER_SESSION_ID, sid.as_ref()); | ||
| } | ||
| let request = builder | ||
| .body(Full::new(Bytes::from(json_body))) | ||
| .map_err(|e| StreamableHttpError::Client(UnixSocketError::Http(e)))?; | ||
| let response = send_http_request(&self.socket_path, request) | ||
| .await | ||
| .map_err(StreamableHttpError::Client)?; | ||
| let status = response.status(); | ||
| if status == StatusCode::UNAUTHORIZED { | ||
| if let Some(header) = response.headers().get(WWW_AUTHENTICATE) { | ||
| let www_authenticate_header = header | ||
| .to_str() | ||
| .map_err(|_| { | ||
| StreamableHttpError::UnexpectedServerResponse(Cow::from( | ||
| "invalid www-authenticate header value", | ||
| )) | ||
| })? | ||
| .to_string(); | ||
| return Err(StreamableHttpError::AuthRequired(AuthRequiredError { | ||
| www_authenticate_header, | ||
| })); | ||
| } | ||
| } | ||
| if status == StatusCode::FORBIDDEN { | ||
| if let Some(header) = response.headers().get(WWW_AUTHENTICATE) { | ||
| let header_str = header.to_str().map_err(|_| { | ||
| StreamableHttpError::UnexpectedServerResponse(Cow::from( | ||
| "invalid www-authenticate header value", | ||
| )) | ||
| })?; | ||
| let scope = extract_scope_from_header(header_str); | ||
| return Err(StreamableHttpError::InsufficientScope( | ||
| InsufficientScopeError { | ||
| www_authenticate_header: header_str.to_string(), | ||
| required_scope: scope, | ||
| }, | ||
| )); | ||
| } | ||
| } | ||
| if matches!(status, StatusCode::ACCEPTED | StatusCode::NO_CONTENT) { | ||
| return Ok(StreamableHttpPostResponse::Accepted); | ||
| } | ||
| if status == StatusCode::NOT_FOUND && session_was_attached { | ||
| return Err(StreamableHttpError::SessionExpired); | ||
| } | ||
| if !status.is_success() { | ||
| let body = response | ||
| .into_body() | ||
| .collect() | ||
| .await | ||
| .map(|c| String::from_utf8_lossy(&c.to_bytes()).into_owned()) | ||
| .unwrap_or_else(|_| "<failed to read response body>".to_owned()); | ||
| return Err(StreamableHttpError::UnexpectedServerResponse(Cow::Owned( | ||
| format!("HTTP {status}: {body}"), | ||
| ))); | ||
| } | ||
| let content_type = response.headers().get(http::header::CONTENT_TYPE).cloned(); | ||
| let session_id = response | ||
| .headers() | ||
| .get(HEADER_SESSION_ID) | ||
| .and_then(|v| v.to_str().ok()) | ||
| .map(|s| s.to_string()); | ||
| match content_type { | ||
| Some(ref ct) if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) => { | ||
| let sse_stream = SseStream::new(response.into_body()).boxed(); | ||
| Ok(StreamableHttpPostResponse::Sse(sse_stream, session_id)) | ||
| } | ||
| Some(ref ct) if ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) => { | ||
| let body = response | ||
| .into_body() | ||
| .collect() | ||
| .await | ||
| .map_err(|e| StreamableHttpError::Client(UnixSocketError::Hyper(e)))? | ||
| .to_bytes(); | ||
| match serde_json::from_slice::<ServerJsonRpcMessage>(&body) { | ||
| Ok(message) => Ok(StreamableHttpPostResponse::Json(message, session_id)), | ||
| Err(e) => { | ||
| tracing::warn!( | ||
| "could not parse JSON response as ServerJsonRpcMessage, treating as accepted: {e}" | ||
| ); | ||
| Ok(StreamableHttpPostResponse::Accepted) | ||
| } | ||
| } | ||
| } | ||
| _ => Err(StreamableHttpError::UnexpectedContentType( | ||
| content_type.map(|ct| String::from_utf8_lossy(ct.as_bytes()).into_owned()), | ||
| )), | ||
| } | ||
| } | ||
| async fn delete_session( | ||
| &self, | ||
| uri: Arc<str>, | ||
| session_id: Arc<str>, | ||
| auth_token: Option<String>, | ||
| custom_headers: HashMap<HeaderName, HeaderValue>, | ||
| ) -> Result<(), StreamableHttpError<Self::Error>> { | ||
| let mut builder = Request::builder() | ||
| .method(Method::DELETE) | ||
| .uri(uri.as_ref()) | ||
| .header(http::header::HOST, self.host_header.clone()) | ||
| .header(HEADER_SESSION_ID, session_id.as_ref()); | ||
| if let Some(auth) = auth_token { | ||
| builder = builder.header(http::header::AUTHORIZATION, format!("Bearer {auth}")); | ||
| } | ||
| builder = apply_custom_headers(builder, custom_headers)?; | ||
| let request = builder | ||
| .body(Full::new(Bytes::new())) | ||
| .map_err(|e| StreamableHttpError::Client(UnixSocketError::Http(e)))?; | ||
| let response = send_http_request(&self.socket_path, request) | ||
| .await | ||
| .map_err(StreamableHttpError::Client)?; | ||
| if response.status() == StatusCode::METHOD_NOT_ALLOWED { | ||
| tracing::debug!("this server doesn't support deleting session"); | ||
| return Ok(()); | ||
| } | ||
| if !response.status().is_success() { | ||
| return Err(StreamableHttpError::UnexpectedServerResponse(Cow::Owned( | ||
| format!("delete_session returned {}", response.status()), | ||
| ))); | ||
| } | ||
| Ok(()) | ||
| } | ||
| async fn get_stream( | ||
| &self, | ||
| uri: Arc<str>, | ||
| session_id: Arc<str>, | ||
| last_event_id: Option<String>, | ||
| auth_token: Option<String>, | ||
| custom_headers: HashMap<HeaderName, HeaderValue>, | ||
| ) -> Result<BoxStream<'static, Result<Sse, sse_stream::Error>>, StreamableHttpError<Self::Error>> | ||
| { | ||
| let mut builder = Request::builder() | ||
| .method(Method::GET) | ||
| .uri(uri.as_ref()) | ||
| .header(http::header::HOST, self.host_header.clone()) | ||
| .header( | ||
| http::header::ACCEPT, | ||
| format!("{EVENT_STREAM_MIME_TYPE}, {JSON_MIME_TYPE}"), | ||
| ) | ||
| .header(HEADER_SESSION_ID, session_id.as_ref()); | ||
| if let Some(last_id) = last_event_id { | ||
| builder = builder.header(HEADER_LAST_EVENT_ID, last_id); | ||
| } | ||
| if let Some(auth) = auth_token { | ||
| builder = builder.header(http::header::AUTHORIZATION, format!("Bearer {auth}")); | ||
| } | ||
| builder = apply_custom_headers(builder, custom_headers)?; | ||
| let request = builder | ||
| .body(Full::new(Bytes::new())) | ||
| .map_err(|e| StreamableHttpError::Client(UnixSocketError::Http(e)))?; | ||
| let response = send_http_request(&self.socket_path, request) | ||
| .await | ||
| .map_err(StreamableHttpError::Client)?; | ||
| if response.status() == StatusCode::METHOD_NOT_ALLOWED { | ||
| return Err(StreamableHttpError::ServerDoesNotSupportSse); | ||
| } | ||
| if response.status() == StatusCode::UNAUTHORIZED { | ||
| if let Some(header) = response.headers().get(WWW_AUTHENTICATE) { | ||
| let www_authenticate_header = header | ||
| .to_str() | ||
| .map_err(|_| { | ||
| StreamableHttpError::UnexpectedServerResponse(Cow::from( | ||
| "invalid www-authenticate header value", | ||
| )) | ||
| })? | ||
| .to_string(); | ||
| return Err(StreamableHttpError::AuthRequired(AuthRequiredError { | ||
| www_authenticate_header, | ||
| })); | ||
| } | ||
| } | ||
| if response.status() == StatusCode::FORBIDDEN { | ||
| if let Some(header) = response.headers().get(WWW_AUTHENTICATE) { | ||
| let header_str = header.to_str().map_err(|_| { | ||
| StreamableHttpError::UnexpectedServerResponse(Cow::from( | ||
| "invalid www-authenticate header value", | ||
| )) | ||
| })?; | ||
| let scope = extract_scope_from_header(header_str); | ||
| return Err(StreamableHttpError::InsufficientScope( | ||
| InsufficientScopeError { | ||
| www_authenticate_header: header_str.to_string(), | ||
| required_scope: scope, | ||
| }, | ||
| )); | ||
| } | ||
| } | ||
| if !response.status().is_success() { | ||
| return Err(StreamableHttpError::UnexpectedServerResponse(Cow::Owned( | ||
| format!("get_stream returned {}", response.status()), | ||
| ))); | ||
| } | ||
| match response.headers().get(http::header::CONTENT_TYPE) { | ||
| Some(ct) => { | ||
| if !ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) | ||
| && !ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) | ||
| { | ||
| return Err(StreamableHttpError::UnexpectedContentType(Some( | ||
| String::from_utf8_lossy(ct.as_bytes()).to_string(), | ||
| ))); | ||
| } | ||
| } | ||
| None => { | ||
| return Err(StreamableHttpError::UnexpectedContentType(None)); | ||
| } | ||
| } | ||
| Ok(SseStream::new(response.into_body()).boxed()) | ||
| } | ||
| } | ||
| impl StreamableHttpClientTransport<UnixSocketHttpClient> { | ||
| /// Creates a new transport connecting through a Unix domain socket. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `socket_path` - Path to the Unix domain socket. Use `@name` for Linux abstract sockets. | ||
| /// * `uri` - The MCP server URI (used for HTTP Host header and request path). | ||
| pub fn from_unix_socket(socket_path: &str, uri: impl Into<Arc<str>>) -> Self { | ||
| let uri: Arc<str> = uri.into(); | ||
| let client = UnixSocketHttpClient::new(socket_path, &uri); | ||
| let config = StreamableHttpClientTransportConfig { | ||
| uri, | ||
| ..Default::default() | ||
| }; | ||
| StreamableHttpClientTransport::with_client(client, config) | ||
| } | ||
| /// Creates a new transport connecting through a Unix domain socket with custom config. | ||
| /// | ||
| /// # Arguments | ||
| /// | ||
| /// * `socket_path` - Path to the Unix domain socket. Use `@name` for Linux abstract sockets. | ||
| /// * `config` - Transport configuration (URI, retry policy, custom headers, etc.). | ||
| pub fn from_unix_socket_with_config( | ||
| socket_path: &str, | ||
| config: StreamableHttpClientTransportConfig, | ||
| ) -> Self { | ||
| let client = UnixSocketHttpClient::new(socket_path, &config.uri); | ||
| StreamableHttpClientTransport::with_client(client, config) | ||
| } | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| #[test] | ||
| fn resolve_abstract_socket() { | ||
| assert_eq!(resolve_socket_path("@egress.sock"), "\0egress.sock"); | ||
| } | ||
| #[test] | ||
| fn resolve_filesystem_socket() { | ||
| assert_eq!( | ||
| resolve_socket_path("/var/run/envoy.sock"), | ||
| "/var/run/envoy.sock" | ||
| ); | ||
| } | ||
| #[test] | ||
| fn resolve_empty_abstract() { | ||
| assert_eq!(resolve_socket_path("@"), "\0"); | ||
| } | ||
| #[test] | ||
| #[should_panic(expected = "socket_path must not be empty")] | ||
| fn rejects_bare_at_symbol() { | ||
| UnixSocketHttpClient::new("@", "http://localhost/mcp"); | ||
| } | ||
| #[test] | ||
| #[should_panic(expected = "socket_path must not be empty")] | ||
| fn rejects_empty_path() { | ||
| UnixSocketHttpClient::new("", "http://localhost/mcp"); | ||
| } | ||
| #[test] | ||
| fn host_header_auto_derived() { | ||
| let client = | ||
| UnixSocketHttpClient::new("/var/run/envoy.sock", "http://mcp-server.internal/mcp"); | ||
| assert_eq!(client.host_header, "mcp-server.internal"); | ||
| } | ||
| #[test] | ||
| fn host_header_with_port() { | ||
| let client = | ||
| UnixSocketHttpClient::new("/var/run/envoy.sock", "http://mcp-server.internal:8080/mcp"); | ||
| assert_eq!(client.host_header, "mcp-server.internal:8080"); | ||
| } | ||
| #[test] | ||
| fn host_header_fallback_on_path_only_uri() { | ||
| let client = UnixSocketHttpClient::new("/var/run/envoy.sock", "/mcp"); | ||
| assert_eq!(client.host_header, "localhost"); | ||
| } | ||
| #[test] | ||
| fn reserved_header_rejected() { | ||
| let mut headers = HashMap::new(); | ||
| headers.insert( | ||
| HeaderName::from_static("accept"), | ||
| HeaderValue::from_static("text/plain"), | ||
| ); | ||
| let builder = Request::builder(); | ||
| let result = apply_custom_headers(builder, headers); | ||
| assert!(matches!( | ||
| result, | ||
| Err(StreamableHttpError::ReservedHeaderConflict(_)) | ||
| )); | ||
| } | ||
| #[test] | ||
| fn mcp_protocol_version_allowed_through() { | ||
| let mut headers = HashMap::new(); | ||
| headers.insert( | ||
| HeaderName::from_static("mcp-protocol-version"), | ||
| HeaderValue::from_static("2025-03-26"), | ||
| ); | ||
| let builder = Request::builder().uri("http://localhost/mcp").method("GET"); | ||
| let result = apply_custom_headers(builder, headers); | ||
| assert!(result.is_ok()); | ||
| } | ||
| } |
| #![cfg(not(feature = "local"))] | ||
| // cargo test --test test_inflight_response_drain --features "client server" | ||
| use std::{ | ||
| pin::Pin, | ||
| sync::{ | ||
| Arc, | ||
| atomic::{AtomicBool, Ordering}, | ||
| }, | ||
| task::{Context, Poll}, | ||
| time::Duration, | ||
| }; | ||
| use rmcp::{ | ||
| ServerHandler, ServiceExt, | ||
| handler::server::{router::tool::ToolRouter, wrapper::Parameters}, | ||
| model::{CallToolRequestParams, ClientInfo, ServerCapabilities, ServerInfo}, | ||
| service::QuitReason, | ||
| tool, tool_handler, tool_router, | ||
| }; | ||
| use tokio::io::{AsyncRead, ReadBuf}; | ||
| // A slow tool server that sleeps before returning a response. | ||
| #[derive(Debug, Clone)] | ||
| struct SlowToolServer { | ||
| tool_router: ToolRouter<Self>, | ||
| } | ||
| impl SlowToolServer { | ||
| fn new() -> Self { | ||
| Self { | ||
| tool_router: Self::tool_router(), | ||
| } | ||
| } | ||
| } | ||
| #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] | ||
| struct SlowToolRequest { | ||
| #[schemars(description = "how long to sleep in milliseconds")] | ||
| sleep_ms: u64, | ||
| } | ||
| #[tool_router] | ||
| impl SlowToolServer { | ||
| #[tool(description = "A tool that sleeps then returns")] | ||
| async fn slow_tool( | ||
| &self, | ||
| Parameters(SlowToolRequest { sleep_ms }): Parameters<SlowToolRequest>, | ||
| ) -> String { | ||
| tokio::time::sleep(Duration::from_millis(sleep_ms)).await; | ||
| format!("done after {}ms", sleep_ms) | ||
| } | ||
| } | ||
| #[tool_handler] | ||
| impl ServerHandler for SlowToolServer { | ||
| fn get_info(&self) -> ServerInfo { | ||
| ServerInfo::new(ServerCapabilities::builder().enable_tools().build()) | ||
| } | ||
| } | ||
| #[derive(Debug, Clone, Default)] | ||
| struct DummyClientHandler; | ||
| impl rmcp::ClientHandler for DummyClientHandler { | ||
| fn get_info(&self) -> ClientInfo { | ||
| ClientInfo::default() | ||
| } | ||
| } | ||
| /// An `AsyncRead` wrapper that delegates to the inner reader until signalled, | ||
| /// then returns EOF (read 0 bytes). | ||
| struct ClosableReader<R> { | ||
| inner: R, | ||
| eof_flag: Arc<AtomicBool>, | ||
| } | ||
| impl<R: AsyncRead + Unpin> AsyncRead for ClosableReader<R> { | ||
| fn poll_read( | ||
| mut self: Pin<&mut Self>, | ||
| cx: &mut Context<'_>, | ||
| buf: &mut ReadBuf<'_>, | ||
| ) -> Poll<std::io::Result<()>> { | ||
| if self.eof_flag.load(Ordering::Acquire) { | ||
| return Poll::Ready(Ok(())); | ||
| } | ||
| Pin::new(&mut self.inner).poll_read(cx, buf) | ||
| } | ||
| } | ||
| /// When the server's input stream returns EOF while a tool handler is still | ||
| /// in-flight, the drain phase should flush pending responses before closing. | ||
| #[tokio::test] | ||
| async fn test_inflight_response_drain_on_eof() -> anyhow::Result<()> { | ||
| // Two unidirectional channels: | ||
| // client_write → server_read (client sends requests to server) | ||
| // server_write → client_read (server sends responses to client) | ||
| let (client_write, server_read) = tokio::io::duplex(4096); | ||
| let (server_write, client_read) = tokio::io::duplex(4096); | ||
| // Wrap the server's read side so we can signal EOF from the test. | ||
| let eof_flag = Arc::new(AtomicBool::new(false)); | ||
| let closable_read = ClosableReader { | ||
| inner: server_read, | ||
| eof_flag: eof_flag.clone(), | ||
| }; | ||
| let server_transport = (closable_read, server_write); | ||
| let client_transport = (client_read, client_write); | ||
| // Start server with slow tool handler | ||
| let server_handle = tokio::spawn(async move { | ||
| let server = SlowToolServer::new(); | ||
| let running = server.serve(server_transport).await?; | ||
| let reason = running.waiting().await?; | ||
| assert!( | ||
| matches!(reason, QuitReason::Closed), | ||
| "expected Closed quit reason, got {:?}", | ||
| reason, | ||
| ); | ||
| anyhow::Ok(()) | ||
| }); | ||
| // Start client | ||
| let client = DummyClientHandler.serve(client_transport).await?; | ||
| // Call the slow tool (200ms sleep). Concurrently, signal the server's | ||
| // read side to return EOF after the request has been sent but before | ||
| // the handler finishes. | ||
| let tool_future = client.call_tool( | ||
| CallToolRequestParams::new("slow_tool").with_arguments( | ||
| serde_json::json!({ "sleep_ms": 200 }) | ||
| .as_object() | ||
| .unwrap() | ||
| .clone(), | ||
| ), | ||
| ); | ||
| let (tool_result, _) = tokio::join!(tool_future, async { | ||
| // Wait for the request to be sent and received by the server, | ||
| // then signal EOF on the server's read side. | ||
| tokio::time::sleep(Duration::from_millis(50)).await; | ||
| eof_flag.store(true, Ordering::Release); | ||
| }); | ||
| // The tool result should still arrive thanks to the drain phase. | ||
| let result = tool_result?; | ||
| let text = result | ||
| .content | ||
| .first() | ||
| .and_then(|c| c.raw.as_text()) | ||
| .map(|t| t.text.as_str()) | ||
| .expect("expected text content in tool result"); | ||
| assert_eq!(text, "done after 200ms"); | ||
| server_handle.await??; | ||
| Ok(()) | ||
| } |
| #![cfg(all( | ||
| feature = "transport-streamable-http-client", | ||
| feature = "transport-streamable-http-client-reqwest", | ||
| not(feature = "local") | ||
| ))] | ||
| use std::{collections::HashMap, sync::Arc}; | ||
| use rmcp::{ | ||
| model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId}, | ||
| transport::streamable_http_client::{ | ||
| StreamableHttpClient, StreamableHttpError, StreamableHttpPostResponse, | ||
| }, | ||
| }; | ||
| /// Spin up a minimal axum server that always responds with the given status, | ||
| /// content-type, and body — no MCP logic involved. | ||
| async fn spawn_mock_server(status: u16, content_type: &'static str, body: &'static str) -> String { | ||
| use axum::{Router, body::Body, http::Response, routing::post}; | ||
| let router = Router::new().route( | ||
| "/mcp", | ||
| post(move || async move { | ||
| Response::builder() | ||
| .status(status) | ||
| .header("content-type", content_type) | ||
| .body(Body::from(body)) | ||
| .unwrap() | ||
| }), | ||
| ); | ||
| let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); | ||
| let addr = listener.local_addr().unwrap(); | ||
| tokio::spawn(async move { | ||
| axum::serve(listener, router).await.unwrap(); | ||
| }); | ||
| format!("http://{addr}/mcp") | ||
| } | ||
| fn ping_message() -> ClientJsonRpcMessage { | ||
| ClientJsonRpcMessage::request( | ||
| ClientRequest::PingRequest(PingRequest::default()), | ||
| RequestId::Number(1), | ||
| ) | ||
| } | ||
| /// HTTP 4xx with Content-Type: application/json and a valid JSON-RPC error body | ||
| /// must be surfaced as `StreamableHttpPostResponse::Json`, not swallowed as a | ||
| /// transport error. | ||
| #[tokio::test] | ||
| async fn http_4xx_json_rpc_error_body_is_surfaced_as_json_response() { | ||
| let body = r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid Request"}}"#; | ||
| let url = spawn_mock_server(400, "application/json", body).await; | ||
| let client = reqwest::Client::new(); | ||
| let result = client | ||
| .post_message( | ||
| Arc::from(url.as_str()), | ||
| ping_message(), | ||
| None, | ||
| None, | ||
| HashMap::new(), | ||
| ) | ||
| .await; | ||
| match result { | ||
| Ok(StreamableHttpPostResponse::Json(msg, _)) => { | ||
| let json = serde_json::to_value(&msg).unwrap(); | ||
| assert_eq!(json["error"]["code"], -32600); | ||
| assert_eq!(json["error"]["message"], "Invalid Request"); | ||
| } | ||
| other => panic!("expected Json response, got: {other:?}"), | ||
| } | ||
| } | ||
| /// HTTP 4xx with non-JSON content-type must still return `UnexpectedServerResponse` | ||
| /// (no regression on the original error path). | ||
| #[tokio::test] | ||
| async fn http_4xx_non_json_body_returns_unexpected_server_response() { | ||
| let url = spawn_mock_server(400, "text/plain", "Bad Request").await; | ||
| let client = reqwest::Client::new(); | ||
| let result = client | ||
| .post_message( | ||
| Arc::from(url.as_str()), | ||
| ping_message(), | ||
| None, | ||
| None, | ||
| HashMap::new(), | ||
| ) | ||
| .await; | ||
| match result { | ||
| Err(StreamableHttpError::UnexpectedServerResponse(_)) => {} | ||
| other => panic!("expected UnexpectedServerResponse, got: {other:?}"), | ||
| } | ||
| } | ||
| /// HTTP 4xx with Content-Type: application/json but a body that is NOT a valid | ||
| /// JSON-RPC message must fall back to `UnexpectedServerResponse`. | ||
| #[tokio::test] | ||
| async fn http_4xx_malformed_json_body_falls_back_to_unexpected_server_response() { | ||
| let url = spawn_mock_server(400, "application/json", r#"{"error":"not jsonrpc"}"#).await; | ||
| let client = reqwest::Client::new(); | ||
| let result = client | ||
| .post_message( | ||
| Arc::from(url.as_str()), | ||
| ping_message(), | ||
| None, | ||
| None, | ||
| HashMap::new(), | ||
| ) | ||
| .await; | ||
| match result { | ||
| Err(StreamableHttpError::UnexpectedServerResponse(_)) => {} | ||
| other => panic!("expected UnexpectedServerResponse, got: {other:?}"), | ||
| } | ||
| } |
| #![cfg(all( | ||
| unix, | ||
| feature = "transport-streamable-http-client-unix-socket", | ||
| not(feature = "local") | ||
| ))] | ||
| use std::{collections::HashMap, sync::Arc}; | ||
| use axum::{ | ||
| Router, body::Bytes, extract::State, http::StatusCode, response::IntoResponse, routing::post, | ||
| }; | ||
| use http::{HeaderName, HeaderValue}; | ||
| use hyper_util::rt::TokioIo; | ||
| use rmcp::{ | ||
| ServiceExt, | ||
| transport::{ | ||
| StreamableHttpClientTransport, UnixSocketHttpClient, | ||
| streamable_http_client::StreamableHttpClientTransportConfig, | ||
| }, | ||
| }; | ||
| use serde_json::json; | ||
| use tokio::sync::Mutex; | ||
| #[derive(Clone)] | ||
| struct ServerState { | ||
| received_headers: Arc<Mutex<HashMap<String, String>>>, | ||
| initialize_called: Arc<tokio::sync::Notify>, | ||
| } | ||
| async fn mcp_handler( | ||
| State(state): State<ServerState>, | ||
| headers: http::HeaderMap, | ||
| body: Bytes, | ||
| ) -> impl IntoResponse { | ||
| let mut headers_map = HashMap::new(); | ||
| for (name, value) in headers.iter() { | ||
| let name_str = name.as_str(); | ||
| if name_str.starts_with("x-") || name_str == "host" { | ||
| if let Ok(v) = value.to_str() { | ||
| headers_map.insert(name_str.to_string(), v.to_string()); | ||
| } | ||
| } | ||
| } | ||
| let mut stored = state.received_headers.lock().await; | ||
| stored.extend(headers_map); | ||
| drop(stored); | ||
| if let Ok(json_body) = serde_json::from_slice::<serde_json::Value>(&body) { | ||
| if let Some(method) = json_body.get("method").and_then(|m| m.as_str()) { | ||
| if method == "initialize" { | ||
| state.initialize_called.notify_one(); | ||
| let response = json!({ | ||
| "jsonrpc": "2.0", | ||
| "id": json_body.get("id"), | ||
| "result": { | ||
| "protocolVersion": "2024-11-05", | ||
| "capabilities": {}, | ||
| "serverInfo": { | ||
| "name": "test-unix-server", | ||
| "version": "1.0.0" | ||
| } | ||
| } | ||
| }); | ||
| return ( | ||
| StatusCode::OK, | ||
| [ | ||
| (http::header::CONTENT_TYPE, "application/json"), | ||
| ( | ||
| http::HeaderName::from_static("mcp-session-id"), | ||
| "unix-test-session", | ||
| ), | ||
| ], | ||
| response.to_string(), | ||
| ); | ||
| } else if method == "notifications/initialized" { | ||
| return ( | ||
| StatusCode::ACCEPTED, | ||
| [ | ||
| (http::header::CONTENT_TYPE, "application/json"), | ||
| ( | ||
| http::HeaderName::from_static("mcp-session-id"), | ||
| "unix-test-session", | ||
| ), | ||
| ], | ||
| String::new(), | ||
| ); | ||
| } | ||
| } | ||
| } | ||
| let request_id = serde_json::from_slice::<serde_json::Value>(&body) | ||
| .ok() | ||
| .and_then(|j| j.get("id").cloned()) | ||
| .unwrap_or(serde_json::Value::Null); | ||
| let response = json!({ | ||
| "jsonrpc": "2.0", | ||
| "id": request_id, | ||
| "result": {} | ||
| }); | ||
| ( | ||
| StatusCode::OK, | ||
| [ | ||
| (http::header::CONTENT_TYPE, "application/json"), | ||
| ( | ||
| http::HeaderName::from_static("mcp-session-id"), | ||
| "unix-test-session", | ||
| ), | ||
| ], | ||
| response.to_string(), | ||
| ) | ||
| } | ||
| /// Spawns an HTTP/1.1 server on a Unix socket using hyper directly. | ||
| /// Avoids `axum::serve(UnixListener, ...)` which uses `spawn_local` on Linux. | ||
| fn spawn_unix_server( | ||
| listener: tokio::net::UnixListener, | ||
| app: Router, | ||
| ) -> tokio::task::JoinHandle<()> { | ||
| tokio::spawn(async move { | ||
| while let Ok((stream, _)) = listener.accept().await { | ||
| let tower_service = app.clone(); | ||
| tokio::spawn(async move { | ||
| let io = TokioIo::new(stream); | ||
| let hyper_service = hyper::service::service_fn( | ||
| move |req: hyper::Request<hyper::body::Incoming>| { | ||
| let mut tower_service = tower_service.clone(); | ||
| async move { | ||
| use tower_service::Service; | ||
| tower_service.call(req).await | ||
| } | ||
| }, | ||
| ); | ||
| hyper::server::conn::http1::Builder::new() | ||
| .serve_connection(io, hyper_service) | ||
| .await | ||
| .ok(); | ||
| }); | ||
| } | ||
| }) | ||
| } | ||
| /// Integration test: MCP client connects and completes handshake over a Unix domain socket. | ||
| #[tokio::test] | ||
| async fn test_unix_socket_mcp_handshake() -> anyhow::Result<()> { | ||
| let dir = std::env::temp_dir().join(format!("rmcp-test-{}", std::process::id())); | ||
| std::fs::create_dir_all(&dir)?; | ||
| let socket_path = dir.join("mcp.sock"); | ||
| let _ = std::fs::remove_file(&socket_path); | ||
| let state = ServerState { | ||
| received_headers: Arc::new(Mutex::new(HashMap::new())), | ||
| initialize_called: Arc::new(tokio::sync::Notify::new()), | ||
| }; | ||
| let app = Router::new() | ||
| .route("/mcp", post(mcp_handler)) | ||
| .with_state(state.clone()); | ||
| let listener = tokio::net::UnixListener::bind(&socket_path)?; | ||
| let server_handle = spawn_unix_server(listener, app); | ||
| let socket_str = socket_path.to_str().unwrap(); | ||
| let uri = "http://mcp-server.internal/mcp"; | ||
| let client = UnixSocketHttpClient::new(socket_str, uri); | ||
| let config = StreamableHttpClientTransportConfig::with_uri(uri); | ||
| let transport = StreamableHttpClientTransport::with_client(client, config); | ||
| let mcp_client = ().serve(transport).await.expect("MCP handshake should succeed"); | ||
| tokio::time::timeout( | ||
| std::time::Duration::from_secs(5), | ||
| state.initialize_called.notified(), | ||
| ) | ||
| .await | ||
| .expect("Initialize request should be received"); | ||
| let headers = state.received_headers.lock().await; | ||
| assert_eq!( | ||
| headers.get("host"), | ||
| Some(&"mcp-server.internal".to_string()), | ||
| "Host header should be derived from URI" | ||
| ); | ||
| drop(mcp_client); | ||
| server_handle.abort(); | ||
| let _ = std::fs::remove_file(&socket_path); | ||
| let _ = std::fs::remove_dir(&dir); | ||
| Ok(()) | ||
| } | ||
| /// Integration test: Custom headers are sent through the Unix socket transport. | ||
| #[tokio::test] | ||
| async fn test_unix_socket_custom_headers() -> anyhow::Result<()> { | ||
| let dir = std::env::temp_dir().join(format!("rmcp-test-headers-{}", std::process::id())); | ||
| std::fs::create_dir_all(&dir)?; | ||
| let socket_path = dir.join("mcp.sock"); | ||
| let _ = std::fs::remove_file(&socket_path); | ||
| let state = ServerState { | ||
| received_headers: Arc::new(Mutex::new(HashMap::new())), | ||
| initialize_called: Arc::new(tokio::sync::Notify::new()), | ||
| }; | ||
| let app = Router::new() | ||
| .route("/mcp", post(mcp_handler)) | ||
| .with_state(state.clone()); | ||
| let listener = tokio::net::UnixListener::bind(&socket_path)?; | ||
| let server_handle = spawn_unix_server(listener, app); | ||
| let mut custom_headers = HashMap::new(); | ||
| custom_headers.insert( | ||
| HeaderName::from_static("x-test-header"), | ||
| HeaderValue::from_static("test-value-123"), | ||
| ); | ||
| custom_headers.insert( | ||
| HeaderName::from_static("x-client-id"), | ||
| HeaderValue::from_static("unix-test-client"), | ||
| ); | ||
| let socket_str = socket_path.to_str().unwrap(); | ||
| let uri = "http://mcp-server.internal/mcp"; | ||
| let client = UnixSocketHttpClient::new(socket_str, uri); | ||
| let config = StreamableHttpClientTransportConfig::with_uri(uri).custom_headers(custom_headers); | ||
| let transport = StreamableHttpClientTransport::with_client(client, config); | ||
| let mcp_client = ().serve(transport).await.expect("MCP handshake should succeed"); | ||
| tokio::time::timeout( | ||
| std::time::Duration::from_secs(5), | ||
| state.initialize_called.notified(), | ||
| ) | ||
| .await | ||
| .expect("Initialize request should be received"); | ||
| let headers = state.received_headers.lock().await; | ||
| assert_eq!( | ||
| headers.get("x-test-header"), | ||
| Some(&"test-value-123".to_string()), | ||
| "Custom header x-test-header should be received" | ||
| ); | ||
| assert_eq!( | ||
| headers.get("x-client-id"), | ||
| Some(&"unix-test-client".to_string()), | ||
| "Custom header x-client-id should be received" | ||
| ); | ||
| drop(mcp_client); | ||
| server_handle.abort(); | ||
| let _ = std::fs::remove_file(&socket_path); | ||
| let _ = std::fs::remove_dir(&dir); | ||
| Ok(()) | ||
| } | ||
| /// Integration test: Convenience constructor `from_unix_socket` works end-to-end. | ||
| #[tokio::test] | ||
| async fn test_unix_socket_convenience_constructor() -> anyhow::Result<()> { | ||
| let dir = std::env::temp_dir().join(format!("rmcp-test-conv-{}", std::process::id())); | ||
| std::fs::create_dir_all(&dir)?; | ||
| let socket_path = dir.join("mcp.sock"); | ||
| let _ = std::fs::remove_file(&socket_path); | ||
| let state = ServerState { | ||
| received_headers: Arc::new(Mutex::new(HashMap::new())), | ||
| initialize_called: Arc::new(tokio::sync::Notify::new()), | ||
| }; | ||
| let app = Router::new() | ||
| .route("/mcp", post(mcp_handler)) | ||
| .with_state(state.clone()); | ||
| let listener = tokio::net::UnixListener::bind(&socket_path)?; | ||
| let server_handle = spawn_unix_server(listener, app); | ||
| let socket_str = socket_path.to_str().unwrap(); | ||
| let transport = | ||
| StreamableHttpClientTransport::from_unix_socket(socket_str, "http://localhost/mcp"); | ||
| let mcp_client = ().serve(transport).await.expect("MCP handshake should succeed"); | ||
| tokio::time::timeout( | ||
| std::time::Duration::from_secs(5), | ||
| state.initialize_called.notified(), | ||
| ) | ||
| .await | ||
| .expect("Initialize request should be received"); | ||
| drop(mcp_client); | ||
| server_handle.abort(); | ||
| let _ = std::fs::remove_file(&socket_path); | ||
| let _ = std::fs::remove_dir(&dir); | ||
| Ok(()) | ||
| } |
| { | ||
| "git": { | ||
| "sha1": "3bd75220708b2e9f8c74a3fe3277ac5d4f03f478" | ||
| "sha1": "ac749e3cedfc036a5b77960337669c7cf2338035" | ||
| }, | ||
| "path_in_vcs": "crates/rmcp" | ||
| } |
+71
-3
@@ -15,3 +15,3 @@ # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO | ||
| name = "rmcp" | ||
| version = "1.2.0" | ||
| version = "1.3.0" | ||
| build = "build.rs" | ||
@@ -28,3 +28,2 @@ autolib = false | ||
| license = "Apache-2.0" | ||
| license-file = "LICENSE" | ||
| repository = "https://github.com/modelcontextprotocol/rust-sdk/" | ||
@@ -63,2 +62,3 @@ resolver = "2" | ||
| elicitation = ["dep:url"] | ||
| local = ["rmcp-macros?/local"] | ||
| macros = [ | ||
@@ -119,2 +119,11 @@ "dep:rmcp-macros", | ||
| ] | ||
| transport-streamable-http-client-unix-socket = [ | ||
| "transport-streamable-http-client", | ||
| "dep:hyper", | ||
| "dep:hyper-util", | ||
| "dep:http-body-util", | ||
| "dep:http", | ||
| "dep:bytes", | ||
| "tokio/net", | ||
| ] | ||
| transport-streamable-http-server = [ | ||
@@ -200,2 +209,6 @@ "transport-streamable-http-server-session", | ||
| [[test]] | ||
| name = "test_inflight_response_drain" | ||
| path = "tests/test_inflight_response_drain.rs" | ||
| [[test]] | ||
| name = "test_json_schema_detection" | ||
@@ -295,2 +308,10 @@ path = "tests/test_json_schema_detection.rs" | ||
| [[test]] | ||
| name = "test_streamable_http_4xx_error_body" | ||
| path = "tests/test_streamable_http_4xx_error_body.rs" | ||
| required-features = [ | ||
| "transport-streamable-http-client", | ||
| "transport-streamable-http-client-reqwest", | ||
| ] | ||
| [[test]] | ||
| name = "test_streamable_http_json_response" | ||
@@ -318,2 +339,9 @@ path = "tests/test_streamable_http_json_response.rs" | ||
| path = "tests/test_streamable_http_stale_session.rs" | ||
| required-features = [ | ||
| "server", | ||
| "client", | ||
| "transport-streamable-http-server", | ||
| "transport-streamable-http-client", | ||
| "transport-streamable-http-client-reqwest", | ||
| ] | ||
@@ -366,2 +394,11 @@ [[test]] | ||
| [[test]] | ||
| name = "test_unix_socket_transport" | ||
| path = "tests/test_unix_socket_transport.rs" | ||
| required-features = [ | ||
| "client", | ||
| "server", | ||
| "transport-streamable-http-client-unix-socket", | ||
| ] | ||
| [[test]] | ||
| name = "test_with_js" | ||
@@ -414,2 +451,15 @@ path = "tests/test_with_js.rs" | ||
| [dependencies.hyper] | ||
| version = "1" | ||
| features = [ | ||
| "client", | ||
| "http1", | ||
| ] | ||
| optional = true | ||
| [dependencies.hyper-util] | ||
| version = "0.1" | ||
| features = ["tokio"] | ||
| optional = true | ||
| [dependencies.jsonwebtoken] | ||
@@ -450,3 +500,3 @@ version = "10" | ||
| [dependencies.rmcp-macros] | ||
| version = "1.2.0" | ||
| version = "1.3.0" | ||
| optional = true | ||
@@ -522,2 +572,13 @@ | ||
| [dev-dependencies.hyper] | ||
| version = "1" | ||
| features = [ | ||
| "server", | ||
| "http1", | ||
| ] | ||
| [dev-dependencies.hyper-util] | ||
| version = "0.1" | ||
| features = ["tokio"] | ||
| [dev-dependencies.schemars] | ||
@@ -531,2 +592,5 @@ version = "1.1.0" | ||
| [dev-dependencies.tower-service] | ||
| version = "0.3" | ||
| [dev-dependencies.tracing-subscriber] | ||
@@ -556,1 +620,5 @@ version = "0.3" | ||
| features = ["serde"] | ||
| [lints.clippy] | ||
| exhaustive_enums = "warn" | ||
| exhaustive_structs = "warn" |
+23
-0
@@ -10,2 +10,25 @@ # Changelog | ||
| ## [1.3.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v1.2.0...rmcp-v1.3.0) - 2026-03-24 | ||
| ### Added | ||
| - *(transport)* add Unix domain socket client for streamable HTTP ([#749](https://github.com/modelcontextprotocol/rust-sdk/pull/749)) | ||
| - *(auth)* implement SEP-2207 OIDC-flavored refresh token guidance ([#676](https://github.com/modelcontextprotocol/rust-sdk/pull/676)) | ||
| - add configuration for transparent session re-init ([#760](https://github.com/modelcontextprotocol/rust-sdk/pull/760)) | ||
| - add local feature for !Send tool handler support ([#740](https://github.com/modelcontextprotocol/rust-sdk/pull/740)) | ||
| ### Fixed | ||
| - prevent CallToolResult and GetTaskPayloadResult from shadowing CustomResult in untagged enums ([#771](https://github.com/modelcontextprotocol/rust-sdk/pull/771)) | ||
| - drain in-flight responses on stdin EOF ([#759](https://github.com/modelcontextprotocol/rust-sdk/pull/759)) | ||
| - remove default type param from StreamableHttpService ([#758](https://github.com/modelcontextprotocol/rust-sdk/pull/758)) | ||
| - use cfg-gated Send+Sync supertraits to avoid semver break ([#757](https://github.com/modelcontextprotocol/rust-sdk/pull/757)) | ||
| - *(rmcp)* surface JSON-RPC error bodies on HTTP 4xx responses ([#748](https://github.com/modelcontextprotocol/rust-sdk/pull/748)) | ||
| - default CallToolResult content to empty vec on missing field ([#752](https://github.com/modelcontextprotocol/rust-sdk/pull/752)) | ||
| - *(auth)* redact secrets in Debug output for StoredCredentials and StoredAuthorizationState ([#744](https://github.com/modelcontextprotocol/rust-sdk/pull/744)) | ||
| ### Other | ||
| - fix all clippy warnings across workspace ([#746](https://github.com/modelcontextprotocol/rust-sdk/pull/746)) | ||
| ## [1.2.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v1.1.1...rmcp-v1.2.0) - 2026-03-11 | ||
@@ -12,0 +35,0 @@ |
+1
-1
@@ -55,3 +55,3 @@ <style> | ||
| | **stdio** | [`TokioChildProcess`](crate::transport::TokioChildProcess) | [`stdio`](crate::transport::stdio) | | ||
| | **Streamable HTTP** | [`StreamableHttpClientTransport`](crate::transport::StreamableHttpClientTransport) | [`StreamableHttpService`](crate::transport::StreamableHttpService) | | ||
| | **Streamable HTTP** | [`StreamableHttpClientTransport`](crate::transport::StreamableHttpClientTransport) | `StreamableHttpService` | | ||
@@ -58,0 +58,0 @@ Any type that implements the [`Transport`](crate::transport::Transport) trait can be used. The [`IntoTransport`](crate::transport::IntoTransport) helper trait provides automatic conversions from: |
+31
-28
@@ -7,3 +7,5 @@ pub mod progress; | ||
| model::*, | ||
| service::{NotificationContext, RequestContext, RoleClient, Service, ServiceRole}, | ||
| service::{ | ||
| MaybeSendFuture, NotificationContext, RequestContext, RoleClient, Service, ServiceRole, | ||
| }, | ||
| }; | ||
@@ -87,3 +89,3 @@ | ||
| context: RequestContext<RoleClient>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Ok(())) | ||
@@ -96,3 +98,3 @@ } | ||
| context: RequestContext<RoleClient>, | ||
| ) -> impl Future<Output = Result<CreateMessageResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<CreateMessageResult, McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Err( | ||
@@ -106,3 +108,3 @@ McpError::method_not_found::<CreateMessageRequestMethod>(), | ||
| context: RequestContext<RoleClient>, | ||
| ) -> impl Future<Output = Result<ListRootsResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<ListRootsResult, McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Ok(ListRootsResult::default())) | ||
@@ -169,3 +171,4 @@ } | ||
| context: RequestContext<RoleClient>, | ||
| ) -> impl Future<Output = Result<CreateElicitationResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<CreateElicitationResult, McpError>> + MaybeSendFuture + '_ | ||
| { | ||
| // Default implementation declines all requests - real clients should override this | ||
@@ -183,3 +186,3 @@ let _ = (request, context); | ||
| context: RequestContext<RoleClient>, | ||
| ) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<CustomResult, McpError>> + MaybeSendFuture + '_ { | ||
| let CustomRequest { method, .. } = request; | ||
@@ -198,3 +201,3 @@ let _ = context; | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| std::future::ready(()) | ||
@@ -206,3 +209,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| std::future::ready(()) | ||
@@ -214,3 +217,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| std::future::ready(()) | ||
@@ -222,3 +225,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| std::future::ready(()) | ||
@@ -229,3 +232,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| std::future::ready(()) | ||
@@ -236,3 +239,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| std::future::ready(()) | ||
@@ -243,3 +246,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| std::future::ready(()) | ||
@@ -252,3 +255,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| std::future::ready(()) | ||
@@ -260,3 +263,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| let _ = (notification, context); | ||
@@ -287,3 +290,3 @@ std::future::ready(()) | ||
| context: RequestContext<RoleClient>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ { | ||
| (**self).ping(context) | ||
@@ -296,3 +299,3 @@ } | ||
| context: RequestContext<RoleClient>, | ||
| ) -> impl Future<Output = Result<CreateMessageResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<CreateMessageResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).create_message(params, context) | ||
@@ -304,3 +307,3 @@ } | ||
| context: RequestContext<RoleClient>, | ||
| ) -> impl Future<Output = Result<ListRootsResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<ListRootsResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).list_roots(context) | ||
@@ -313,3 +316,3 @@ } | ||
| context: RequestContext<RoleClient>, | ||
| ) -> impl Future<Output = Result<CreateElicitationResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<CreateElicitationResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).create_elicitation(request, context) | ||
@@ -322,3 +325,3 @@ } | ||
| context: RequestContext<RoleClient>, | ||
| ) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<CustomResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).on_custom_request(request, context) | ||
@@ -331,3 +334,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| (**self).on_cancelled(params, context) | ||
@@ -340,3 +343,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| (**self).on_progress(params, context) | ||
@@ -349,3 +352,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| (**self).on_logging_message(params, context) | ||
@@ -358,3 +361,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| (**self).on_resource_updated(params, context) | ||
@@ -366,3 +369,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| (**self).on_resource_list_changed(context) | ||
@@ -374,3 +377,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| (**self).on_tool_list_changed(context) | ||
@@ -382,3 +385,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| (**self).on_prompt_list_changed(context) | ||
@@ -391,3 +394,3 @@ } | ||
| context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| (**self).on_custom_notification(notification, context) | ||
@@ -394,0 +397,0 @@ } |
+240
-221
@@ -6,3 +6,5 @@ use std::sync::Arc; | ||
| model::{TaskSupport, *}, | ||
| service::{NotificationContext, RequestContext, RoleServer, Service, ServiceRole}, | ||
| service::{ | ||
| MaybeSendFuture, NotificationContext, RequestContext, RoleServer, Service, ServiceRole, | ||
| }, | ||
| }; | ||
@@ -162,207 +164,224 @@ | ||
| #[allow(unused_variables)] | ||
| pub trait ServerHandler: Sized + Send + Sync + 'static { | ||
| fn enqueue_task( | ||
| &self, | ||
| _request: CallToolRequestParams, | ||
| _context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + Send + '_ { | ||
| std::future::ready(Err(McpError::internal_error( | ||
| "Task processing not implemented".to_string(), | ||
| None, | ||
| ))) | ||
| } | ||
| fn ping( | ||
| &self, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + Send + '_ { | ||
| std::future::ready(Ok(())) | ||
| } | ||
| // handle requests | ||
| fn initialize( | ||
| &self, | ||
| request: InitializeRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<InitializeResult, McpError>> + Send + '_ { | ||
| if context.peer.peer_info().is_none() { | ||
| context.peer.set_peer_info(request); | ||
| macro_rules! server_handler_methods { | ||
| () => { | ||
| fn enqueue_task( | ||
| &self, | ||
| _request: CallToolRequestParams, | ||
| _context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Err(McpError::internal_error( | ||
| "Task processing not implemented".to_string(), | ||
| None, | ||
| ))) | ||
| } | ||
| std::future::ready(Ok(self.get_info())) | ||
| } | ||
| fn complete( | ||
| &self, | ||
| request: CompleteRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CompleteResult, McpError>> + Send + '_ { | ||
| std::future::ready(Ok(CompleteResult::default())) | ||
| } | ||
| fn set_level( | ||
| &self, | ||
| request: SetLevelRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + Send + '_ { | ||
| std::future::ready(Err(McpError::method_not_found::<SetLevelRequestMethod>())) | ||
| } | ||
| fn get_prompt( | ||
| &self, | ||
| request: GetPromptRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<GetPromptResult, McpError>> + Send + '_ { | ||
| std::future::ready(Err(McpError::method_not_found::<GetPromptRequestMethod>())) | ||
| } | ||
| fn list_prompts( | ||
| &self, | ||
| request: Option<PaginatedRequestParams>, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + Send + '_ { | ||
| std::future::ready(Ok(ListPromptsResult::default())) | ||
| } | ||
| fn list_resources( | ||
| &self, | ||
| request: Option<PaginatedRequestParams>, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ { | ||
| std::future::ready(Ok(ListResourcesResult::default())) | ||
| } | ||
| fn list_resource_templates( | ||
| &self, | ||
| request: Option<PaginatedRequestParams>, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + Send + '_ { | ||
| std::future::ready(Ok(ListResourceTemplatesResult::default())) | ||
| } | ||
| fn read_resource( | ||
| &self, | ||
| request: ReadResourceRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + Send + '_ { | ||
| std::future::ready(Err( | ||
| McpError::method_not_found::<ReadResourceRequestMethod>(), | ||
| )) | ||
| } | ||
| fn subscribe( | ||
| &self, | ||
| request: SubscribeRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + Send + '_ { | ||
| std::future::ready(Err(McpError::method_not_found::<SubscribeRequestMethod>())) | ||
| } | ||
| fn unsubscribe( | ||
| &self, | ||
| request: UnsubscribeRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + Send + '_ { | ||
| std::future::ready(Err(McpError::method_not_found::<UnsubscribeRequestMethod>())) | ||
| } | ||
| fn call_tool( | ||
| &self, | ||
| request: CallToolRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CallToolResult, McpError>> + Send + '_ { | ||
| std::future::ready(Err(McpError::method_not_found::<CallToolRequestMethod>())) | ||
| } | ||
| fn list_tools( | ||
| &self, | ||
| request: Option<PaginatedRequestParams>, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListToolsResult, McpError>> + Send + '_ { | ||
| std::future::ready(Ok(ListToolsResult::default())) | ||
| } | ||
| /// Get a tool definition by name. | ||
| /// | ||
| /// The default implementation returns `None`, which bypasses validation. | ||
| /// When using `#[tool_handler]`, this method is automatically implemented. | ||
| fn get_tool(&self, _name: &str) -> Option<Tool> { | ||
| None | ||
| } | ||
| fn on_custom_request( | ||
| &self, | ||
| request: CustomRequest, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ { | ||
| let CustomRequest { method, .. } = request; | ||
| let _ = context; | ||
| std::future::ready(Err(McpError::new( | ||
| ErrorCode::METHOD_NOT_FOUND, | ||
| method, | ||
| None, | ||
| ))) | ||
| } | ||
| fn ping( | ||
| &self, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Ok(())) | ||
| } | ||
| // handle requests | ||
| fn initialize( | ||
| &self, | ||
| request: InitializeRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<InitializeResult, McpError>> + MaybeSendFuture + '_ { | ||
| if context.peer.peer_info().is_none() { | ||
| context.peer.set_peer_info(request); | ||
| } | ||
| std::future::ready(Ok(self.get_info())) | ||
| } | ||
| fn complete( | ||
| &self, | ||
| request: CompleteRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CompleteResult, McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Ok(CompleteResult::default())) | ||
| } | ||
| fn set_level( | ||
| &self, | ||
| request: SetLevelRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Err(McpError::method_not_found::<SetLevelRequestMethod>())) | ||
| } | ||
| fn get_prompt( | ||
| &self, | ||
| request: GetPromptRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<GetPromptResult, McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Err(McpError::method_not_found::<GetPromptRequestMethod>())) | ||
| } | ||
| fn list_prompts( | ||
| &self, | ||
| request: Option<PaginatedRequestParams>, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Ok(ListPromptsResult::default())) | ||
| } | ||
| fn list_resources( | ||
| &self, | ||
| request: Option<PaginatedRequestParams>, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Ok(ListResourcesResult::default())) | ||
| } | ||
| fn list_resource_templates( | ||
| &self, | ||
| request: Option<PaginatedRequestParams>, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> | ||
| + MaybeSendFuture | ||
| + '_ { | ||
| std::future::ready(Ok(ListResourceTemplatesResult::default())) | ||
| } | ||
| fn read_resource( | ||
| &self, | ||
| request: ReadResourceRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Err( | ||
| McpError::method_not_found::<ReadResourceRequestMethod>(), | ||
| )) | ||
| } | ||
| fn subscribe( | ||
| &self, | ||
| request: SubscribeRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Err(McpError::method_not_found::<SubscribeRequestMethod>())) | ||
| } | ||
| fn unsubscribe( | ||
| &self, | ||
| request: UnsubscribeRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Err( | ||
| McpError::method_not_found::<UnsubscribeRequestMethod>(), | ||
| )) | ||
| } | ||
| fn call_tool( | ||
| &self, | ||
| request: CallToolRequestParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CallToolResult, McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Err(McpError::method_not_found::<CallToolRequestMethod>())) | ||
| } | ||
| fn list_tools( | ||
| &self, | ||
| request: Option<PaginatedRequestParams>, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListToolsResult, McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Ok(ListToolsResult::default())) | ||
| } | ||
| /// Get a tool definition by name. | ||
| /// | ||
| /// The default implementation returns `None`, which bypasses validation. | ||
| /// When using `#[tool_handler]`, this method is automatically implemented. | ||
| fn get_tool(&self, _name: &str) -> Option<Tool> { | ||
| None | ||
| } | ||
| fn on_custom_request( | ||
| &self, | ||
| request: CustomRequest, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CustomResult, McpError>> + MaybeSendFuture + '_ { | ||
| let CustomRequest { method, .. } = request; | ||
| let _ = context; | ||
| std::future::ready(Err(McpError::new( | ||
| ErrorCode::METHOD_NOT_FOUND, | ||
| method, | ||
| None, | ||
| ))) | ||
| } | ||
| fn on_cancelled( | ||
| &self, | ||
| notification: CancelledNotificationParam, | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| std::future::ready(()) | ||
| } | ||
| fn on_progress( | ||
| &self, | ||
| notification: ProgressNotificationParam, | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| std::future::ready(()) | ||
| } | ||
| fn on_initialized( | ||
| &self, | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| tracing::info!("client initialized"); | ||
| std::future::ready(()) | ||
| } | ||
| fn on_roots_list_changed( | ||
| &self, | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| std::future::ready(()) | ||
| } | ||
| fn on_custom_notification( | ||
| &self, | ||
| notification: CustomNotification, | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| let _ = (notification, context); | ||
| std::future::ready(()) | ||
| } | ||
| fn on_cancelled( | ||
| &self, | ||
| notification: CancelledNotificationParam, | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| std::future::ready(()) | ||
| } | ||
| fn on_progress( | ||
| &self, | ||
| notification: ProgressNotificationParam, | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| std::future::ready(()) | ||
| } | ||
| fn on_initialized( | ||
| &self, | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| tracing::info!("client initialized"); | ||
| std::future::ready(()) | ||
| } | ||
| fn on_roots_list_changed( | ||
| &self, | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| std::future::ready(()) | ||
| } | ||
| fn on_custom_notification( | ||
| &self, | ||
| notification: CustomNotification, | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| let _ = (notification, context); | ||
| std::future::ready(()) | ||
| } | ||
| fn get_info(&self) -> ServerInfo { | ||
| ServerInfo::default() | ||
| } | ||
| fn get_info(&self) -> ServerInfo { | ||
| ServerInfo::default() | ||
| } | ||
| fn list_tasks( | ||
| &self, | ||
| request: Option<PaginatedRequestParams>, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListTasksResult, McpError>> + Send + '_ { | ||
| std::future::ready(Err(McpError::method_not_found::<ListTasksMethod>())) | ||
| } | ||
| fn list_tasks( | ||
| &self, | ||
| request: Option<PaginatedRequestParams>, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListTasksResult, McpError>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Err(McpError::method_not_found::<ListTasksMethod>())) | ||
| } | ||
| fn get_task_info( | ||
| &self, | ||
| request: GetTaskInfoParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<GetTaskResult, McpError>> + Send + '_ { | ||
| let _ = (request, context); | ||
| std::future::ready(Err(McpError::method_not_found::<GetTaskInfoMethod>())) | ||
| } | ||
| fn get_task_info( | ||
| &self, | ||
| request: GetTaskInfoParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<GetTaskResult, McpError>> + MaybeSendFuture + '_ { | ||
| let _ = (request, context); | ||
| std::future::ready(Err(McpError::method_not_found::<GetTaskInfoMethod>())) | ||
| } | ||
| fn get_task_result( | ||
| &self, | ||
| request: GetTaskResultParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + Send + '_ { | ||
| let _ = (request, context); | ||
| std::future::ready(Err(McpError::method_not_found::<GetTaskResultMethod>())) | ||
| } | ||
| fn get_task_result( | ||
| &self, | ||
| request: GetTaskResultParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + MaybeSendFuture + '_ { | ||
| let _ = (request, context); | ||
| std::future::ready(Err(McpError::method_not_found::<GetTaskResultMethod>())) | ||
| } | ||
| fn cancel_task( | ||
| &self, | ||
| request: CancelTaskParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CancelTaskResult, McpError>> + Send + '_ { | ||
| let _ = (request, context); | ||
| std::future::ready(Err(McpError::method_not_found::<CancelTaskMethod>())) | ||
| } | ||
| fn cancel_task( | ||
| &self, | ||
| request: CancelTaskParams, | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CancelTaskResult, McpError>> + MaybeSendFuture + '_ { | ||
| let _ = (request, context); | ||
| std::future::ready(Err(McpError::method_not_found::<CancelTaskMethod>())) | ||
| } | ||
| }; | ||
| } | ||
| #[allow(unused_variables)] | ||
| #[cfg(not(feature = "local"))] | ||
| pub trait ServerHandler: Sized + Send + Sync + 'static { | ||
| server_handler_methods!(); | ||
| } | ||
| #[allow(unused_variables)] | ||
| #[cfg(feature = "local")] | ||
| pub trait ServerHandler: Sized + 'static { | ||
| server_handler_methods!(); | ||
| } | ||
| macro_rules! impl_server_handler_for_wrapper { | ||
@@ -375,3 +394,3 @@ ($wrapper:ident) => { | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<CreateTaskResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).enqueue_task(request, context) | ||
@@ -383,3 +402,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ { | ||
| (**self).ping(context) | ||
@@ -392,3 +411,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<InitializeResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<InitializeResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).initialize(request, context) | ||
@@ -401,3 +420,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CompleteResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<CompleteResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).complete(request, context) | ||
@@ -410,3 +429,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ { | ||
| (**self).set_level(request, context) | ||
@@ -419,3 +438,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<GetPromptResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<GetPromptResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).get_prompt(request, context) | ||
@@ -428,3 +447,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<ListPromptsResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).list_prompts(request, context) | ||
@@ -437,3 +456,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<ListResourcesResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).list_resources(request, context) | ||
@@ -446,3 +465,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + Send + '_ | ||
| ) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + MaybeSendFuture + '_ | ||
| { | ||
@@ -456,3 +475,3 @@ (**self).list_resource_templates(request, context) | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<ReadResourceResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).read_resource(request, context) | ||
@@ -465,3 +484,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ { | ||
| (**self).subscribe(request, context) | ||
@@ -474,3 +493,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ { | ||
| (**self).unsubscribe(request, context) | ||
@@ -483,3 +502,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CallToolResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<CallToolResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).call_tool(request, context) | ||
@@ -492,3 +511,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListToolsResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<ListToolsResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).list_tools(request, context) | ||
@@ -505,3 +524,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CustomResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<CustomResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).on_custom_request(request, context) | ||
@@ -514,3 +533,3 @@ } | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| (**self).on_cancelled(notification, context) | ||
@@ -523,3 +542,3 @@ } | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| (**self).on_progress(notification, context) | ||
@@ -531,3 +550,3 @@ } | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| (**self).on_initialized(context) | ||
@@ -539,3 +558,3 @@ } | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| (**self).on_roots_list_changed(context) | ||
@@ -548,3 +567,3 @@ } | ||
| context: NotificationContext<RoleServer>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| (**self).on_custom_notification(notification, context) | ||
@@ -561,3 +580,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<ListTasksResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<ListTasksResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).list_tasks(request, context) | ||
@@ -570,3 +589,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<GetTaskResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<GetTaskResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).get_task_info(request, context) | ||
@@ -579,3 +598,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).get_task_result(request, context) | ||
@@ -588,3 +607,3 @@ } | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<CancelTaskResult, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<CancelTaskResult, McpError>> + MaybeSendFuture + '_ { | ||
| (**self).cancel_task(request, context) | ||
@@ -591,0 +610,0 @@ } |
@@ -141,2 +141,3 @@ //! Common utilities shared between tool and prompt handlers | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct Extension<T>(pub T); | ||
@@ -186,2 +187,3 @@ | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct RequestId(pub crate::model::RequestId); | ||
@@ -188,0 +190,0 @@ |
@@ -9,3 +9,4 @@ //! Prompt handling infrastructure for MCP servers | ||
| use futures::future::{BoxFuture, FutureExt}; | ||
| #[cfg(not(feature = "local"))] | ||
| use futures::future::BoxFuture; | ||
| use serde::de::DeserializeOwned; | ||
@@ -19,6 +20,7 @@ | ||
| model::{GetPromptResult, PromptMessage}, | ||
| service::RequestContext, | ||
| service::{MaybeBoxFuture, MaybeSend, MaybeSendFuture, RequestContext}, | ||
| }; | ||
| /// Context for prompt retrieval operations | ||
| #[non_exhaustive] | ||
| pub struct PromptContext<'a, S> { | ||
@@ -62,6 +64,7 @@ pub server: &'a S, | ||
| context: PromptContext<'_, S>, | ||
| ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>; | ||
| ) -> MaybeBoxFuture<'_, Result<GetPromptResult, crate::ErrorData>>; | ||
| } | ||
| /// Type alias for dynamic prompt handlers | ||
| #[cfg(not(feature = "local"))] | ||
| pub type DynGetPromptHandler<S> = dyn for<'a> Fn(PromptContext<'a, S>) -> BoxFuture<'a, Result<GetPromptResult, crate::ErrorData>> | ||
@@ -71,2 +74,10 @@ + Send | ||
| #[cfg(feature = "local")] | ||
| pub type DynGetPromptHandler<S> = dyn for<'a> Fn( | ||
| PromptContext<'a, S>, | ||
| ) -> futures::future::LocalBoxFuture< | ||
| 'a, | ||
| Result<GetPromptResult, crate::ErrorData>, | ||
| >; | ||
| /// Adapter type for async methods that return `Vec<PromptMessage>` | ||
@@ -114,2 +125,3 @@ pub struct AsyncMethodAdapter<T>(PhantomData<T>); | ||
| #[project = IntoGetPromptResultFutProj] | ||
| #[non_exhaustive] | ||
| pub enum IntoGetPromptResultFut<F, R> { | ||
@@ -149,2 +161,3 @@ Pending { | ||
| // Prompt-specific extractor for prompt name | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct PromptName(pub String); | ||
@@ -200,7 +213,7 @@ | ||
| $( | ||
| $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + Send, | ||
| $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + MaybeSendFuture, | ||
| )* | ||
| F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R> + Send, | ||
| R: IntoGetPromptResult + Send + 'static, | ||
| S: Send + Sync + 'static, | ||
| F: FnOnce(&S, $($Tn,)*) -> MaybeBoxFuture<'_, R> + MaybeSendFuture, | ||
| R: IntoGetPromptResult + MaybeSendFuture + 'static, | ||
| S: MaybeSend + 'static, | ||
| { | ||
@@ -211,3 +224,3 @@ #[allow(unused_variables, non_snake_case, unused_mut)] | ||
| mut context: PromptContext<'_, S>, | ||
| ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>> | ||
| ) -> MaybeBoxFuture<'_, Result<GetPromptResult, crate::ErrorData>> | ||
| { | ||
@@ -218,3 +231,3 @@ $( | ||
| Ok(value) => value, | ||
| Err(e) => return std::future::ready(Err(e)).boxed(), | ||
| Err(e) => return Box::pin(std::future::ready(Err(e))), | ||
| }; | ||
@@ -224,6 +237,6 @@ )* | ||
| let fut = self(service, $($Tn,)*); | ||
| async move { | ||
| Box::pin(async move { | ||
| let result = fut.await; | ||
| result.into_get_prompt_result() | ||
| }.boxed() | ||
| }) | ||
| } | ||
@@ -237,7 +250,7 @@ } | ||
| $( | ||
| $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + Send, | ||
| $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + MaybeSendFuture, | ||
| )* | ||
| F: FnOnce(&S, $($Tn,)*) -> R + Send, | ||
| R: IntoGetPromptResult + Send, | ||
| S: Send + Sync, | ||
| F: FnOnce(&S, $($Tn,)*) -> R + MaybeSendFuture, | ||
| R: IntoGetPromptResult + MaybeSendFuture, | ||
| S: MaybeSend, | ||
| { | ||
@@ -248,3 +261,3 @@ #[allow(unused_variables, non_snake_case, unused_mut)] | ||
| mut context: PromptContext<'_, S>, | ||
| ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>> | ||
| ) -> MaybeBoxFuture<'_, Result<GetPromptResult, crate::ErrorData>> | ||
| { | ||
@@ -255,3 +268,3 @@ $( | ||
| Ok(value) => value, | ||
| Err(e) => return std::future::ready(Err(e)).boxed(), | ||
| Err(e) => return Box::pin(std::future::ready(Err(e))), | ||
| }; | ||
@@ -261,3 +274,3 @@ )* | ||
| let result = self(service, $($Tn,)*); | ||
| std::future::ready(result.into_get_prompt_result()).boxed() | ||
| Box::pin(std::future::ready(result.into_get_prompt_result())) | ||
| } | ||
@@ -271,8 +284,8 @@ } | ||
| $( | ||
| $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + Send + 'static, | ||
| $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + MaybeSendFuture + 'static, | ||
| )* | ||
| F: FnOnce($($Tn,)*) -> Fut + Send + 'static, | ||
| Fut: Future<Output = Result<R, crate::ErrorData>> + Send + 'static, | ||
| R: IntoGetPromptResult + Send + 'static, | ||
| S: Send + Sync + 'static, | ||
| F: FnOnce($($Tn,)*) -> Fut + MaybeSendFuture + 'static, | ||
| Fut: Future<Output = Result<R, crate::ErrorData>> + MaybeSendFuture + 'static, | ||
| R: IntoGetPromptResult + MaybeSendFuture + 'static, | ||
| S: MaybeSend + 'static, | ||
| { | ||
@@ -283,3 +296,3 @@ #[allow(unused_variables, non_snake_case, unused_mut)] | ||
| mut context: PromptContext<'_, S>, | ||
| ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>> | ||
| ) -> MaybeBoxFuture<'_, Result<GetPromptResult, crate::ErrorData>> | ||
| { | ||
@@ -291,3 +304,3 @@ // Extract all parameters before moving into the async block | ||
| Ok(value) => value, | ||
| Err(e) => return std::future::ready(Err(e)).boxed(), | ||
| Err(e) => return Box::pin(std::future::ready(Err(e))), | ||
| }; | ||
@@ -310,7 +323,7 @@ )* | ||
| $( | ||
| $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + Send + 'static, | ||
| $Tn: for<'a> FromContextPart<PromptContext<'a, S>> + MaybeSendFuture + 'static, | ||
| )* | ||
| F: FnOnce($($Tn,)*) -> Result<R, crate::ErrorData> + Send + 'static, | ||
| R: IntoGetPromptResult + Send + 'static, | ||
| S: Send + Sync, | ||
| F: FnOnce($($Tn,)*) -> Result<R, crate::ErrorData> + MaybeSendFuture + 'static, | ||
| R: IntoGetPromptResult + MaybeSendFuture + 'static, | ||
| S: MaybeSend, | ||
| { | ||
@@ -321,3 +334,3 @@ #[allow(unused_variables, non_snake_case, unused_mut)] | ||
| mut context: PromptContext<'_, S>, | ||
| ) -> BoxFuture<'_, Result<GetPromptResult, crate::ErrorData>> | ||
| ) -> MaybeBoxFuture<'_, Result<GetPromptResult, crate::ErrorData>> | ||
| { | ||
@@ -328,7 +341,7 @@ $( | ||
| Ok(value) => value, | ||
| Err(e) => return std::future::ready(Err(e)).boxed(), | ||
| Err(e) => return Box::pin(std::future::ready(Err(e))), | ||
| }; | ||
| )* | ||
| let result = self($($Tn,)*); | ||
| std::future::ready(result.and_then(|r| r.into_get_prompt_result())).boxed() | ||
| Box::pin(std::future::ready(result.and_then(|r| r.into_get_prompt_result()))) | ||
| } | ||
@@ -335,0 +348,0 @@ } |
@@ -16,2 +16,3 @@ use std::sync::Arc; | ||
| #[non_exhaustive] | ||
| pub struct Router<S> { | ||
@@ -18,0 +19,0 @@ pub tool_router: tool::ToolRouter<S>, |
| use std::{borrow::Cow, sync::Arc}; | ||
| use futures::future::BoxFuture; | ||
| use crate::{ | ||
| handler::server::prompt::{DynGetPromptHandler, GetPromptHandler, PromptContext}, | ||
| model::{GetPromptResult, Prompt}, | ||
| service::{MaybeBoxFuture, MaybeSend}, | ||
| }; | ||
| #[non_exhaustive] | ||
| pub struct PromptRoute<S> { | ||
@@ -35,6 +35,6 @@ #[allow(clippy::type_complexity)] | ||
| impl<S: Send + Sync + 'static> PromptRoute<S> { | ||
| impl<S: MaybeSend + 'static> PromptRoute<S> { | ||
| pub fn new<H, A: 'static>(attr: impl Into<Prompt>, handler: H) -> Self | ||
| where | ||
| H: GetPromptHandler<S, A> + Send + Sync + Clone + 'static, | ||
| H: GetPromptHandler<S, A> + MaybeSend + Clone + 'static, | ||
| { | ||
@@ -54,5 +54,4 @@ Self { | ||
| PromptContext<'a, S>, | ||
| ) -> BoxFuture<'a, Result<GetPromptResult, crate::ErrorData>> | ||
| + Send | ||
| + Sync | ||
| ) -> MaybeBoxFuture<'a, Result<GetPromptResult, crate::ErrorData>> | ||
| + MaybeSend | ||
| + 'static, | ||
@@ -77,5 +76,5 @@ { | ||
| where | ||
| S: Send + Sync + 'static, | ||
| S: MaybeSend + 'static, | ||
| A: 'static, | ||
| H: GetPromptHandler<S, A> + Send + Sync + Clone + 'static, | ||
| H: GetPromptHandler<S, A> + MaybeSend + Clone + 'static, | ||
| P: Into<Prompt>, | ||
@@ -90,3 +89,3 @@ { | ||
| where | ||
| S: Send + Sync + 'static, | ||
| S: MaybeSend + 'static, | ||
| { | ||
@@ -99,2 +98,3 @@ fn into_prompt_route(self) -> PromptRoute<S> { | ||
| /// Adapter for functions generated by the #\[prompt\] macro | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct PromptAttrGenerateFunctionAdapter; | ||
@@ -104,3 +104,3 @@ | ||
| where | ||
| S: Send + Sync + 'static, | ||
| S: MaybeSend + 'static, | ||
| F: Fn() -> PromptRoute<S>, | ||
@@ -114,2 +114,3 @@ { | ||
| #[derive(Debug)] | ||
| #[non_exhaustive] | ||
| pub struct PromptRouter<S> { | ||
@@ -147,3 +148,3 @@ #[allow(clippy::type_complexity)] | ||
| where | ||
| S: Send + Sync + 'static, | ||
| S: MaybeSend + 'static, | ||
| { | ||
@@ -206,3 +207,3 @@ pub fn new() -> Self { | ||
| where | ||
| S: Send + Sync + 'static, | ||
| S: MaybeSend + 'static, | ||
| { | ||
@@ -219,3 +220,3 @@ type Output = Self; | ||
| where | ||
| S: Send + Sync + 'static, | ||
| S: MaybeSend + 'static, | ||
| { | ||
@@ -222,0 +223,0 @@ fn add_assign(&mut self, other: PromptRouter<S>) { |
@@ -127,3 +127,2 @@ //! Tools for MCP servers. | ||
| use futures::{FutureExt, future::BoxFuture}; | ||
| use schemars::JsonSchema; | ||
@@ -138,4 +137,6 @@ pub use tool_traits::{AsyncTool, SyncTool, ToolBase}; | ||
| model::{CallToolResult, Tool, ToolAnnotations}, | ||
| service::{MaybeBoxFuture, MaybeSend}, | ||
| }; | ||
| #[non_exhaustive] | ||
| pub struct ToolRoute<S> { | ||
@@ -166,6 +167,6 @@ #[allow(clippy::type_complexity)] | ||
| impl<S: Send + Sync + 'static> ToolRoute<S> { | ||
| impl<S: MaybeSend + 'static> ToolRoute<S> { | ||
| pub fn new<C, A>(attr: impl Into<Tool>, call: C) -> Self | ||
| where | ||
| C: CallToolHandler<S, A> + Send + Sync + Clone + 'static, | ||
| C: CallToolHandler<S, A> + MaybeSend + Clone + 'static, | ||
| { | ||
@@ -175,3 +176,3 @@ Self { | ||
| let call = call.clone(); | ||
| context.invoke(call).boxed() | ||
| context.invoke(call) | ||
| }), | ||
@@ -185,5 +186,4 @@ attr: attr.into(), | ||
| ToolCallContext<'a, S>, | ||
| ) -> BoxFuture<'a, Result<CallToolResult, crate::ErrorData>> | ||
| + Send | ||
| + Sync | ||
| ) -> MaybeBoxFuture<'a, Result<CallToolResult, crate::ErrorData>> | ||
| + MaybeSend | ||
| + 'static, | ||
@@ -207,4 +207,4 @@ { | ||
| where | ||
| S: Send + Sync + 'static, | ||
| C: CallToolHandler<S, A> + Send + Sync + Clone + 'static, | ||
| S: MaybeSend + 'static, | ||
| C: CallToolHandler<S, A> + MaybeSend + Clone + 'static, | ||
| T: Into<Tool>, | ||
@@ -219,3 +219,3 @@ { | ||
| where | ||
| S: Send + Sync + 'static, | ||
| S: MaybeSend + 'static, | ||
| { | ||
@@ -227,6 +227,7 @@ fn into_tool_route(self) -> ToolRoute<S> { | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct ToolAttrGenerateFunctionAdapter; | ||
| impl<S, F> IntoToolRoute<S, ToolAttrGenerateFunctionAdapter> for F | ||
| where | ||
| S: Send + Sync + 'static, | ||
| S: MaybeSend + 'static, | ||
| F: Fn() -> ToolRoute<S>, | ||
@@ -241,3 +242,3 @@ { | ||
| where | ||
| Self: CallToolHandler<S, A> + Send + Sync + Clone + 'static, | ||
| Self: CallToolHandler<S, A> + MaybeSend + Clone + 'static, | ||
| { | ||
@@ -249,3 +250,3 @@ fn name(self, name: impl Into<Cow<'static, str>>) -> WithToolAttr<Self, S, A>; | ||
| where | ||
| C: CallToolHandler<S, A> + Send + Sync + Clone + 'static, | ||
| C: CallToolHandler<S, A> + MaybeSend + Clone + 'static, | ||
| { | ||
@@ -265,5 +266,6 @@ fn name(self, name: impl Into<Cow<'static, str>>) -> WithToolAttr<Self, S, A> { | ||
| #[non_exhaustive] | ||
| pub struct WithToolAttr<C, S, A> | ||
| where | ||
| C: CallToolHandler<S, A> + Send + Sync + Clone + 'static, | ||
| C: CallToolHandler<S, A> + MaybeSend + Clone + 'static, | ||
| { | ||
@@ -277,4 +279,4 @@ pub attr: crate::model::Tool, | ||
| where | ||
| C: CallToolHandler<S, A> + Send + Sync + Clone + 'static, | ||
| S: Send + Sync + 'static, | ||
| C: CallToolHandler<S, A> + MaybeSend + Clone + 'static, | ||
| S: MaybeSend + 'static, | ||
| { | ||
@@ -288,3 +290,3 @@ fn into_tool_route(self) -> ToolRoute<S> { | ||
| where | ||
| C: CallToolHandler<S, A> + Send + Sync + Clone + 'static, | ||
| C: CallToolHandler<S, A> + MaybeSend + Clone + 'static, | ||
| { | ||
@@ -309,2 +311,3 @@ pub fn description(mut self, description: impl Into<Cow<'static, str>>) -> Self { | ||
| #[derive(Debug)] | ||
| #[non_exhaustive] | ||
| pub struct ToolRouter<S> { | ||
@@ -345,3 +348,3 @@ #[allow(clippy::type_complexity)] | ||
| where | ||
| S: Send + Sync + 'static, | ||
| S: MaybeSend + 'static, | ||
| { | ||
@@ -446,3 +449,3 @@ pub fn new() -> Self { | ||
| where | ||
| S: Send + Sync + 'static, | ||
| S: MaybeSend + 'static, | ||
| { | ||
@@ -459,3 +462,3 @@ type Output = Self; | ||
| where | ||
| S: Send + Sync + 'static, | ||
| S: MaybeSend + 'static, | ||
| { | ||
@@ -462,0 +465,0 @@ fn add_assign(&mut self, other: ToolRouter<S>) { |
@@ -1,2 +0,2 @@ | ||
| use std::{borrow::Cow, pin::Pin, sync::Arc}; | ||
| use std::{borrow::Cow, future::Future, sync::Arc}; | ||
@@ -14,2 +14,3 @@ use serde::{Deserialize, Serialize}; | ||
| schemars::JsonSchema, | ||
| service::{MaybeSend, MaybeSendFuture}, | ||
| }; | ||
@@ -88,3 +89,4 @@ | ||
| /// Examples are shown in [the module-level documentation][crate::handler::server::router::tool]. | ||
| pub trait SyncTool<S: Sync + Send + 'static>: ToolBase { | ||
| #[allow(private_bounds)] | ||
| pub trait SyncTool<S: MaybeSend + 'static>: ToolBase { | ||
| fn invoke(service: &S, param: Self::Parameter) -> Result<Self::Output, Self::Error>; | ||
@@ -97,7 +99,8 @@ } | ||
| /// Examples are shown in [the module-level documentation][crate::handler::server::router::tool]. | ||
| pub trait AsyncTool<S: Sync + Send + 'static>: ToolBase { | ||
| #[allow(private_bounds)] | ||
| pub trait AsyncTool<S: MaybeSend + 'static>: ToolBase { | ||
| fn invoke( | ||
| service: &S, | ||
| param: Self::Parameter, | ||
| ) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send; | ||
| ) -> impl Future<Output = Result<Self::Output, Self::Error>> + MaybeSendFuture; | ||
| } | ||
@@ -119,3 +122,3 @@ | ||
| pub(crate) fn sync_tool_wrapper<S: Sync + Send + 'static, T: SyncTool<S>>( | ||
| pub(crate) fn sync_tool_wrapper<S: MaybeSend + 'static, T: SyncTool<S>>( | ||
| service: &S, | ||
@@ -127,3 +130,3 @@ Parameters(params): Parameters<T::Parameter>, | ||
| pub(crate) fn sync_tool_wrapper_with_empty_params<S: Sync + Send + 'static, T: SyncTool<S>>( | ||
| pub(crate) fn sync_tool_wrapper_with_empty_params<S: MaybeSend + 'static, T: SyncTool<S>>( | ||
| service: &S, | ||
@@ -136,7 +139,6 @@ ) -> Result<Json<T::Output>, ErrorData> { | ||
| #[expect(clippy::type_complexity)] | ||
| pub(crate) fn async_tool_wrapper<S: Sync + Send + 'static, T: AsyncTool<S>>( | ||
| pub(crate) fn async_tool_wrapper<S: MaybeSend + 'static, T: AsyncTool<S>>( | ||
| service: &S, | ||
| Parameters(params): Parameters<T::Parameter>, | ||
| ) -> Pin<Box<dyn Future<Output = Result<Json<T::Output>, ErrorData>> + Send + '_>> { | ||
| ) -> crate::service::MaybeBoxFuture<'_, Result<Json<T::Output>, ErrorData>> { | ||
| Box::pin(async move { | ||
@@ -150,6 +152,5 @@ T::invoke(service, params) | ||
| #[expect(clippy::type_complexity)] | ||
| pub(crate) fn async_tool_wrapper_with_empty_params<S: Sync + Send + 'static, T: AsyncTool<S>>( | ||
| pub(crate) fn async_tool_wrapper_with_empty_params<S: MaybeSend + 'static, T: AsyncTool<S>>( | ||
| service: &S, | ||
| ) -> Pin<Box<dyn Future<Output = Result<Json<T::Output>, ErrorData>> + Send + '_>> { | ||
| ) -> crate::service::MaybeBoxFuture<'_, Result<Json<T::Output>, ErrorData>> { | ||
| Box::pin(async move { | ||
@@ -156,0 +157,0 @@ T::invoke(service, T::Parameter::default()) |
@@ -7,3 +7,4 @@ use std::{ | ||
| use futures::future::{BoxFuture, FutureExt}; | ||
| #[cfg(not(feature = "local"))] | ||
| use futures::future::BoxFuture; | ||
| use serde::de::DeserializeOwned; | ||
@@ -20,3 +21,3 @@ | ||
| model::{CallToolRequestParams, CallToolResult, IntoContents, JsonObject}, | ||
| service::RequestContext, | ||
| service::{MaybeBoxFuture, MaybeSend, MaybeSendFuture, RequestContext}, | ||
| }; | ||
@@ -33,2 +34,3 @@ | ||
| } | ||
| #[non_exhaustive] | ||
| pub struct ToolCallContext<'s, S> { | ||
@@ -109,2 +111,3 @@ pub request_context: RequestContext<RoleServer>, | ||
| #[project = IntoCallToolResultFutProj] | ||
| #[non_exhaustive] | ||
| pub enum IntoCallToolResultFut<F, R> { | ||
@@ -153,5 +156,6 @@ Pending { | ||
| context: ToolCallContext<'_, S>, | ||
| ) -> BoxFuture<'_, Result<CallToolResult, crate::ErrorData>>; | ||
| ) -> MaybeBoxFuture<'_, Result<CallToolResult, crate::ErrorData>>; | ||
| } | ||
| #[cfg(not(feature = "local"))] | ||
| pub type DynCallToolHandler<S> = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFuture<'s, Result<CallToolResult, crate::ErrorData>> | ||
@@ -161,3 +165,11 @@ + Send | ||
| #[cfg(feature = "local")] | ||
| pub type DynCallToolHandler<S> = | ||
| dyn for<'s> Fn( | ||
| ToolCallContext<'s, S>, | ||
| ) | ||
| -> futures::future::LocalBoxFuture<'s, Result<CallToolResult, crate::ErrorData>>; | ||
| // Tool-specific extractor for tool name | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct ToolName(pub Cow<'static, str>); | ||
@@ -198,3 +210,3 @@ | ||
| impl<'s, S> ToolCallContext<'s, S> { | ||
| pub fn invoke<H, A>(self, h: H) -> BoxFuture<'s, Result<CallToolResult, crate::ErrorData>> | ||
| pub fn invoke<H, A>(self, h: H) -> MaybeBoxFuture<'s, Result<CallToolResult, crate::ErrorData>> | ||
| where | ||
@@ -231,8 +243,8 @@ H: CallToolHandler<S, A>, | ||
| )* | ||
| F: FnOnce(&S, $($Tn,)*) -> BoxFuture<'_, R>, | ||
| F: FnOnce(&S, $($Tn,)*) -> MaybeBoxFuture<'_, R>, | ||
| // Need RTN support here(I guess), https://github.com/rust-lang/rust/pull/138424 | ||
| // Fut: Future<Output = R> + Send + 'a, | ||
| R: IntoCallToolResult + Send + 'static, | ||
| S: Send + Sync + 'static, | ||
| R: IntoCallToolResult + MaybeSendFuture + 'static, | ||
| S: MaybeSend + 'static, | ||
| { | ||
@@ -243,3 +255,3 @@ #[allow(unused_variables, non_snake_case, unused_mut)] | ||
| mut context: ToolCallContext<'_, S>, | ||
| ) -> BoxFuture<'_, Result<CallToolResult, crate::ErrorData>>{ | ||
| ) -> MaybeBoxFuture<'_, Result<CallToolResult, crate::ErrorData>>{ | ||
| $( | ||
@@ -249,3 +261,3 @@ let result = $Tn::from_context_part(&mut context); | ||
| Ok(value) => value, | ||
| Err(e) => return std::future::ready(Err(e)).boxed(), | ||
| Err(e) => return Box::pin(std::future::ready(Err(e))), | ||
| }; | ||
@@ -255,6 +267,6 @@ )* | ||
| let fut = self(service, $($Tn,)*); | ||
| async move { | ||
| Box::pin(async move { | ||
| let result = fut.await; | ||
| result.into_call_tool_result() | ||
| }.boxed() | ||
| }) | ||
| } | ||
@@ -268,6 +280,6 @@ } | ||
| )* | ||
| F: FnOnce($($Tn,)*) -> Fut + Send + , | ||
| Fut: Future<Output = R> + Send + 'static, | ||
| R: IntoCallToolResult + Send + 'static, | ||
| S: Send + Sync, | ||
| F: FnOnce($($Tn,)*) -> Fut + MaybeSendFuture, | ||
| Fut: Future<Output = R> + MaybeSendFuture + 'static, | ||
| R: IntoCallToolResult + MaybeSendFuture + 'static, | ||
| S: MaybeSend, | ||
| { | ||
@@ -278,3 +290,3 @@ #[allow(unused_variables, non_snake_case, unused_mut)] | ||
| mut context: ToolCallContext<S>, | ||
| ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>>{ | ||
| ) -> MaybeBoxFuture<'static, Result<CallToolResult, crate::ErrorData>>{ | ||
| $( | ||
@@ -284,10 +296,10 @@ let result = $Tn::from_context_part(&mut context); | ||
| Ok(value) => value, | ||
| Err(e) => return std::future::ready(Err(e)).boxed(), | ||
| Err(e) => return Box::pin(std::future::ready(Err(e))), | ||
| }; | ||
| )* | ||
| let fut = self($($Tn,)*); | ||
| async move { | ||
| Box::pin(async move { | ||
| let result = fut.await; | ||
| result.into_call_tool_result() | ||
| }.boxed() | ||
| }) | ||
| } | ||
@@ -301,5 +313,5 @@ } | ||
| )* | ||
| F: FnOnce(&S, $($Tn,)*) -> R + Send + , | ||
| R: IntoCallToolResult + Send + , | ||
| S: Send + Sync, | ||
| F: FnOnce(&S, $($Tn,)*) -> R + MaybeSendFuture, | ||
| R: IntoCallToolResult + MaybeSendFuture, | ||
| S: MaybeSend, | ||
| { | ||
@@ -310,3 +322,3 @@ #[allow(unused_variables, non_snake_case, unused_mut)] | ||
| mut context: ToolCallContext<S>, | ||
| ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>> { | ||
| ) -> MaybeBoxFuture<'static, Result<CallToolResult, crate::ErrorData>> { | ||
| $( | ||
@@ -316,6 +328,6 @@ let result = $Tn::from_context_part(&mut context); | ||
| Ok(value) => value, | ||
| Err(e) => return std::future::ready(Err(e)).boxed(), | ||
| Err(e) => return Box::pin(std::future::ready(Err(e))), | ||
| }; | ||
| )* | ||
| std::future::ready(self(context.service, $($Tn,)*).into_call_tool_result()).boxed() | ||
| Box::pin(std::future::ready(self(context.service, $($Tn,)*).into_call_tool_result())) | ||
| } | ||
@@ -329,5 +341,5 @@ } | ||
| )* | ||
| F: FnOnce($($Tn,)*) -> R + Send + , | ||
| R: IntoCallToolResult + Send + , | ||
| S: Send + Sync, | ||
| F: FnOnce($($Tn,)*) -> R + MaybeSendFuture, | ||
| R: IntoCallToolResult + MaybeSendFuture, | ||
| S: MaybeSend, | ||
| { | ||
@@ -338,3 +350,3 @@ #[allow(unused_variables, non_snake_case, unused_mut)] | ||
| mut context: ToolCallContext<S>, | ||
| ) -> BoxFuture<'static, Result<CallToolResult, crate::ErrorData>> { | ||
| ) -> MaybeBoxFuture<'static, Result<CallToolResult, crate::ErrorData>> { | ||
| $( | ||
@@ -344,6 +356,6 @@ let result = $Tn::from_context_part(&mut context); | ||
| Ok(value) => value, | ||
| Err(e) => return std::future::ready(Err(e)).boxed(), | ||
| Err(e) => return Box::pin(std::future::ready(Err(e))), | ||
| }; | ||
| )* | ||
| std::future::ready(self($($Tn,)*).into_call_tool_result()).boxed() | ||
| Box::pin(std::future::ready(self($($Tn,)*).into_call_tool_result())) | ||
| } | ||
@@ -350,0 +362,0 @@ } |
@@ -17,2 +17,3 @@ use std::borrow::Cow; | ||
| /// of the tool result rather than the regular `content` field. | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct Json<T>(pub T); | ||
@@ -19,0 +20,0 @@ |
@@ -45,2 +45,3 @@ use schemars::JsonSchema; | ||
| #[serde(transparent)] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct Parameters<P>(pub P); | ||
@@ -47,0 +48,0 @@ |
@@ -42,2 +42,3 @@ use std::ops::{Deref, DerefMut}; | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct Annotated<T: AnnotateAble> { | ||
@@ -44,0 +45,0 @@ #[serde(flatten)] |
@@ -37,2 +37,3 @@ use std::collections::BTreeMap; | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct PromptsCapability { | ||
@@ -46,2 +47,3 @@ #[serde(skip_serializing_if = "Option::is_none")] | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct ResourcesCapability { | ||
@@ -57,2 +59,3 @@ #[serde(skip_serializing_if = "Option::is_none")] | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct ToolsCapability { | ||
@@ -66,2 +69,3 @@ #[serde(skip_serializing_if = "Option::is_none")] | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct RootsCapabilities { | ||
@@ -76,2 +80,3 @@ #[serde(skip_serializing_if = "Option::is_none")] | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct TasksCapability { | ||
@@ -90,2 +95,3 @@ #[serde(skip_serializing_if = "Option::is_none")] | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct TaskRequestsCapability { | ||
@@ -103,2 +109,3 @@ #[serde(skip_serializing_if = "Option::is_none")] | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct SamplingTaskCapability { | ||
@@ -112,2 +119,3 @@ #[serde(skip_serializing_if = "Option::is_none")] | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct ElicitationTaskCapability { | ||
@@ -121,2 +129,3 @@ #[serde(skip_serializing_if = "Option::is_none")] | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct ToolsTaskCapability { | ||
@@ -202,2 +211,3 @@ #[serde(skip_serializing_if = "Option::is_none")] | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct FormElicitationCapability { | ||
@@ -214,2 +224,3 @@ /// Whether the client supports JSON Schema validation for elicitation responses. | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct UrlElicitationCapability {} | ||
@@ -223,2 +234,3 @@ | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct ElicitationCapability { | ||
@@ -237,2 +249,3 @@ /// Whether client supports form-based elicitation. | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct SamplingCapability { | ||
@@ -326,2 +339,3 @@ /// Support for `tools` and `toolChoice` parameters | ||
| #[derive(Default, Clone, Copy, Debug)] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct [<$Target BuilderState>]< | ||
@@ -331,2 +345,3 @@ $(const [<$f:upper>]: bool = false,)* | ||
| #[derive(Debug, Default)] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct [<$Target Builder>]<S = [<$Target BuilderState>]> { | ||
@@ -333,0 +348,0 @@ $(pub $f: Option<$T>,)* |
@@ -12,2 +12,3 @@ //! Content sent around agents, extensions, and LLMs | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct RawTextContent { | ||
@@ -23,2 +24,3 @@ pub text: String, | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct RawImageContent { | ||
@@ -37,2 +39,3 @@ /// The base64-encoded image | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct RawEmbeddedResource { | ||
@@ -69,2 +72,3 @@ /// Optional protocol-level metadata for this content block | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct RawAudioContent { | ||
@@ -152,2 +156,3 @@ pub data: String, | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_enums, reason = "intentionally exhaustive")] | ||
| pub enum RawContent { | ||
@@ -154,0 +159,0 @@ Text(RawTextContent), |
@@ -198,2 +198,3 @@ use std::ops::{Deref, DerefMut}; | ||
| #[serde(transparent)] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct Meta(pub JsonObject); | ||
@@ -200,0 +201,0 @@ const PROGRESS_TOKEN_FIELD: &str = "progressToken"; |
@@ -141,2 +141,3 @@ use serde::{Deserialize, Serialize}; | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_enums, reason = "intentionally exhaustive")] | ||
| pub enum PromptMessageRole { | ||
@@ -151,2 +152,3 @@ User, | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_enums, reason = "intentionally exhaustive")] | ||
| pub enum PromptMessageContent { | ||
@@ -153,0 +155,0 @@ /// Plain text content |
@@ -9,2 +9,3 @@ use serde::{Deserialize, Serialize}; | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct RawResource { | ||
@@ -43,2 +44,3 @@ /// URI representing the resource location (e.g., "file:///path/to/file" or "str:///content") | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct RawResourceTemplate { | ||
@@ -63,2 +65,3 @@ pub uri_template: String, | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_enums, reason = "intentionally exhaustive")] | ||
| pub enum ResourceContents { | ||
@@ -220,2 +223,3 @@ #[serde(rename_all = "camelCase")] | ||
| use super::*; | ||
| use crate::model::IconTheme; | ||
@@ -272,2 +276,3 @@ #[test] | ||
| sizes: Some(vec!["48x48".to_string()]), | ||
| theme: Some(IconTheme::Light), | ||
| }]), | ||
@@ -280,2 +285,3 @@ }; | ||
| assert_eq!(json["icons"][0]["sizes"][0], "48x48"); | ||
| assert_eq!(json["icons"][0]["theme"], "light"); | ||
| } | ||
@@ -282,0 +288,0 @@ |
+24
-1
@@ -10,2 +10,3 @@ use serde::{Deserialize, Serialize}; | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_enums, reason = "intentionally exhaustive")] | ||
| pub enum TaskStatus { | ||
@@ -114,2 +115,3 @@ /// The receiver accepted the request and is currently working on it. | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct GetTaskResult { | ||
@@ -128,3 +130,3 @@ #[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")] | ||
| /// serialized as a JSON value. | ||
| #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] | ||
| #[derive(Debug, Clone, PartialEq, Serialize)] | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
@@ -141,2 +143,21 @@ #[non_exhaustive] | ||
| // Custom Deserialize that always fails, so that `GetTaskPayloadResult` is skipped | ||
| // during `#[serde(untagged)]` enum deserialization (e.g. `ServerResult`). | ||
| // The payload has the same JSON shape as `CustomResult(Value)`, so they are | ||
| // indistinguishable. `CustomResult` acts as the catch-all instead. | ||
| // `GetTaskPayloadResult` should be constructed programmatically via `::new()`. | ||
| impl<'de> serde::Deserialize<'de> for GetTaskPayloadResult { | ||
| fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||
| where | ||
| D: serde::Deserializer<'de>, | ||
| { | ||
| // Consume the value so the deserializer state stays consistent. | ||
| serde::de::IgnoredAny::deserialize(deserializer)?; | ||
| Err(serde::de::Error::custom( | ||
| "GetTaskPayloadResult cannot be deserialized directly; \ | ||
| use CustomResult as the catch-all", | ||
| )) | ||
| } | ||
| } | ||
| /// Response to a `tasks/cancel` request. | ||
@@ -148,2 +169,3 @@ /// | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct CancelTaskResult { | ||
@@ -160,2 +182,3 @@ #[serde(rename = "_meta", default, skip_serializing_if = "Option::is_none")] | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct TaskList { | ||
@@ -162,0 +185,0 @@ pub tasks: Vec<Task>, |
@@ -54,2 +54,3 @@ use std::{borrow::Cow, sync::Arc}; | ||
| #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] | ||
| #[expect(clippy::exhaustive_enums, reason = "intentionally exhaustive")] | ||
| pub enum TaskSupport { | ||
@@ -56,0 +57,0 @@ /// Clients MUST NOT invoke this tool as a task (default behavior). |
+153
-15
@@ -1,4 +0,45 @@ | ||
| use futures::{FutureExt, future::BoxFuture}; | ||
| use futures::FutureExt; | ||
| #[cfg(not(feature = "local"))] | ||
| use futures::future::BoxFuture; | ||
| #[cfg(feature = "local")] | ||
| use futures::future::LocalBoxFuture; | ||
| use thiserror::Error; | ||
| // --------------------------------------------------------------------------- | ||
| // Conditional Send helpers | ||
| // | ||
| // `MaybeSend` – supertrait alias: `Send + Sync` without `local`, empty with `local` | ||
| // `MaybeSendFuture` – future bound alias: `Send` without `local`, empty with `local` | ||
| // `MaybeBoxFuture` – boxed future type: `BoxFuture` without `local`, `LocalBoxFuture` with `local` | ||
| // --------------------------------------------------------------------------- | ||
| #[cfg(not(feature = "local"))] | ||
| #[doc(hidden)] | ||
| pub trait MaybeSend: Send + Sync {} | ||
| #[cfg(not(feature = "local"))] | ||
| impl<T: Send + Sync> MaybeSend for T {} | ||
| #[cfg(feature = "local")] | ||
| #[doc(hidden)] | ||
| pub trait MaybeSend {} | ||
| #[cfg(feature = "local")] | ||
| impl<T> MaybeSend for T {} | ||
| #[cfg(not(feature = "local"))] | ||
| #[doc(hidden)] | ||
| pub trait MaybeSendFuture: Send {} | ||
| #[cfg(not(feature = "local"))] | ||
| impl<T: Send> MaybeSendFuture for T {} | ||
| #[cfg(feature = "local")] | ||
| #[doc(hidden)] | ||
| pub trait MaybeSendFuture {} | ||
| #[cfg(feature = "local")] | ||
| impl<T> MaybeSendFuture for T {} | ||
| #[cfg(not(feature = "local"))] | ||
| pub(crate) type MaybeBoxFuture<'a, T> = BoxFuture<'a, T>; | ||
| #[cfg(feature = "local")] | ||
| pub(crate) type MaybeBoxFuture<'a, T> = LocalBoxFuture<'a, T>; | ||
| #[cfg(feature = "server")] | ||
@@ -90,2 +131,3 @@ use crate::model::ServerJsonRpcMessage; | ||
| #[cfg(not(feature = "local"))] | ||
| pub trait Service<R: ServiceRole>: Send + Sync + 'static { | ||
@@ -96,3 +138,3 @@ fn handle_request( | ||
| context: RequestContext<R>, | ||
| ) -> impl Future<Output = Result<R::Resp, McpError>> + Send + '_; | ||
| ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_; | ||
| fn handle_notification( | ||
@@ -102,6 +144,21 @@ &self, | ||
| context: NotificationContext<R>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + Send + '_; | ||
| ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_; | ||
| fn get_info(&self) -> R::Info; | ||
| } | ||
| #[cfg(feature = "local")] | ||
| pub trait Service<R: ServiceRole>: 'static { | ||
| fn handle_request( | ||
| &self, | ||
| request: R::PeerReq, | ||
| context: RequestContext<R>, | ||
| ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_; | ||
| fn handle_notification( | ||
| &self, | ||
| notification: R::PeerNot, | ||
| context: NotificationContext<R>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_; | ||
| fn get_info(&self) -> R::Info; | ||
| } | ||
| pub trait ServiceExt<R: ServiceRole>: Service<R> + Sized { | ||
@@ -117,3 +174,3 @@ /// Convert this service to a dynamic boxed service | ||
| transport: T, | ||
| ) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError>> + Send | ||
| ) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError>> + MaybeSendFuture | ||
| where | ||
@@ -130,3 +187,3 @@ T: IntoTransport<R, E, A>, | ||
| ct: CancellationToken, | ||
| ) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError>> + Send | ||
| ) -> impl Future<Output = Result<RunningService<R, Self>, R::InitializeError>> + MaybeSendFuture | ||
| where | ||
@@ -143,3 +200,3 @@ T: IntoTransport<R, E, A>, | ||
| context: RequestContext<R>, | ||
| ) -> impl Future<Output = Result<R::Resp, McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<R::Resp, McpError>> + MaybeSendFuture + '_ { | ||
| DynService::handle_request(self.as_ref(), request, context) | ||
@@ -152,3 +209,3 @@ } | ||
| context: NotificationContext<R>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ { | ||
| DynService::handle_notification(self.as_ref(), notification, context) | ||
@@ -162,2 +219,3 @@ } | ||
| #[cfg(not(feature = "local"))] | ||
| pub trait DynService<R: ServiceRole>: Send + Sync { | ||
@@ -168,3 +226,3 @@ fn handle_request( | ||
| context: RequestContext<R>, | ||
| ) -> BoxFuture<'_, Result<R::Resp, McpError>>; | ||
| ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>>; | ||
| fn handle_notification( | ||
@@ -174,6 +232,21 @@ &self, | ||
| context: NotificationContext<R>, | ||
| ) -> BoxFuture<'_, Result<(), McpError>>; | ||
| ) -> MaybeBoxFuture<'_, Result<(), McpError>>; | ||
| fn get_info(&self) -> R::Info; | ||
| } | ||
| #[cfg(feature = "local")] | ||
| pub trait DynService<R: ServiceRole> { | ||
| fn handle_request( | ||
| &self, | ||
| request: R::PeerReq, | ||
| context: RequestContext<R>, | ||
| ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>>; | ||
| fn handle_notification( | ||
| &self, | ||
| notification: R::PeerNot, | ||
| context: NotificationContext<R>, | ||
| ) -> MaybeBoxFuture<'_, Result<(), McpError>>; | ||
| fn get_info(&self) -> R::Info; | ||
| } | ||
| impl<R: ServiceRole, S: Service<R>> DynService<R> for S { | ||
@@ -184,3 +257,3 @@ fn handle_request( | ||
| context: RequestContext<R>, | ||
| ) -> BoxFuture<'_, Result<R::Resp, McpError>> { | ||
| ) -> MaybeBoxFuture<'_, Result<R::Resp, McpError>> { | ||
| Box::pin(self.handle_request(request, context)) | ||
@@ -192,3 +265,3 @@ } | ||
| context: NotificationContext<R>, | ||
| ) -> BoxFuture<'_, Result<(), McpError>> { | ||
| ) -> MaybeBoxFuture<'_, Result<(), McpError>> { | ||
| Box::pin(self.handle_notification(notification, context)) | ||
@@ -249,2 +322,3 @@ } | ||
| #[derive(Debug)] | ||
| #[non_exhaustive] | ||
| pub struct RequestHandle<R: ServiceRole> { | ||
@@ -341,2 +415,3 @@ pub rx: tokio::sync::oneshot::Receiver<Result<R::PeerResp, ServiceError>>, | ||
| #[derive(Debug, Default)] | ||
| #[non_exhaustive] | ||
| pub struct PeerRequestOptions { | ||
@@ -592,2 +667,3 @@ pub timeout: Option<Duration>, | ||
| #[derive(Debug, Clone)] | ||
| #[non_exhaustive] | ||
| pub struct RequestContext<R: ServiceRole> { | ||
@@ -618,2 +694,3 @@ /// this token will be cancelled when the [`CancelledNotification`] is received. | ||
| #[derive(Debug, Clone)] | ||
| #[non_exhaustive] | ||
| pub struct NotificationContext<R: ServiceRole> { | ||
@@ -658,2 +735,24 @@ pub meta: Meta, | ||
| /// Spawn a task that may hold `!Send` state when the `local` feature is active. | ||
| /// | ||
| /// Without the `local` feature this is `tokio::spawn` (requires `Future: Send + 'static`). | ||
| /// With `local` it uses `tokio::task::spawn_local` (requires only `Future: 'static`). | ||
| #[cfg(not(feature = "local"))] | ||
| fn spawn_service_task<F>(future: F) -> tokio::task::JoinHandle<F::Output> | ||
| where | ||
| F: Future + Send + 'static, | ||
| F::Output: Send + 'static, | ||
| { | ||
| tokio::spawn(future) | ||
| } | ||
| #[cfg(feature = "local")] | ||
| fn spawn_service_task<F>(future: F) -> tokio::task::JoinHandle<F::Output> | ||
| where | ||
| F: Future + 'static, | ||
| F::Output: 'static, | ||
| { | ||
| tokio::task::spawn_local(future) | ||
| } | ||
| #[instrument(skip_all)] | ||
@@ -694,6 +793,7 @@ fn serve_inner<R, S, T>( | ||
| let current_span = tracing::Span::current(); | ||
| let handle = tokio::spawn(async move { | ||
| let handle = spawn_service_task(async move { | ||
| let mut transport = transport.into_transport(); | ||
| let mut batch_messages = VecDeque::<RxJsonRpcMessage<R>>::new(); | ||
| let mut send_task_set = tokio::task::JoinSet::<SendTaskResult>::new(); | ||
| let mut response_send_tasks = tokio::task::JoinSet::<()>::new(); | ||
| #[derive(Debug)] | ||
@@ -810,3 +910,3 @@ enum SendTaskResult { | ||
| let current_span = tracing::Span::current(); | ||
| tokio::spawn(async move { | ||
| response_send_tasks.spawn(async move { | ||
| let send_result = send.await; | ||
@@ -882,3 +982,3 @@ if let Err(error) = send_result { | ||
| let current_span = tracing::Span::current(); | ||
| tokio::spawn(async move { | ||
| spawn_service_task(async move { | ||
| let result = service | ||
@@ -930,3 +1030,3 @@ .handle_request(request, context) | ||
| let current_span = tracing::Span::current(); | ||
| tokio::spawn(async move { | ||
| spawn_service_task(async move { | ||
| let result = service.handle_notification(notification, context).await; | ||
@@ -961,2 +1061,40 @@ if let Err(error) = result { | ||
| }; | ||
| // Drain in-flight handler responses before closing the transport. | ||
| // When stdin EOF or cancellation arrives, spawned handler tasks may still | ||
| // be finishing. We need to: | ||
| // 1. Wait for response sends that were already spawned in the main loop | ||
| // 2. Drain any remaining handler responses from the channel | ||
| let drain_timeout = match &quit_reason { | ||
| QuitReason::Closed => Some(Duration::from_secs(5)), | ||
| QuitReason::Cancelled => Some(Duration::from_secs(2)), | ||
| _ => None, | ||
| }; | ||
| if let Some(timeout_duration) = drain_timeout { | ||
| // Drop our sender so the channel closes once all handler task | ||
| // clones finish sending their responses (or are dropped). | ||
| drop(sink_proxy_tx); | ||
| let drain_result = tokio::time::timeout(timeout_duration, async { | ||
| // First, wait for any response sends already dispatched by the | ||
| // main loop (these hold transport write futures). | ||
| while let Some(result) = response_send_tasks.join_next().await { | ||
| if let Err(error) = result { | ||
| tracing::error!(%error, "response send task failed during drain"); | ||
| } | ||
| } | ||
| // Then drain any handler responses still in the channel | ||
| // (handlers that finished after the loop broke). | ||
| while let Some(m) = sink_proxy_rx.recv().await { | ||
| if let Err(error) = transport.send(m).await { | ||
| tracing::error!(%error, "failed to send pending response during drain"); | ||
| break; | ||
| } | ||
| } | ||
| }) | ||
| .await; | ||
| if drain_result.is_err() { | ||
| tracing::warn!("timed out draining in-flight responses"); | ||
| } | ||
| } | ||
| let sink_close_result = transport.close().await; | ||
@@ -963,0 +1101,0 @@ if let Err(e) = sink_close_result { |
@@ -143,2 +143,3 @@ use std::borrow::Cow; | ||
| #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct RoleClient; | ||
@@ -166,3 +167,4 @@ | ||
| ct: CancellationToken, | ||
| ) -> impl Future<Output = Result<RunningService<RoleClient, Self>, ClientInitializeError>> + Send | ||
| ) -> impl Future<Output = Result<RunningService<RoleClient, Self>, ClientInitializeError>> | ||
| + MaybeSendFuture | ||
| where | ||
@@ -169,0 +171,0 @@ T: IntoTransport<RoleClient, E, A>, |
@@ -30,2 +30,3 @@ use std::borrow::Cow; | ||
| #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] | ||
| #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] | ||
| pub struct RoleServer; | ||
@@ -99,3 +100,4 @@ | ||
| ct: CancellationToken, | ||
| ) -> impl Future<Output = Result<RunningService<RoleServer, Self>, ServerInitializeError>> + Send | ||
| ) -> impl Future<Output = Result<RunningService<RoleServer, Self>, ServerInitializeError>> | ||
| + MaybeSendFuture | ||
| where | ||
@@ -575,2 +577,3 @@ T: IntoTransport<RoleServer, E, A>, | ||
| #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] | ||
| #[non_exhaustive] | ||
| pub enum ElicitationMode { | ||
@@ -577,0 +580,0 @@ Form, |
@@ -6,3 +6,3 @@ use std::{future::poll_fn, marker::PhantomData}; | ||
| use super::NotificationContext; | ||
| use crate::service::{RequestContext, Service, ServiceRole}; | ||
| use crate::service::{MaybeSendFuture, RequestContext, Service, ServiceRole}; | ||
@@ -48,3 +48,3 @@ pub struct TowerHandler<S, R: ServiceRole> { | ||
| _context: NotificationContext<R>, | ||
| ) -> impl Future<Output = Result<(), crate::ErrorData>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<(), crate::ErrorData>> + MaybeSendFuture + '_ { | ||
| std::future::ready(Ok(())) | ||
@@ -51,0 +51,0 @@ } |
@@ -22,2 +22,3 @@ use std::{any::Any, collections::HashMap, pin::Pin}; | ||
| #[derive(Debug, Clone)] | ||
| #[non_exhaustive] | ||
| pub struct OperationDescriptor { | ||
@@ -59,2 +60,3 @@ pub operation_id: String, | ||
| /// Operation message describing a unit of asynchronous work. | ||
| #[non_exhaustive] | ||
| pub struct OperationMessage { | ||
@@ -96,2 +98,3 @@ pub descriptor: OperationDescriptor, | ||
| #[non_exhaustive] | ||
| pub struct TaskResult { | ||
@@ -98,0 +101,0 @@ pub descriptor: OperationDescriptor, |
+6
-2
@@ -10,3 +10,3 @@ //! # Transport | ||
| //! | std IO | [`child_process::TokioChildProcess`] | [`io::stdio`] | | ||
| //! | streamable http | [`streamable_http_client::StreamableHttpClientTransport`] | [`streamable_http_server::StreamableHttpService`] | | ||
| //! | streamable http | [`streamable_http_client::StreamableHttpClientTransport`] | `streamable_http_server::StreamableHttpService` | | ||
| //! | ||
@@ -111,3 +111,3 @@ //!## Helper Transport Types | ||
| pub mod streamable_http_server; | ||
| #[cfg(feature = "transport-streamable-http-server")] | ||
| #[cfg(all(feature = "transport-streamable-http-server", not(feature = "local")))] | ||
| pub use streamable_http_server::tower::{StreamableHttpServerConfig, StreamableHttpService}; | ||
@@ -117,2 +117,4 @@ | ||
| pub mod streamable_http_client; | ||
| #[cfg(all(unix, feature = "transport-streamable-http-client-unix-socket"))] | ||
| pub use common::unix_socket::UnixSocketHttpClient; | ||
| #[cfg(feature = "transport-streamable-http-client")] | ||
@@ -157,2 +159,3 @@ pub use streamable_http_client::StreamableHttpClientTransport; | ||
| #[non_exhaustive] | ||
| pub enum TransportAdapterIdentity {} | ||
@@ -238,2 +241,3 @@ impl<R, T, E> IntoTransport<R, E, TransportAdapterIdentity> for T | ||
| #[error("Transport [{transport_name}] error: {error}")] | ||
| #[non_exhaustive] | ||
| pub struct DynamicTransportError { | ||
@@ -240,0 +244,0 @@ pub transport_name: Cow<'static, str>, |
@@ -19,2 +19,3 @@ use std::{marker::PhantomData, sync::Arc}; | ||
| #[non_exhaustive] | ||
| pub enum TransportAdapterAsyncRW {} | ||
@@ -33,2 +34,3 @@ | ||
| #[non_exhaustive] | ||
| pub enum TransportAdapterAsyncCombinedRW {} | ||
@@ -282,2 +284,3 @@ impl<Role, S> IntoTransport<Role, std::io::Error, TransportAdapterAsyncCombinedRW> for S | ||
| #[derive(Debug, Error)] | ||
| #[non_exhaustive] | ||
| pub enum JsonRpcMessageCodecError { | ||
@@ -284,0 +287,0 @@ #[error("max line length exceeded")] |
@@ -17,1 +17,4 @@ #[cfg(feature = "transport-streamable-http-server")] | ||
| pub mod auth; | ||
| #[cfg(all(unix, feature = "transport-streamable-http-client-unix-socket"))] | ||
| pub mod unix_socket; |
@@ -20,2 +20,3 @@ use std::{ | ||
| #[derive(Debug, Clone)] | ||
| #[non_exhaustive] | ||
| pub struct FixedInterval { | ||
@@ -51,2 +52,3 @@ pub max_times: Option<usize>, | ||
| #[derive(Debug, Clone)] | ||
| #[non_exhaustive] | ||
| pub struct ExponentialBackoff { | ||
@@ -82,2 +84,3 @@ pub max_times: Option<usize>, | ||
| #[derive(Debug, Clone, Copy, Default)] | ||
| #[non_exhaustive] | ||
| pub struct NeverRetry; | ||
@@ -175,2 +178,3 @@ | ||
| #[project = SseAutoReconnectStreamStateProj] | ||
| #[non_exhaustive] | ||
| pub enum SseAutoReconnectStreamState<F> { | ||
@@ -177,0 +181,0 @@ Connected { |
@@ -6,1 +6,120 @@ pub const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; | ||
| pub const JSON_MIME_TYPE: &str = "application/json"; | ||
| /// Reserved headers that must not be overridden by user-supplied custom headers. | ||
| /// `MCP-Protocol-Version` is in this list but is allowed through because the worker | ||
| /// injects it after initialization. | ||
| pub(crate) const RESERVED_HEADERS: &[&str] = &[ | ||
| "accept", | ||
| HEADER_SESSION_ID, | ||
| HEADER_MCP_PROTOCOL_VERSION, // allowed through by validate_custom_header; worker injects it post-init | ||
| HEADER_LAST_EVENT_ID, | ||
| ]; | ||
| /// Checks whether a custom header name is allowed. | ||
| /// Returns `Ok(())` if allowed, `Err(name)` if rejected as reserved. | ||
| /// `MCP-Protocol-Version` is reserved but allowed through (the worker injects it post-init). | ||
| #[cfg(feature = "client-side-sse")] | ||
| pub(crate) fn validate_custom_header(name: &http::HeaderName) -> Result<(), String> { | ||
| if RESERVED_HEADERS | ||
| .iter() | ||
| .any(|&r| name.as_str().eq_ignore_ascii_case(r)) | ||
| { | ||
| if name | ||
| .as_str() | ||
| .eq_ignore_ascii_case(HEADER_MCP_PROTOCOL_VERSION) | ||
| { | ||
| return Ok(()); | ||
| } | ||
| return Err(name.to_string()); | ||
| } | ||
| Ok(()) | ||
| } | ||
| /// Extracts the `scope=` parameter from a `WWW-Authenticate` header value. | ||
| /// Handles both quoted (`scope="files:read files:write"`) and unquoted (`scope=read:data`) forms. | ||
| pub(crate) fn extract_scope_from_header(header: &str) -> Option<String> { | ||
| let header_lowercase = header.to_ascii_lowercase(); | ||
| let scope_key = "scope="; | ||
| if let Some(pos) = header_lowercase.find(scope_key) { | ||
| let start = pos + scope_key.len(); | ||
| let value_slice = &header[start..]; | ||
| if let Some(stripped) = value_slice.strip_prefix('"') { | ||
| if let Some(end_quote) = stripped.find('"') { | ||
| return Some(stripped[..end_quote].to_string()); | ||
| } | ||
| } else { | ||
| let end = value_slice | ||
| .find(|c: char| c == ',' || c == ';' || c.is_whitespace()) | ||
| .unwrap_or(value_slice.len()); | ||
| if end > 0 { | ||
| return Some(value_slice[..end].to_string()); | ||
| } | ||
| } | ||
| } | ||
| None | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| #[test] | ||
| fn extract_scope_quoted() { | ||
| let header = r#"Bearer error="insufficient_scope", scope="files:read files:write""#; | ||
| assert_eq!( | ||
| extract_scope_from_header(header), | ||
| Some("files:read files:write".to_string()) | ||
| ); | ||
| } | ||
| #[test] | ||
| fn extract_scope_unquoted() { | ||
| let header = r#"Bearer scope=read:data, error="insufficient_scope""#; | ||
| assert_eq!( | ||
| extract_scope_from_header(header), | ||
| Some("read:data".to_string()) | ||
| ); | ||
| } | ||
| #[test] | ||
| fn extract_scope_missing() { | ||
| let header = r#"Bearer error="invalid_token""#; | ||
| assert_eq!(extract_scope_from_header(header), None); | ||
| } | ||
| #[test] | ||
| fn extract_scope_empty_header() { | ||
| assert_eq!(extract_scope_from_header("Bearer"), None); | ||
| } | ||
| #[cfg(feature = "client-side-sse")] | ||
| #[test] | ||
| fn validate_rejects_reserved_accept() { | ||
| let name = http::HeaderName::from_static("accept"); | ||
| assert!(validate_custom_header(&name).is_err()); | ||
| } | ||
| #[cfg(feature = "client-side-sse")] | ||
| #[test] | ||
| fn validate_rejects_reserved_session_id() { | ||
| let name = http::HeaderName::from_static("mcp-session-id"); | ||
| assert!(validate_custom_header(&name).is_err()); | ||
| } | ||
| #[cfg(feature = "client-side-sse")] | ||
| #[test] | ||
| fn validate_allows_mcp_protocol_version() { | ||
| let name = http::HeaderName::from_static("mcp-protocol-version"); | ||
| assert!(validate_custom_header(&name).is_ok()); | ||
| } | ||
| #[cfg(feature = "client-side-sse")] | ||
| #[test] | ||
| fn validate_allows_custom_header() { | ||
| let name = http::HeaderName::from_static("x-custom"); | ||
| assert!(validate_custom_header(&name).is_ok()); | ||
| } | ||
| } |
@@ -9,7 +9,7 @@ use std::{borrow::Cow, collections::HashMap, sync::Arc}; | ||
| use crate::{ | ||
| model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, | ||
| model::{ClientJsonRpcMessage, JsonRpcMessage, ServerJsonRpcMessage}, | ||
| transport::{ | ||
| common::http_header::{ | ||
| EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION, | ||
| HEADER_SESSION_ID, JSON_MIME_TYPE, | ||
| EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE, | ||
| extract_scope_from_header, validate_custom_header, | ||
| }, | ||
@@ -26,14 +26,3 @@ streamable_http_client::*, | ||
| /// Reserved headers that must not be overridden by user-supplied custom headers. | ||
| /// `MCP-Protocol-Version` is in this list but is allowed through because the worker | ||
| /// injects it after initialization. | ||
| const RESERVED_HEADERS: &[&str] = &[ | ||
| "accept", | ||
| HEADER_SESSION_ID, | ||
| HEADER_MCP_PROTOCOL_VERSION, | ||
| HEADER_LAST_EVENT_ID, | ||
| ]; | ||
| /// Applies custom headers to a request builder, rejecting reserved headers | ||
| /// except `MCP-Protocol-Version` (which the worker injects after init). | ||
| /// Applies custom headers to a request builder, rejecting reserved headers. | ||
| fn apply_custom_headers( | ||
@@ -44,17 +33,3 @@ mut builder: reqwest::RequestBuilder, | ||
| for (name, value) in custom_headers { | ||
| if RESERVED_HEADERS | ||
| .iter() | ||
| .any(|&r| name.as_str().eq_ignore_ascii_case(r)) | ||
| { | ||
| if name | ||
| .as_str() | ||
| .eq_ignore_ascii_case(HEADER_MCP_PROTOCOL_VERSION) | ||
| { | ||
| builder = builder.header(name, value); | ||
| continue; | ||
| } | ||
| return Err(StreamableHttpError::ReservedHeaderConflict( | ||
| name.to_string(), | ||
| )); | ||
| } | ||
| validate_custom_header(&name).map_err(StreamableHttpError::ReservedHeaderConflict)?; | ||
| builder = builder.header(name, value); | ||
@@ -65,2 +40,11 @@ } | ||
| /// Attempts to parse `body` as a JSON-RPC error message. | ||
| /// Returns `None` if the body is not parseable or is not a `JsonRpcMessage::Error`. | ||
| fn parse_json_rpc_error(body: &str) -> Option<ServerJsonRpcMessage> { | ||
| match serde_json::from_str::<ServerJsonRpcMessage>(body) { | ||
| Ok(message @ JsonRpcMessage::Error(_)) => Some(message), | ||
| _ => None, | ||
| } | ||
| } | ||
| impl StreamableHttpClient for reqwest::Client { | ||
@@ -197,2 +181,13 @@ type Error = reqwest::Error; | ||
| } | ||
| let content_type = response | ||
| .headers() | ||
| .get(reqwest::header::CONTENT_TYPE) | ||
| .map(|ct| String::from_utf8_lossy(ct.as_bytes()).to_string()); | ||
| let session_id = response | ||
| .headers() | ||
| .get(HEADER_SESSION_ID) | ||
| .and_then(|v| v.to_str().ok()) | ||
| .map(|s| s.to_string()); | ||
| // Non-success responses may carry valid JSON-RPC error payloads that | ||
| // should be surfaced as McpError rather than lost in TransportSend. | ||
| if !status.is_success() { | ||
@@ -203,2 +198,15 @@ let body = response | ||
| .unwrap_or_else(|_| "<failed to read response body>".to_owned()); | ||
| if content_type | ||
| .as_deref() | ||
| .is_some_and(|ct| ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes())) | ||
| { | ||
| match parse_json_rpc_error(&body) { | ||
| Some(message) => { | ||
| return Ok(StreamableHttpPostResponse::Json(message, session_id)); | ||
| } | ||
| None => tracing::warn!( | ||
| "HTTP {status}: could not parse JSON body as a JSON-RPC error" | ||
| ), | ||
| } | ||
| } | ||
| return Err(StreamableHttpError::UnexpectedServerResponse(Cow::Owned( | ||
@@ -208,8 +216,3 @@ format!("HTTP {status}: {body}"), | ||
| } | ||
| let content_type = response.headers().get(reqwest::header::CONTENT_TYPE); | ||
| let session_id = response.headers().get(HEADER_SESSION_ID); | ||
| let session_id = session_id | ||
| .and_then(|v| v.to_str().ok()) | ||
| .map(|s| s.to_string()); | ||
| match content_type { | ||
| match content_type.as_deref() { | ||
| Some(ct) if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) => { | ||
@@ -236,5 +239,3 @@ let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); | ||
| tracing::error!("unexpected content type: {:?}", content_type); | ||
| Err(StreamableHttpError::UnexpectedContentType( | ||
| content_type.map(|ct| String::from_utf8_lossy(ct.as_bytes()).to_string()), | ||
| )) | ||
| Err(StreamableHttpError::UnexpectedContentType(content_type)) | ||
| } | ||
@@ -291,63 +292,8 @@ } | ||
| /// extract scope parameter from WWW-Authenticate header | ||
| fn extract_scope_from_header(header: &str) -> Option<String> { | ||
| let header_lowercase = header.to_ascii_lowercase(); | ||
| let scope_key = "scope="; | ||
| if let Some(pos) = header_lowercase.find(scope_key) { | ||
| let start = pos + scope_key.len(); | ||
| let value_slice = &header[start..]; | ||
| if let Some(stripped) = value_slice.strip_prefix('"') { | ||
| if let Some(end_quote) = stripped.find('"') { | ||
| return Some(stripped[..end_quote].to_string()); | ||
| } | ||
| } else { | ||
| let end = value_slice | ||
| .find(|c: char| c == ',' || c == ';' || c.is_whitespace()) | ||
| .unwrap_or(value_slice.len()); | ||
| if end > 0 { | ||
| return Some(value_slice[..end].to_string()); | ||
| } | ||
| } | ||
| } | ||
| None | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::extract_scope_from_header; | ||
| use crate::transport::streamable_http_client::InsufficientScopeError; | ||
| use super::parse_json_rpc_error; | ||
| use crate::{model::JsonRpcMessage, transport::streamable_http_client::InsufficientScopeError}; | ||
| #[test] | ||
| fn extract_scope_quoted() { | ||
| let header = r#"Bearer error="insufficient_scope", scope="files:read files:write""#; | ||
| assert_eq!( | ||
| extract_scope_from_header(header), | ||
| Some("files:read files:write".to_string()) | ||
| ); | ||
| } | ||
| #[test] | ||
| fn extract_scope_unquoted() { | ||
| let header = r#"Bearer scope=read:data, error="insufficient_scope""#; | ||
| assert_eq!( | ||
| extract_scope_from_header(header), | ||
| Some("read:data".to_string()) | ||
| ); | ||
| } | ||
| #[test] | ||
| fn extract_scope_missing() { | ||
| let header = r#"Bearer error="invalid_token""#; | ||
| assert_eq!(extract_scope_from_header(header), None); | ||
| } | ||
| #[test] | ||
| fn extract_scope_empty_header() { | ||
| assert_eq!(extract_scope_from_header("Bearer"), None); | ||
| } | ||
| #[test] | ||
| fn insufficient_scope_error_can_upgrade() { | ||
@@ -368,2 +314,34 @@ let with_scope = InsufficientScopeError { | ||
| } | ||
| #[test] | ||
| fn parse_json_rpc_error_returns_error_variant() { | ||
| let body = | ||
| r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid Request"}}"#; | ||
| assert!(matches!( | ||
| parse_json_rpc_error(body), | ||
| Some(JsonRpcMessage::Error(_)) | ||
| )); | ||
| } | ||
| #[test] | ||
| fn parse_json_rpc_error_rejects_non_error_request() { | ||
| // A valid JSON-RPC request (method + id) must not be accepted as an error. | ||
| let body = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#; | ||
| assert!(parse_json_rpc_error(body).is_none()); | ||
| } | ||
| #[test] | ||
| fn parse_json_rpc_error_rejects_notification() { | ||
| // A notification (method, no id) must not be accepted as an error. | ||
| let body = | ||
| r#"{"jsonrpc":"2.0","method":"notifications/cancelled","params":{"requestId":1}}"#; | ||
| assert!(parse_json_rpc_error(body).is_none()); | ||
| } | ||
| #[test] | ||
| fn parse_json_rpc_error_rejects_malformed_json() { | ||
| assert!(parse_json_rpc_error("not json at all").is_none()); | ||
| assert!(parse_json_rpc_error("").is_none()); | ||
| assert!(parse_json_rpc_error(r#"{"broken":"#).is_none()); | ||
| } | ||
| } |
@@ -61,2 +61,3 @@ #![allow(dead_code)] | ||
| #[derive(Debug, Clone)] | ||
| #[non_exhaustive] | ||
| pub struct ServerSseMessage { | ||
@@ -63,0 +64,0 @@ /// The event ID for this message. When set, clients can use this ID |
@@ -53,2 +53,3 @@ use std::sync::Arc; | ||
| #[non_exhaustive] | ||
| pub enum TransportAdapterSinkStream {} | ||
@@ -68,2 +69,3 @@ | ||
| #[non_exhaustive] | ||
| pub enum TransportAdapterAsyncCombinedRW {} | ||
@@ -70,0 +72,0 @@ impl<Role, S> IntoTransport<Role, S::Error, TransportAdapterAsyncCombinedRW> for S |
@@ -27,2 +27,3 @@ use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration}; | ||
| #[derive(Debug)] | ||
| #[non_exhaustive] | ||
| pub struct AuthRequiredError { | ||
@@ -33,2 +34,3 @@ pub www_authenticate_header: String, | ||
| #[derive(Debug)] | ||
| #[non_exhaustive] | ||
| pub struct InsufficientScopeError { | ||
@@ -217,2 +219,3 @@ pub www_authenticate_header: String, | ||
| #[non_exhaustive] | ||
| pub struct RetryConfig { | ||
@@ -259,2 +262,3 @@ pub max_times: Option<usize>, | ||
| #[derive(Debug, Clone, Default)] | ||
| #[non_exhaustive] | ||
| pub struct StreamableHttpClientWorker<C: StreamableHttpClient> { | ||
@@ -607,44 +611,47 @@ pub client: C, | ||
| Err(StreamableHttpError::SessionExpired) => { | ||
| // The server discarded the session (HTTP 404). Perform a | ||
| // fresh handshake once and replay the original message. | ||
| tracing::info!( | ||
| "session expired (HTTP 404), attempting transparent re-initialization" | ||
| ); | ||
| match Self::perform_reinitialization( | ||
| self.client.clone(), | ||
| saved_init_request.clone(), | ||
| config.uri.clone(), | ||
| config.auth_header.clone(), | ||
| config.custom_headers.clone(), | ||
| ) | ||
| .await | ||
| { | ||
| Ok((new_session_id, new_protocol_headers)) => { | ||
| // Old streams hold the stale session ID; abort them | ||
| // so the new standalone SSE stream takes over. | ||
| streams.abort_all(); | ||
| if !config.reinit_on_expired_session { | ||
| Err(StreamableHttpError::SessionExpired) | ||
| } else { | ||
| // The server discarded the session (HTTP 404). Perform a | ||
| // fresh handshake once and replay the original message. | ||
| tracing::info!( | ||
| "session expired (HTTP 404), attempting transparent re-initialization" | ||
| ); | ||
| match Self::perform_reinitialization( | ||
| self.client.clone(), | ||
| saved_init_request.clone(), | ||
| config.uri.clone(), | ||
| config.auth_header.clone(), | ||
| config.custom_headers.clone(), | ||
| ) | ||
| .await | ||
| { | ||
| Ok((new_session_id, new_protocol_headers)) => { | ||
| // Old streams hold the stale session ID; abort them | ||
| // so the new standalone SSE stream takes over. | ||
| streams.abort_all(); | ||
| session_id = new_session_id; | ||
| protocol_headers = new_protocol_headers; | ||
| session_cleanup_info = | ||
| session_id.as_ref().map(|sid| SessionCleanupInfo { | ||
| client: self.client.clone(), | ||
| uri: config.uri.clone(), | ||
| session_id: sid.clone(), | ||
| auth_header: config.auth_header.clone(), | ||
| protocol_headers: protocol_headers.clone(), | ||
| }); | ||
| session_id = new_session_id; | ||
| protocol_headers = new_protocol_headers; | ||
| session_cleanup_info = | ||
| session_id.as_ref().map(|sid| SessionCleanupInfo { | ||
| client: self.client.clone(), | ||
| uri: config.uri.clone(), | ||
| session_id: sid.clone(), | ||
| auth_header: config.auth_header.clone(), | ||
| protocol_headers: protocol_headers.clone(), | ||
| }); | ||
| if let Some(new_sid) = &session_id { | ||
| let client = self.client.clone(); | ||
| let uri = config.uri.clone(); | ||
| let new_sid = new_sid.clone(); | ||
| let auth_header = config.auth_header.clone(); | ||
| let retry_config = self.config.retry_config.clone(); | ||
| let sse_tx = sse_worker_tx.clone(); | ||
| let task_ct = transport_task_ct.clone(); | ||
| let config_uri = config.uri.clone(); | ||
| let config_auth = config.auth_header.clone(); | ||
| let spawn_headers = protocol_headers.clone(); | ||
| streams.spawn(async move { | ||
| if let Some(new_sid) = &session_id { | ||
| let client = self.client.clone(); | ||
| let uri = config.uri.clone(); | ||
| let new_sid = new_sid.clone(); | ||
| let auth_header = config.auth_header.clone(); | ||
| let retry_config = self.config.retry_config.clone(); | ||
| let sse_tx = sse_worker_tx.clone(); | ||
| let task_ct = transport_task_ct.clone(); | ||
| let config_uri = config.uri.clone(); | ||
| let config_auth = config.auth_header.clone(); | ||
| let spawn_headers = protocol_headers.clone(); | ||
| streams.spawn(async move { | ||
| match client | ||
@@ -694,47 +701,48 @@ .get_stream( | ||
| }); | ||
| } | ||
| } | ||
| let retry_response = self | ||
| .client | ||
| .post_message( | ||
| config.uri.clone(), | ||
| message, | ||
| session_id.clone(), | ||
| config.auth_header.clone(), | ||
| protocol_headers.clone(), | ||
| ) | ||
| .await; | ||
| match retry_response { | ||
| Err(e) => Err(e), | ||
| Ok(StreamableHttpPostResponse::Accepted) => { | ||
| tracing::trace!( | ||
| "client message accepted after re-init" | ||
| ); | ||
| Ok(()) | ||
| } | ||
| Ok(StreamableHttpPostResponse::Json(msg, ..)) => { | ||
| context.send_to_handler(msg).await?; | ||
| Ok(()) | ||
| } | ||
| Ok(StreamableHttpPostResponse::Sse(stream, ..)) => { | ||
| if let Some(sid) = &session_id { | ||
| let sse_stream = SseAutoReconnectStream::new( | ||
| stream, | ||
| StreamableHttpClientReconnect { | ||
| client: self.client.clone(), | ||
| session_id: sid.clone(), | ||
| uri: config.uri.clone(), | ||
| auth_header: config.auth_header.clone(), | ||
| custom_headers: protocol_headers.clone(), | ||
| }, | ||
| self.config.retry_config.clone(), | ||
| let retry_response = self | ||
| .client | ||
| .post_message( | ||
| config.uri.clone(), | ||
| message, | ||
| session_id.clone(), | ||
| config.auth_header.clone(), | ||
| protocol_headers.clone(), | ||
| ) | ||
| .await; | ||
| match retry_response { | ||
| Err(e) => Err(e), | ||
| Ok(StreamableHttpPostResponse::Accepted) => { | ||
| tracing::trace!( | ||
| "client message accepted after re-init" | ||
| ); | ||
| streams.spawn(Self::execute_sse_stream( | ||
| sse_stream, | ||
| sse_worker_tx.clone(), | ||
| true, | ||
| transport_task_ct.child_token(), | ||
| )); | ||
| } else { | ||
| let sse_stream = | ||
| Ok(()) | ||
| } | ||
| Ok(StreamableHttpPostResponse::Json(msg, ..)) => { | ||
| context.send_to_handler(msg).await?; | ||
| Ok(()) | ||
| } | ||
| Ok(StreamableHttpPostResponse::Sse(stream, ..)) => { | ||
| if let Some(sid) = &session_id { | ||
| let sse_stream = SseAutoReconnectStream::new( | ||
| stream, | ||
| StreamableHttpClientReconnect { | ||
| client: self.client.clone(), | ||
| session_id: sid.clone(), | ||
| uri: config.uri.clone(), | ||
| auth_header: config.auth_header.clone(), | ||
| custom_headers: protocol_headers | ||
| .clone(), | ||
| }, | ||
| self.config.retry_config.clone(), | ||
| ); | ||
| streams.spawn(Self::execute_sse_stream( | ||
| sse_stream, | ||
| sse_worker_tx.clone(), | ||
| true, | ||
| transport_task_ct.child_token(), | ||
| )); | ||
| } else { | ||
| let sse_stream = | ||
| SseAutoReconnectStream::never_reconnect( | ||
@@ -744,16 +752,17 @@ stream, | ||
| ); | ||
| streams.spawn(Self::execute_sse_stream( | ||
| sse_stream, | ||
| sse_worker_tx.clone(), | ||
| true, | ||
| transport_task_ct.child_token(), | ||
| )); | ||
| streams.spawn(Self::execute_sse_stream( | ||
| sse_stream, | ||
| sse_worker_tx.clone(), | ||
| true, | ||
| transport_task_ct.child_token(), | ||
| )); | ||
| } | ||
| tracing::trace!("got new sse stream after re-init"); | ||
| Ok(()) | ||
| } | ||
| tracing::trace!("got new sse stream after re-init"); | ||
| Ok(()) | ||
| } | ||
| } | ||
| Err(reinit_err) => Err(reinit_err), | ||
| } | ||
| Err(reinit_err) => Err(reinit_err), | ||
| } | ||
| } // else enable_reinit_on_expired_session | ||
| } | ||
@@ -1051,2 +1060,3 @@ Err(e) => Err(e), | ||
| #[derive(Debug, Clone)] | ||
| #[non_exhaustive] | ||
| pub struct StreamableHttpClientTransportConfig { | ||
@@ -1062,2 +1072,12 @@ pub uri: Arc<str>, | ||
| pub custom_headers: HashMap<HeaderName, HeaderValue>, | ||
| /// Enables transparent recovery when the server reports an expired session (`HTTP 404`). | ||
| /// | ||
| /// When enabled, the transport performs one automatic recovery attempt: | ||
| /// 1. Replays the original `initialize` handshake to create a new session. | ||
| /// 2. Re-establishes streaming state for that session. | ||
| /// 3. Retries the in-flight request that failed with `SessionExpired`. | ||
| /// | ||
| /// This recovery is best-effort and bounded to a single attempt. If recovery fails, | ||
| /// the original failure path is preserved and the error is returned to the caller. | ||
| pub reinit_on_expired_session: bool, | ||
| } | ||
@@ -1110,2 +1130,15 @@ | ||
| } | ||
| /// Set whether the transport should attempt transparent re-initialization on session expiration | ||
| /// See [`Self::reinit_on_expired_session`] for details. | ||
| /// # Example | ||
| /// ```rust,no_run | ||
| /// use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; | ||
| /// let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8000") | ||
| /// .reinit_on_expired_session(true); | ||
| /// ``` | ||
| pub fn reinit_on_expired_session(mut self, enable: bool) -> Self { | ||
| self.reinit_on_expired_session = enable; | ||
| self | ||
| } | ||
| } | ||
@@ -1122,4 +1155,5 @@ | ||
| custom_headers: HashMap::new(), | ||
| reinit_on_expired_session: true, | ||
| } | ||
| } | ||
| } |
| pub mod session; | ||
| #[cfg(feature = "transport-streamable-http-server")] | ||
| #[cfg(all(feature = "transport-streamable-http-server", not(feature = "local")))] | ||
| pub mod tower; | ||
| pub use session::{SessionId, SessionManager}; | ||
| #[cfg(feature = "transport-streamable-http-server")] | ||
| #[cfg(all(feature = "transport-streamable-http-server", not(feature = "local")))] | ||
| pub use tower::{StreamableHttpServerConfig, StreamableHttpService}; |
@@ -36,3 +36,3 @@ //! Session management for the Streamable HTTP transport. | ||
| /// | ||
| /// The [`StreamableHttpService`](super::StreamableHttpService) calls into this | ||
| /// The `StreamableHttpService` calls into this | ||
| /// trait for every HTTP request that carries (or should carry) a session ID. | ||
@@ -39,0 +39,0 @@ /// |
@@ -32,2 +32,3 @@ use std::{ | ||
| #[derive(Debug, Default)] | ||
| #[non_exhaustive] | ||
| pub struct LocalSessionManager { | ||
@@ -39,2 +40,3 @@ pub sessions: tokio::sync::RwLock<HashMap<SessionId, LocalSessionHandle>>, | ||
| #[derive(Debug, Error)] | ||
| #[non_exhaustive] | ||
| pub enum LocalSessionManagerError { | ||
@@ -153,2 +155,3 @@ #[error("Session not found: {0}")] | ||
| #[derive(Debug, Clone, Error)] | ||
| #[non_exhaustive] | ||
| pub enum EventIdParseError { | ||
@@ -316,2 +319,3 @@ #[error("Invalid index: {0}")] | ||
| #[derive(Debug, Error)] | ||
| #[non_exhaustive] | ||
| pub enum SessionError { | ||
@@ -346,2 +350,3 @@ #[error("Invalid request id: {0}")] | ||
| #[derive(Debug)] | ||
| #[non_exhaustive] | ||
| pub struct StreamableHttpMessageReceiver { | ||
@@ -665,2 +670,3 @@ pub http_request_id: Option<HttpRequestId>, | ||
| #[derive(Debug)] | ||
| #[non_exhaustive] | ||
| pub enum SessionEvent { | ||
@@ -1068,2 +1074,3 @@ ClientMessage { | ||
| #[derive(Debug, Clone)] | ||
| #[non_exhaustive] | ||
| pub struct SessionConfig { | ||
@@ -1070,0 +1077,0 @@ /// the capacity of the channel for the session. Default is 16. |
@@ -13,5 +13,8 @@ use futures::Stream; | ||
| #[error("Session management is not supported")] | ||
| #[non_exhaustive] | ||
| pub struct ErrorSessionManagementNotSupported; | ||
| #[derive(Debug, Clone, Default)] | ||
| #[non_exhaustive] | ||
| pub struct NeverSessionManager {} | ||
| #[non_exhaustive] | ||
| pub enum NeverTransport {} | ||
@@ -18,0 +21,0 @@ impl Transport<RoleServer> for NeverTransport { |
@@ -33,2 +33,3 @@ use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration}; | ||
| #[derive(Debug, Clone)] | ||
| #[non_exhaustive] | ||
| pub struct StreamableHttpServerConfig { | ||
@@ -66,2 +67,29 @@ /// The ping message duration for SSE connections. | ||
| impl StreamableHttpServerConfig { | ||
| pub fn with_sse_keep_alive(mut self, duration: Option<Duration>) -> Self { | ||
| self.sse_keep_alive = duration; | ||
| self | ||
| } | ||
| pub fn with_sse_retry(mut self, duration: Option<Duration>) -> Self { | ||
| self.sse_retry = duration; | ||
| self | ||
| } | ||
| pub fn with_stateful_mode(mut self, stateful: bool) -> Self { | ||
| self.stateful_mode = stateful; | ||
| self | ||
| } | ||
| pub fn with_json_response(mut self, json_response: bool) -> Self { | ||
| self.json_response = json_response; | ||
| self | ||
| } | ||
| pub fn with_cancellation_token(mut self, token: CancellationToken) -> Self { | ||
| self.cancellation_token = token; | ||
| self | ||
| } | ||
| } | ||
| #[expect( | ||
@@ -190,3 +218,3 @@ clippy::result_large_err, | ||
| /// ``` | ||
| pub struct StreamableHttpService<S, M = super::session::local::LocalSessionManager> { | ||
| pub struct StreamableHttpService<S, M> { | ||
| pub config: StreamableHttpServerConfig, | ||
@@ -210,3 +238,3 @@ session_manager: Arc<M>, | ||
| RequestBody: Body + Send + 'static, | ||
| S: crate::Service<RoleServer>, | ||
| S: crate::Service<RoleServer> + Send + 'static, | ||
| M: SessionManager, | ||
@@ -213,0 +241,0 @@ RequestBody::Error: Display, |
@@ -56,2 +56,3 @@ use std::borrow::Cow; | ||
| #[non_exhaustive] | ||
| pub struct WorkerSendRequest<W: Worker> { | ||
@@ -70,2 +71,3 @@ pub message: TxJsonRpcMessage<W::Role>, | ||
| #[non_exhaustive] | ||
| pub struct WorkerConfig { | ||
@@ -84,2 +86,3 @@ pub name: Option<String>, | ||
| } | ||
| #[non_exhaustive] | ||
| pub enum WorkerAdapter {} | ||
@@ -149,2 +152,3 @@ | ||
| #[non_exhaustive] | ||
| pub struct SendRequest<W: Worker> { | ||
@@ -155,2 +159,3 @@ pub message: TxJsonRpcMessage<W::Role>, | ||
| #[non_exhaustive] | ||
| pub struct WorkerContext<W: Worker> { | ||
@@ -157,0 +162,0 @@ pub to_handler_tx: tokio::sync::mpsc::Sender<RxJsonRpcMessage<W::Role>>, |
@@ -10,3 +10,7 @@ use std::{ | ||
| use rmcp::{ClientHandler, RoleClient}; | ||
| use rmcp::{ErrorData as McpError, RoleServer, ServerHandler, model::*, service::RequestContext}; | ||
| use rmcp::{ | ||
| ErrorData as McpError, RoleServer, ServerHandler, | ||
| model::*, | ||
| service::{MaybeSendFuture, RequestContext}, | ||
| }; | ||
| #[cfg(feature = "client")] | ||
@@ -89,3 +93,3 @@ use serde_json::json; | ||
| _context: NotificationContext<RoleClient>, | ||
| ) -> impl Future<Output = ()> + Send + '_ { | ||
| ) -> impl Future<Output = ()> + MaybeSendFuture + '_ { | ||
| let receive_signal = self.receive_signal.clone(); | ||
@@ -121,3 +125,3 @@ let received_messages = self.received_messages.clone(); | ||
| context: RequestContext<RoleServer>, | ||
| ) -> impl Future<Output = Result<(), McpError>> + Send + '_ { | ||
| ) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ { | ||
| let peer = context.peer; | ||
@@ -124,0 +128,0 @@ async move { |
| // cargo test --features "server client" --package rmcp test_client_initialization | ||
| #![cfg(feature = "client")] | ||
| #![cfg(all(feature = "client", not(feature = "local")))] | ||
@@ -4,0 +4,0 @@ mod common; |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| //cargo test --test test_close_connection --features "client server" | ||
@@ -2,0 +3,0 @@ |
@@ -0,1 +1,3 @@ | ||
| #![allow(clippy::exhaustive_structs, clippy::exhaustive_enums)] | ||
| use rmcp::{ | ||
@@ -2,0 +4,0 @@ ErrorData as McpError, handler::server::wrapper::Parameters, model::*, schemars, tool, |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| use std::collections::HashMap; | ||
@@ -2,0 +3,0 @@ |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| use std::sync::Arc; | ||
@@ -2,0 +3,0 @@ |
@@ -16,1 +16,120 @@ use rmcp::model::{JsonRpcResponse, ServerJsonRpcMessage, ServerResult}; | ||
| } | ||
| /// Regression tests for `#[serde(untagged)]` deserialization of `ServerResult`. | ||
| /// | ||
| /// `ServerResult` is an untagged enum, so serde tries each variant in declaration | ||
| /// order. `GetTaskPayloadResult` has a custom `Deserialize` impl that always fails | ||
| /// so it is skipped, and `CustomResult(Value)` acts as the catch-all. If variant | ||
| /// ordering changes or the custom impl is removed, these tests will catch the | ||
| /// regression. | ||
| mod untagged_server_result { | ||
| use rmcp::model::{CallToolResult, JsonRpcResponse, ServerJsonRpcMessage, ServerResult}; | ||
| use serde_json::json; | ||
| /// Helper: wrap a result value in a JSON-RPC response envelope. | ||
| fn wrap_response(result: serde_json::Value) -> serde_json::Value { | ||
| json!({ | ||
| "jsonrpc": "2.0", | ||
| "id": 1, | ||
| "result": result | ||
| }) | ||
| } | ||
| /// Parse a JSON-RPC response and return the inner `ServerResult`. | ||
| fn parse_result(json: serde_json::Value) -> ServerResult { | ||
| let msg: ServerJsonRpcMessage = serde_json::from_value(json).unwrap(); | ||
| match msg { | ||
| ServerJsonRpcMessage::Response(JsonRpcResponse { result, .. }) => result, | ||
| other => panic!("expected Response, got {other:?}"), | ||
| } | ||
| } | ||
| #[test] | ||
| fn initialize_result_deserializes_to_correct_variant() { | ||
| let result = parse_result(wrap_response(json!({ | ||
| "protocolVersion": "2025-03-26", | ||
| "capabilities": {}, | ||
| "serverInfo": { | ||
| "name": "test-server", | ||
| "version": "1.0.0" | ||
| } | ||
| }))); | ||
| assert!( | ||
| matches!(result, ServerResult::InitializeResult(_)), | ||
| "expected InitializeResult, got {result:?}" | ||
| ); | ||
| } | ||
| #[test] | ||
| fn call_tool_result_deserializes_to_correct_variant() { | ||
| let result = parse_result(wrap_response(json!({ | ||
| "content": [ | ||
| { "type": "text", "text": "hello" } | ||
| ] | ||
| }))); | ||
| assert!( | ||
| matches!(result, ServerResult::CallToolResult(_)), | ||
| "expected CallToolResult, got {result:?}" | ||
| ); | ||
| } | ||
| #[test] | ||
| fn empty_object_deserializes_to_empty_result() { | ||
| let result = parse_result(wrap_response(json!({}))); | ||
| assert!( | ||
| matches!(result, ServerResult::EmptyResult(_)), | ||
| "expected EmptyResult, got {result:?}" | ||
| ); | ||
| } | ||
| #[test] | ||
| fn unknown_shape_falls_through_to_custom_result() { | ||
| // A value that doesn't match any known result type should land in | ||
| // CustomResult, NOT GetTaskPayloadResult. | ||
| let result = parse_result(wrap_response(json!({ | ||
| "some_unknown_field": "some_value", | ||
| "number": 42 | ||
| }))); | ||
| assert!( | ||
| matches!(result, ServerResult::CustomResult(_)), | ||
| "expected CustomResult, got {result:?}" | ||
| ); | ||
| } | ||
| #[test] | ||
| fn arbitrary_json_value_does_not_deserialize_as_get_task_payload_result() { | ||
| // GetTaskPayloadResult wraps a bare Value, but its custom Deserialize | ||
| // always fails so serde skips it during untagged resolution. | ||
| // Any JSON value must fall through to CustomResult instead. | ||
| for value in [json!(42), json!("hello"), json!(null), json!([1, 2, 3])] { | ||
| let result = parse_result(wrap_response(value.clone())); | ||
| assert!( | ||
| matches!(result, ServerResult::CustomResult(_)), | ||
| "value {value} should deserialize as CustomResult, got {result:?}" | ||
| ); | ||
| } | ||
| } | ||
| #[test] | ||
| fn round_trip_initialize_result_preserves_variant() { | ||
| let json = json!({ | ||
| "protocolVersion": "2025-03-26", | ||
| "capabilities": {}, | ||
| "serverInfo": { "name": "test", "version": "1.0" } | ||
| }); | ||
| // Parse as ServerResult, serialize back, parse again — must stay InitializeResult. | ||
| let result = parse_result(wrap_response(json.clone())); | ||
| assert!(matches!(&result, ServerResult::InitializeResult(_))); | ||
| let reserialized = serde_json::to_value(&result).unwrap(); | ||
| let result2 = parse_result(wrap_response(reserialized)); | ||
| assert!(matches!(result2, ServerResult::InitializeResult(_))); | ||
| } | ||
| #[test] | ||
| fn round_trip_call_tool_result_preserves_variant() { | ||
| let original = CallToolResult::success(vec![rmcp::model::Content::text("hello world")]); | ||
| let json = serde_json::to_value(&original).unwrap(); | ||
| let result = parse_result(wrap_response(json)); | ||
| assert!(matches!(result, ServerResult::CallToolResult(_))); | ||
| } | ||
| } |
@@ -0,1 +1,2 @@ | ||
| #![allow(clippy::exhaustive_structs)] | ||
| //cargo test --test test_json_schema_detection --features "client server macros" | ||
@@ -2,0 +3,0 @@ use rmcp::{ |
| // cargo test --features "server client" --package rmcp test_logging | ||
| #![cfg(not(feature = "local"))] | ||
| mod common; | ||
@@ -3,0 +4,0 @@ |
| //cargo test --test test_message_protocol --features "client server" | ||
| #![cfg(not(feature = "local"))] | ||
@@ -10,3 +11,2 @@ mod common; | ||
| }; | ||
| use tokio_util::sync::CancellationToken; | ||
@@ -51,9 +51,3 @@ // Tests start here | ||
| request.clone(), | ||
| RequestContext { | ||
| peer: client.peer().clone(), | ||
| ct: CancellationToken::new(), | ||
| id: NumberOrString::Number(1), | ||
| meta: Default::default(), | ||
| extensions: Default::default(), | ||
| }, | ||
| RequestContext::new(NumberOrString::Number(1), client.peer().clone()), | ||
| ) | ||
@@ -89,9 +83,3 @@ .await?; | ||
| request.clone(), | ||
| RequestContext { | ||
| peer: client.peer().clone(), | ||
| ct: CancellationToken::new(), | ||
| id: NumberOrString::Number(2), | ||
| meta: Default::default(), | ||
| extensions: Default::default(), | ||
| }, | ||
| RequestContext::new(NumberOrString::Number(2), client.peer().clone()), | ||
| ) | ||
@@ -127,9 +115,3 @@ .await?; | ||
| request.clone(), | ||
| RequestContext { | ||
| peer: client.peer().clone(), | ||
| ct: CancellationToken::new(), | ||
| id: NumberOrString::Number(3), | ||
| meta: Default::default(), | ||
| extensions: Default::default(), | ||
| }, | ||
| RequestContext::new(NumberOrString::Number(3), client.peer().clone()), | ||
| ) | ||
@@ -185,9 +167,3 @@ .await?; | ||
| request.clone(), | ||
| RequestContext { | ||
| peer: client.peer().clone(), | ||
| ct: CancellationToken::new(), | ||
| id: NumberOrString::Number(1), | ||
| meta: Meta::default(), | ||
| extensions: Default::default(), | ||
| }, | ||
| RequestContext::new(NumberOrString::Number(1), client.peer().clone()), | ||
| ) | ||
@@ -248,9 +224,3 @@ .await?; | ||
| request.clone(), | ||
| RequestContext { | ||
| peer: client.peer().clone(), | ||
| ct: CancellationToken::new(), | ||
| id: NumberOrString::Number(1), | ||
| meta: Meta::default(), | ||
| extensions: Default::default(), | ||
| }, | ||
| RequestContext::new(NumberOrString::Number(1), client.peer().clone()), | ||
| ) | ||
@@ -315,9 +285,3 @@ .await?; | ||
| request.clone(), | ||
| RequestContext { | ||
| peer: client.peer().clone(), | ||
| ct: CancellationToken::new(), | ||
| id: NumberOrString::Number(1), | ||
| meta: Meta::default(), | ||
| extensions: Default::default(), | ||
| }, | ||
| RequestContext::new(NumberOrString::Number(1), client.peer().clone()), | ||
| ) | ||
@@ -339,9 +303,3 @@ .await?; | ||
| request.clone(), | ||
| RequestContext { | ||
| peer: client.peer().clone(), | ||
| ct: CancellationToken::new(), | ||
| id: NumberOrString::Number(2), | ||
| meta: Meta::default(), | ||
| extensions: Default::default(), | ||
| }, | ||
| RequestContext::new(NumberOrString::Number(2), client.peer().clone()), | ||
| ) | ||
@@ -380,9 +338,3 @@ .await; | ||
| request.clone(), | ||
| RequestContext { | ||
| peer: client.peer().clone(), | ||
| ct: CancellationToken::new(), | ||
| id: NumberOrString::Number(1), | ||
| meta: Meta::default(), | ||
| extensions: Default::default(), | ||
| }, | ||
| RequestContext::new(NumberOrString::Number(1), client.peer().clone()), | ||
| ) | ||
@@ -416,9 +368,3 @@ .await?; | ||
| request.clone(), | ||
| RequestContext { | ||
| peer: client.peer().clone(), | ||
| ct: CancellationToken::new(), | ||
| id: NumberOrString::Number(2), | ||
| meta: Meta::default(), | ||
| extensions: Default::default(), | ||
| }, | ||
| RequestContext::new(NumberOrString::Number(2), client.peer().clone()), | ||
| ) | ||
@@ -469,9 +415,3 @@ .await?; | ||
| request.clone(), | ||
| RequestContext { | ||
| peer: client.peer().clone(), | ||
| ct: CancellationToken::new(), | ||
| id: NumberOrString::Number(1), | ||
| meta: Meta::default(), | ||
| extensions: Default::default(), | ||
| }, | ||
| RequestContext::new(NumberOrString::Number(1), client.peer().clone()), | ||
| ) | ||
@@ -478,0 +418,0 @@ .await?; |
@@ -723,2 +723,13 @@ { | ||
| "type": "string" | ||
| }, | ||
| "theme": { | ||
| "description": "Optional specifier for the theme this icon is designed for\nIf not provided, the client should assume the icon can be used with any theme.", | ||
| "anyOf": [ | ||
| { | ||
| "$ref": "#/definitions/IconTheme" | ||
| }, | ||
| { | ||
| "type": "null" | ||
| } | ||
| ] | ||
| } | ||
@@ -730,2 +741,17 @@ }, | ||
| }, | ||
| "IconTheme": { | ||
| "description": "Icon themes supported by the MCP specification", | ||
| "oneOf": [ | ||
| { | ||
| "description": "Indicates the icon is designed to be used with a light background", | ||
| "type": "string", | ||
| "const": "light" | ||
| }, | ||
| { | ||
| "description": "Indicates the icon is designed to be used with a dark background", | ||
| "type": "string", | ||
| "const": "dark" | ||
| } | ||
| ] | ||
| }, | ||
| "Implementation": { | ||
@@ -732,0 +758,0 @@ "type": "object", |
@@ -723,2 +723,13 @@ { | ||
| "type": "string" | ||
| }, | ||
| "theme": { | ||
| "description": "Optional specifier for the theme this icon is designed for\nIf not provided, the client should assume the icon can be used with any theme.", | ||
| "anyOf": [ | ||
| { | ||
| "$ref": "#/definitions/IconTheme" | ||
| }, | ||
| { | ||
| "type": "null" | ||
| } | ||
| ] | ||
| } | ||
@@ -730,2 +741,17 @@ }, | ||
| }, | ||
| "IconTheme": { | ||
| "description": "Icon themes supported by the MCP specification", | ||
| "oneOf": [ | ||
| { | ||
| "description": "Indicates the icon is designed to be used with a light background", | ||
| "type": "string", | ||
| "const": "light" | ||
| }, | ||
| { | ||
| "description": "Indicates the icon is designed to be used with a dark background", | ||
| "type": "string", | ||
| "const": "dark" | ||
| } | ||
| ] | ||
| }, | ||
| "Implementation": { | ||
@@ -732,0 +758,0 @@ "type": "object", |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| use std::sync::Arc; | ||
@@ -2,0 +3,0 @@ |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| use futures::StreamExt; | ||
@@ -2,0 +3,0 @@ use rmcp::{ |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| //cargo test --test test_prompt_macros --features "client server" | ||
@@ -2,0 +3,0 @@ #![allow(dead_code)] |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| use std::collections::HashMap; | ||
@@ -2,0 +3,0 @@ |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| mod common; | ||
@@ -10,3 +11,2 @@ | ||
| }; | ||
| use tokio_util::sync::CancellationToken; | ||
@@ -129,9 +129,3 @@ #[tokio::test] | ||
| request.clone(), | ||
| RequestContext { | ||
| peer: client.peer().clone(), | ||
| ct: CancellationToken::new(), | ||
| id: NumberOrString::Number(1), | ||
| meta: Default::default(), | ||
| extensions: Default::default(), | ||
| }, | ||
| RequestContext::new(NumberOrString::Number(1), client.peer().clone()), | ||
| ) | ||
@@ -193,9 +187,3 @@ .await?; | ||
| request.clone(), | ||
| RequestContext { | ||
| peer: client.peer().clone(), | ||
| ct: CancellationToken::new(), | ||
| id: NumberOrString::Number(2), | ||
| meta: Default::default(), | ||
| extensions: Default::default(), | ||
| }, | ||
| RequestContext::new(NumberOrString::Number(2), client.peer().clone()), | ||
| ) | ||
@@ -258,9 +246,3 @@ .await?; | ||
| request.clone(), | ||
| RequestContext { | ||
| peer: client.peer().clone(), | ||
| ct: CancellationToken::new(), | ||
| id: NumberOrString::Number(3), | ||
| meta: Default::default(), | ||
| extensions: Default::default(), | ||
| }, | ||
| RequestContext::new(NumberOrString::Number(3), client.peer().clone()), | ||
| ) | ||
@@ -267,0 +249,0 @@ .await; |
| // cargo test --features "client" --package rmcp -- server_init | ||
| #![cfg(feature = "client")] | ||
| #![cfg(all(feature = "client", not(feature = "local")))] | ||
| mod common; | ||
@@ -4,0 +4,0 @@ |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| /// Tests for concurrent SSE stream handling (shadow channels) | ||
@@ -76,9 +77,3 @@ /// | ||
| Arc::new(LocalSessionManager::default()), | ||
| StreamableHttpServerConfig { | ||
| stateful_mode: true, | ||
| sse_keep_alive: Some(Duration::from_secs(15)), | ||
| sse_retry: Some(Duration::from_secs(3)), | ||
| cancellation_token: ct.child_token(), | ||
| ..Default::default() | ||
| }, | ||
| StreamableHttpServerConfig::default().with_cancellation_token(ct.child_token()), | ||
| ); | ||
@@ -85,0 +80,0 @@ |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| use rmcp::transport::streamable_http_server::{ | ||
@@ -39,9 +40,9 @@ StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, | ||
| let ct = CancellationToken::new(); | ||
| let (client, url, ct) = spawn_server(StreamableHttpServerConfig { | ||
| stateful_mode: false, | ||
| json_response: true, | ||
| sse_keep_alive: None, | ||
| cancellation_token: ct.child_token(), | ||
| ..Default::default() | ||
| }) | ||
| let (client, url, ct) = spawn_server( | ||
| StreamableHttpServerConfig::default() | ||
| .with_stateful_mode(false) | ||
| .with_json_response(true) | ||
| .with_sse_keep_alive(None) | ||
| .with_cancellation_token(ct.child_token()), | ||
| ) | ||
| .await; | ||
@@ -82,9 +83,8 @@ | ||
| let ct = CancellationToken::new(); | ||
| let (client, url, ct) = spawn_server(StreamableHttpServerConfig { | ||
| stateful_mode: false, | ||
| json_response: false, | ||
| sse_keep_alive: None, | ||
| cancellation_token: ct.child_token(), | ||
| ..Default::default() | ||
| }) | ||
| let (client, url, ct) = spawn_server( | ||
| StreamableHttpServerConfig::default() | ||
| .with_stateful_mode(false) | ||
| .with_sse_keep_alive(None) | ||
| .with_cancellation_token(ct.child_token()), | ||
| ) | ||
| .await; | ||
@@ -126,9 +126,8 @@ | ||
| // json_response: true has no effect when stateful_mode: true — server still uses SSE | ||
| let (client, url, ct) = spawn_server(StreamableHttpServerConfig { | ||
| stateful_mode: true, | ||
| json_response: true, | ||
| sse_keep_alive: None, | ||
| cancellation_token: ct.child_token(), | ||
| ..Default::default() | ||
| }) | ||
| let (client, url, ct) = spawn_server( | ||
| StreamableHttpServerConfig::default() | ||
| .with_json_response(true) | ||
| .with_sse_keep_alive(None) | ||
| .with_cancellation_token(ct.child_token()), | ||
| ) | ||
| .await; | ||
@@ -135,0 +134,0 @@ |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| use std::time::Duration; | ||
@@ -20,8 +21,5 @@ | ||
| Default::default(), | ||
| StreamableHttpServerConfig { | ||
| stateful_mode: true, | ||
| sse_keep_alive: None, | ||
| cancellation_token: ct.child_token(), | ||
| ..Default::default() | ||
| }, | ||
| StreamableHttpServerConfig::default() | ||
| .with_sse_keep_alive(None) | ||
| .with_cancellation_token(ct.child_token()), | ||
| ); | ||
@@ -90,8 +88,5 @@ | ||
| session_manager.clone(), | ||
| StreamableHttpServerConfig { | ||
| stateful_mode: true, | ||
| sse_keep_alive: None, | ||
| cancellation_token: ct.child_token(), | ||
| ..Default::default() | ||
| }, | ||
| StreamableHttpServerConfig::default() | ||
| .with_sse_keep_alive(None) | ||
| .with_cancellation_token(ct.child_token()), | ||
| ); | ||
@@ -98,0 +93,0 @@ |
| #![cfg(all( | ||
| feature = "transport-streamable-http-client", | ||
| feature = "transport-streamable-http-client-reqwest", | ||
| feature = "transport-streamable-http-server" | ||
| feature = "transport-streamable-http-server", | ||
| not(feature = "local") | ||
| ))] | ||
@@ -10,3 +11,3 @@ | ||
| use rmcp::{ | ||
| ServiceExt, | ||
| ServiceError, ServiceExt, | ||
| model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId}, | ||
@@ -35,8 +36,5 @@ transport::{ | ||
| Default::default(), | ||
| StreamableHttpServerConfig { | ||
| stateful_mode: true, | ||
| sse_keep_alive: None, | ||
| cancellation_token: ct.child_token(), | ||
| ..Default::default() | ||
| }, | ||
| StreamableHttpServerConfig::default() | ||
| .with_sse_keep_alive(None) | ||
| .with_cancellation_token(ct.child_token()), | ||
| ); | ||
@@ -107,8 +105,5 @@ | ||
| session_manager.clone(), | ||
| StreamableHttpServerConfig { | ||
| stateful_mode: true, | ||
| sse_keep_alive: None, | ||
| cancellation_token: ct.child_token(), | ||
| ..Default::default() | ||
| }, | ||
| StreamableHttpServerConfig::default() | ||
| .with_sse_keep_alive(None) | ||
| .with_cancellation_token(ct.child_token()), | ||
| ); | ||
@@ -131,3 +126,4 @@ | ||
| let transport = StreamableHttpClientTransport::from_config( | ||
| StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")), | ||
| StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")) | ||
| .reinit_on_expired_session(true), | ||
| ); | ||
@@ -177,1 +173,64 @@ let client = ().serve(transport).await?; | ||
| } | ||
| /// Verify that when `reinit_on_expired_session` is false and the server loses the session, | ||
| /// the client receives a `SessionExpired` transport error instead of retrying. | ||
| #[tokio::test] | ||
| async fn test_session_expired_error_when_reinit_disabled() -> anyhow::Result<()> { | ||
| let ct = CancellationToken::new(); | ||
| let session_manager = Arc::new(LocalSessionManager::default()); | ||
| let service = StreamableHttpService::new( | ||
| || Ok(Calculator::new()), | ||
| session_manager.clone(), | ||
| StreamableHttpServerConfig::default() | ||
| .with_sse_keep_alive(None) | ||
| .with_cancellation_token(ct.child_token()), | ||
| ); | ||
| let router = axum::Router::new().nest_service("/mcp", service); | ||
| let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; | ||
| let addr = listener.local_addr()?; | ||
| let server_handle = tokio::spawn({ | ||
| let ct = ct.clone(); | ||
| async move { | ||
| let _ = axum::serve(listener, router) | ||
| .with_graceful_shutdown(async move { ct.cancelled_owned().await }) | ||
| .await; | ||
| } | ||
| }); | ||
| let transport = StreamableHttpClientTransport::from_config( | ||
| StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")) | ||
| .reinit_on_expired_session(false), | ||
| ); | ||
| let client = ().serve(transport).await?; | ||
| // Verify the session is established | ||
| let _resources = client.list_all_resources().await?; | ||
| // Force session expiry by removing all sessions from the server-side manager | ||
| { | ||
| let mut sessions = session_manager.sessions.write().await; | ||
| sessions.clear(); | ||
| } | ||
| // This call should fail with a SessionExpired transport error | ||
| let result = client.list_all_resources().await; | ||
| match result { | ||
| Err(ServiceError::TransportSend(ref dyn_err)) => { | ||
| let err_msg = format!("{dyn_err}"); | ||
| assert!( | ||
| err_msg.contains("Session expired"), | ||
| "expected 'Session expired' in error message, got: {err_msg}" | ||
| ); | ||
| } | ||
| other => panic!("expected TransportSend(SessionExpired), got: {other:?}"), | ||
| } | ||
| let _ = client.cancel().await; | ||
| ct.cancel(); | ||
| server_handle.await?; | ||
| Ok(()) | ||
| } |
@@ -0,1 +1,2 @@ | ||
| #![allow(clippy::exhaustive_structs)] | ||
| //cargo test --test test_structured_output --features "client server macros" | ||
@@ -301,14 +302,16 @@ use rmcp::{ | ||
| #[tokio::test] | ||
| async fn test_missing_content_is_rejected() { | ||
| #[test] | ||
| fn test_missing_content_defaults_to_empty() { | ||
| let raw = json!({ "isError": false }); | ||
| let result: Result<CallToolResult, _> = serde_json::from_value(raw); | ||
| assert!(result.is_err()); | ||
| let result: CallToolResult = serde_json::from_value(raw).unwrap(); | ||
| assert!(result.content.is_empty()); | ||
| assert_eq!(result.is_error, Some(false)); | ||
| } | ||
| #[tokio::test] | ||
| async fn test_missing_content_with_structured_content_is_rejected() { | ||
| #[test] | ||
| fn test_missing_content_with_structured_content_deserializes() { | ||
| let raw = json!({ "structuredContent": {"key": "value"}, "isError": false }); | ||
| let result: Result<CallToolResult, _> = serde_json::from_value(raw); | ||
| assert!(result.is_err()); | ||
| let result: CallToolResult = serde_json::from_value(raw).unwrap(); | ||
| assert!(result.content.is_empty()); | ||
| assert_eq!(result.structured_content.unwrap()["key"], "value"); | ||
| } | ||
@@ -337,1 +340,11 @@ | ||
| } | ||
| #[test] | ||
| fn test_call_tool_result_deserialize_without_content() { | ||
| let json = json!({ | ||
| "structuredContent": {"message": "Hello"} | ||
| }); | ||
| let result: CallToolResult = serde_json::from_value(json).unwrap(); | ||
| assert!(result.content.is_empty()); | ||
| assert!(result.structured_content.is_some()); | ||
| } |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| //! Tests for task support validation in tool calls. | ||
@@ -2,0 +3,0 @@ //! |
@@ -0,1 +1,2 @@ | ||
| #![allow(clippy::exhaustive_structs)] | ||
| //cargo test --test test_tool_builder_methods --features "client server macros" | ||
@@ -2,0 +3,0 @@ use rmcp::model::{JsonObject, Tool}; |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| //! Test tool macros, including documentation for generated fns. | ||
@@ -2,0 +3,0 @@ |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| use std::collections::HashMap; | ||
@@ -2,0 +3,0 @@ |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| use rmcp::{ | ||
@@ -71,8 +72,5 @@ ServiceExt, | ||
| Default::default(), | ||
| StreamableHttpServerConfig { | ||
| stateful_mode: true, | ||
| sse_keep_alive: None, | ||
| cancellation_token: ct.child_token(), | ||
| ..Default::default() | ||
| }, | ||
| StreamableHttpServerConfig::default() | ||
| .with_sse_keep_alive(None) | ||
| .with_cancellation_token(ct.child_token()), | ||
| ); | ||
@@ -79,0 +77,0 @@ let router = axum::Router::new().nest_service("/mcp", service); |
@@ -0,1 +1,2 @@ | ||
| #![cfg(not(feature = "local"))] | ||
| use std::process::Stdio; | ||
@@ -2,0 +3,0 @@ |
-216
| The MCP project is undergoing a licensing transition from the MIT License to the Apache License, Version 2.0 ("Apache-2.0"). All new code and specification contributions to the project are licensed under Apache-2.0. Documentation contributions (excluding specifications) are licensed under CC-BY-4.0. | ||
| Contributions for which relicensing consent has been obtained are licensed under Apache-2.0. Contributions made by authors who originally licensed their work under the MIT License and who have not yet granted explicit permission to relicense remain licensed under the MIT License. | ||
| No rights beyond those granted by the applicable original license are conveyed for such contributions. | ||
| --- | ||
| Apache License | ||
| Version 2.0, January 2004 | ||
| http://www.apache.org/licenses/ | ||
| TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | ||
| 1. Definitions. | ||
| "License" shall mean the terms and conditions for use, reproduction, | ||
| and distribution as defined by Sections 1 through 9 of this document. | ||
| "Licensor" shall mean the copyright owner or entity authorized by | ||
| the copyright owner that is granting the License. | ||
| "Legal Entity" shall mean the union of the acting entity and all | ||
| other entities that control, are controlled by, or are under common | ||
| control with that entity. For the purposes of this definition, | ||
| "control" means (i) the power, direct or indirect, to cause the | ||
| direction or management of such entity, whether by contract or | ||
| otherwise, or (ii) ownership of fifty percent (50%) or more of the | ||
| outstanding shares, or (iii) beneficial ownership of such entity. | ||
| "You" (or "Your") shall mean an individual or Legal Entity | ||
| exercising permissions granted by this License. | ||
| "Source" form shall mean the preferred form for making modifications, | ||
| including but not limited to software source code, documentation | ||
| source, and configuration files. | ||
| "Object" form shall mean any form resulting from mechanical | ||
| transformation or translation of a Source form, including but | ||
| not limited to compiled object code, generated documentation, | ||
| and conversions to other media types. | ||
| "Work" shall mean the work of authorship, whether in Source or | ||
| Object form, made available under the License, as indicated by a | ||
| copyright notice that is included in or attached to the work | ||
| (an example is provided in the Appendix below). | ||
| "Derivative Works" shall mean any work, whether in Source or Object | ||
| form, that is based on (or derived from) the Work and for which the | ||
| editorial revisions, annotations, elaborations, or other modifications | ||
| represent, as a whole, an original work of authorship. For the purposes | ||
| of this License, Derivative Works shall not include works that remain | ||
| separable from, or merely link (or bind by name) to the interfaces of, | ||
| the Work and Derivative Works thereof. | ||
| "Contribution" shall mean any work of authorship, including | ||
| the original version of the Work and any modifications or additions | ||
| to that Work or Derivative Works thereof, that is intentionally | ||
| submitted to the Licensor for inclusion in the Work by the copyright | ||
| owner or by an individual or Legal Entity authorized to submit on behalf | ||
| of the copyright owner. For the purposes of this definition, "submitted" | ||
| means any form of electronic, verbal, or written communication sent | ||
| to the Licensor or its representatives, including but not limited to | ||
| communication on electronic mailing lists, source code control systems, | ||
| and issue tracking systems that are managed by, or on behalf of, the | ||
| Licensor for the purpose of discussing and improving the Work, but | ||
| excluding communication that is conspicuously marked or otherwise | ||
| designated in writing by the copyright owner as "Not a Contribution." | ||
| "Contributor" shall mean Licensor and any individual or Legal Entity | ||
| on behalf of whom a Contribution has been received by Licensor and | ||
| subsequently incorporated within the Work. | ||
| 2. Grant of Copyright License. Subject to the terms and conditions of | ||
| this License, each Contributor hereby grants to You a perpetual, | ||
| worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||
| copyright license to reproduce, prepare Derivative Works of, | ||
| publicly display, publicly perform, sublicense, and distribute the | ||
| Work and such Derivative Works in Source or Object form. | ||
| 3. Grant of Patent License. Subject to the terms and conditions of | ||
| this License, each Contributor hereby grants to You a perpetual, | ||
| worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||
| (except as stated in this section) patent license to make, have made, | ||
| use, offer to sell, sell, import, and otherwise transfer the Work, | ||
| where such license applies only to those patent claims licensable | ||
| by such Contributor that are necessarily infringed by their | ||
| Contribution(s) alone or by combination of their Contribution(s) | ||
| with the Work to which such Contribution(s) was submitted. If You | ||
| institute patent litigation against any entity (including a | ||
| cross-claim or counterclaim in a lawsuit) alleging that the Work | ||
| or a Contribution incorporated within the Work constitutes direct | ||
| or contributory patent infringement, then any patent licenses | ||
| granted to You under this License for that Work shall terminate | ||
| as of the date such litigation is filed. | ||
| 4. Redistribution. You may reproduce and distribute copies of the | ||
| Work or Derivative Works thereof in any medium, with or without | ||
| modifications, and in Source or Object form, provided that You | ||
| meet the following conditions: | ||
| (a) You must give any other recipients of the Work or | ||
| Derivative Works a copy of this License; and | ||
| (b) You must cause any modified files to carry prominent notices | ||
| stating that You changed the files; and | ||
| (c) You must retain, in the Source form of any Derivative Works | ||
| that You distribute, all copyright, patent, trademark, and | ||
| attribution notices from the Source form of the Work, | ||
| excluding those notices that do not pertain to any part of | ||
| the Derivative Works; and | ||
| (d) If the Work includes a "NOTICE" text file as part of its | ||
| distribution, then any Derivative Works that You distribute must | ||
| include a readable copy of the attribution notices contained | ||
| within such NOTICE file, excluding those notices that do not | ||
| pertain to any part of the Derivative Works, in at least one | ||
| of the following places: within a NOTICE text file distributed | ||
| as part of the Derivative Works; within the Source form or | ||
| documentation, if provided along with the Derivative Works; or, | ||
| within a display generated by the Derivative Works, if and | ||
| wherever such third-party notices normally appear. The contents | ||
| of the NOTICE file are for informational purposes only and | ||
| do not modify the License. You may add Your own attribution | ||
| notices within Derivative Works that You distribute, alongside | ||
| or as an addendum to the NOTICE text from the Work, provided | ||
| that such additional attribution notices cannot be construed | ||
| as modifying the License. | ||
| You may add Your own copyright statement to Your modifications and | ||
| may provide additional or different license terms and conditions | ||
| for use, reproduction, or distribution of Your modifications, or | ||
| for any such Derivative Works as a whole, provided Your use, | ||
| reproduction, and distribution of the Work otherwise complies with | ||
| the conditions stated in this License. | ||
| 5. Submission of Contributions. Unless You explicitly state otherwise, | ||
| any Contribution intentionally submitted for inclusion in the Work | ||
| by You to the Licensor shall be under the terms and conditions of | ||
| this License, without any additional terms or conditions. | ||
| Notwithstanding the above, nothing herein shall supersede or modify | ||
| the terms of any separate license agreement you may have executed | ||
| with Licensor regarding such Contributions. | ||
| 6. Trademarks. This License does not grant permission to use the trade | ||
| names, trademarks, service marks, or product names of the Licensor, | ||
| except as required for reasonable and customary use in describing the | ||
| origin of the Work and reproducing the content of the NOTICE file. | ||
| 7. Disclaimer of Warranty. Unless required by applicable law or | ||
| agreed to in writing, Licensor provides the Work (and each | ||
| Contributor provides its Contributions) on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||
| implied, including, without limitation, any warranties or conditions | ||
| of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | ||
| PARTICULAR PURPOSE. You are solely responsible for determining the | ||
| appropriateness of using or redistributing the Work and assume any | ||
| risks associated with Your exercise of permissions under this License. | ||
| 8. Limitation of Liability. In no event and under no legal theory, | ||
| whether in tort (including negligence), contract, or otherwise, | ||
| unless required by applicable law (such as deliberate and grossly | ||
| negligent acts) or agreed to in writing, shall any Contributor be | ||
| liable to You for damages, including any direct, indirect, special, | ||
| incidental, or consequential damages of any character arising as a | ||
| result of this License or out of the use or inability to use the | ||
| Work (including but not limited to damages for loss of goodwill, | ||
| work stoppage, computer failure or malfunction, or any and all | ||
| other commercial damages or losses), even if such Contributor | ||
| has been advised of the possibility of such damages. | ||
| 9. Accepting Warranty or Additional Liability. While redistributing | ||
| the Work or Derivative Works thereof, You may choose to offer, | ||
| and charge a fee for, acceptance of support, warranty, indemnity, | ||
| or other liability obligations and/or rights consistent with this | ||
| License. However, in accepting such obligations, You may act only | ||
| on Your own behalf and on Your sole responsibility, not on behalf | ||
| of any other Contributor, and only if You agree to indemnify, | ||
| defend, and hold each Contributor harmless for any liability | ||
| incurred by, or claims asserted against, such Contributor by reason | ||
| of your accepting any such warranty or additional liability. | ||
| END OF TERMS AND CONDITIONS | ||
| --- | ||
| MIT License | ||
| Copyright (c) 2024-2025 Model Context Protocol a Series of LF Projects, LLC. | ||
| Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| of this software and associated documentation files (the "Software"), to deal | ||
| in the Software without restriction, including without limitation the rights | ||
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| copies of the Software, and to permit persons to whom the Software is | ||
| furnished to do so, subject to the following conditions: | ||
| The above copyright notice and this permission notice shall be included in all | ||
| copies or substantial portions of the Software. | ||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
| SOFTWARE. | ||
| --- | ||
| Creative Commons Attribution 4.0 International (CC-BY-4.0) | ||
| Documentation in this project (excluding specifications) is licensed under | ||
| CC-BY-4.0. See https://creativecommons.org/licenses/by/4.0/legalcode for | ||
| the full license text. |
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is not supported yet
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display
Sorry, the diff of this file is too big to display