🚀. Socket Launch Week Day 3:Socket Firewall Now Blocks Malicious VS Code and Open VSX Extensions.Learn more
Sign In

rmcp

Package Overview
Dependencies
Maintainers
1
Versions
46
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

rmcp - cargo Package Compare versions

Comparing version
1.3.0
to
1.4.0
+1
-1
.cargo_vcs_info.json
{
"git": {
"sha1": "ac749e3cedfc036a5b77960337669c7cf2338035"
"sha1": "4628720f89d27a01d4a126ea9f82f0775df9ed52"
},
"path_in_vcs": "crates/rmcp"
}

@@ -15,3 +15,3 @@ # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO

name = "rmcp"
version = "1.3.0"
version = "1.4.0"
build = "build.rs"

@@ -32,3 +32,28 @@ autolib = false

[package.metadata.docs.rs]
all-features = true
features = [
"auth",
"auth-client-credentials-jwt",
"base64",
"client",
"client-side-sse",
"elicitation",
"macros",
"reqwest",
"reqwest-native-tls",
"reqwest-tls-no-provider",
"schemars",
"server",
"server-side-http",
"tower",
"transport-async-rw",
"transport-child-process",
"transport-io",
"transport-streamable-http-client",
"transport-streamable-http-client-reqwest",
"transport-streamable-http-client-unix-socket",
"transport-streamable-http-server",
"transport-streamable-http-server-session",
"transport-worker",
"uuid",
]
rustdoc-args = [

@@ -137,2 +162,6 @@ "--cfg",

transport-worker = ["dep:tokio-stream"]
which-command = [
"transport-child-process",
"dep:which",
]

@@ -494,3 +523,3 @@ [lib]

[dependencies.rmcp-macros]
version = "1.3.0"
version = "1.4.0"
optional = true

@@ -552,2 +581,6 @@

[dependencies.which]
version = "7"
optional = true
[dev-dependencies.anyhow]

@@ -554,0 +587,0 @@ version = "1.0"

@@ -10,2 +10,24 @@ # Changelog

## [1.4.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v1.3.0...rmcp-v1.4.0) - 2026-04-09
### Added
- add Default and constructors to ServerSseMessage ([#794](https://github.com/modelcontextprotocol/rust-sdk/pull/794))
- add meta to elicitation results ([#792](https://github.com/modelcontextprotocol/rust-sdk/pull/792))
- *(macros)* auto-generate get_info and default router ([#785](https://github.com/modelcontextprotocol/rust-sdk/pull/785))
- *(transport)* add which_command for cross-platform executable resolution ([#774](https://github.com/modelcontextprotocol/rust-sdk/pull/774))
- *(auth)* add StoredCredentials::new() constructor ([#778](https://github.com/modelcontextprotocol/rust-sdk/pull/778))
### Fixed
- *(server)* remove initialized notification gate to support Streamable HTTP ([#788](https://github.com/modelcontextprotocol/rust-sdk/pull/788))
- default session keep_alive to 5 minutes ([#780](https://github.com/modelcontextprotocol/rust-sdk/pull/780))
- *(http)* add host check ([#764](https://github.com/modelcontextprotocol/rust-sdk/pull/764))
- exclude local feature from docs.rs build ([#782](https://github.com/modelcontextprotocol/rust-sdk/pull/782))
### Other
- update Rust toolchain to 1.92 ([#797](https://github.com/modelcontextprotocol/rust-sdk/pull/797))
- unify IntoCallToolResult Result impls ([#787](https://github.com/modelcontextprotocol/rust-sdk/pull/787))
## [1.3.0](https://github.com/modelcontextprotocol/rust-sdk/compare/rmcp-v1.2.0...rmcp-v1.3.0) - 2026-03-24

@@ -12,0 +34,0 @@

@@ -149,2 +149,3 @@ pub mod progress;

/// content: Some(user_input),
/// meta: None,
/// })

@@ -158,2 +159,3 @@ /// }

/// content: None,
/// meta: None,
/// })

@@ -176,2 +178,3 @@ /// }

content: None,
meta: None,
}))

@@ -178,0 +181,0 @@ }

@@ -13,5 +13,4 @@ //! Tools for MCP servers.

//! # use serde::{Serialize, Deserialize};
//! struct Server {
//! tool_router: ToolRouter<Self>,
//! }
//! struct Server;
//!
//! #[derive(Deserialize, schemars::JsonSchema, Default)]

@@ -26,3 +25,3 @@ //! struct AddParameter {

//! }
//! #[tool_router]
//! #[tool_router(server_handler)]
//! impl Server {

@@ -39,2 +38,7 @@ //! #[tool(name = "adder", description = "Modular add two integers")]

//!
//! The `server_handler` flag emits `#[tool_handler]` for you (tools-only servers). For custom
//! `#[tool_handler(...)]` options or multiple handler macros on one `impl ServerHandler`, write
//! `#[tool_router]` and `#[tool_handler] impl ServerHandler for ...` explicitly—see
//! [`tool_router`][crate::tool_router] and [`tool_handler`][crate::tool_handler].
//!
//! Using the macro-based code pattern above is suitable for small MCP servers with simple interfaces.

@@ -41,0 +45,0 @@ //! When the business logic become larger, it is recommended that each tool should reside

@@ -88,16 +88,25 @@ use std::{

impl<T: IntoContents, E: IntoContents> IntoCallToolResult for Result<T, E> {
impl IntoCallToolResult for CallToolResult {
fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
match self {
Ok(value) => Ok(CallToolResult::success(value.into_contents())),
Err(error) => Ok(CallToolResult::error(error.into_contents())),
}
Ok(self)
}
}
impl<T: IntoCallToolResult> IntoCallToolResult for Result<T, crate::ErrorData> {
impl IntoCallToolResult for crate::ErrorData {
fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
Err(self)
}
}
impl<T: IntoCallToolResult, E: IntoCallToolResult> IntoCallToolResult for Result<T, E> {
fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
match self {
Ok(value) => value.into_call_tool_result(),
Err(error) => Err(error),
Err(error) => match error.into_call_tool_result() {
Ok(mut result) => {
result.is_error = Some(true);
Ok(result)
}
Err(e) => Err(e),
},
}

@@ -143,8 +152,2 @@ }

impl IntoCallToolResult for Result<CallToolResult, crate::ErrorData> {
fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
self
}
}
pub trait CallToolHandler<S, A> {

@@ -151,0 +154,0 @@ fn call(

@@ -6,6 +6,3 @@ use std::borrow::Cow;

use crate::{
handler::server::tool::IntoCallToolResult,
model::{CallToolResult, IntoContents},
};
use crate::{handler::server::tool::IntoCallToolResult, model::CallToolResult};

@@ -45,13 +42,1 @@ /// Json wrapper for structured output

}
// Implementation for Result<Json<T>, E>
impl<T: Serialize + JsonSchema + 'static, E: IntoContents> IntoCallToolResult
for Result<Json<T>, E>
{
fn into_call_tool_result(self) -> Result<CallToolResult, crate::ErrorData> {
match self {
Ok(value) => value.into_call_tool_result(),
Err(error) => Ok(CallToolResult::error(error.into_contents())),
}
}
}

@@ -56,2 +56,6 @@ use std::borrow::Cow;

#[deprecated(
since = "1.4.0",
note = "The server no longer gates on the initialized notification. This variant is never constructed and will be removed in a future major release."
)]
#[error("expect initialized notification, but received: {0:?}")]

@@ -247,45 +251,8 @@ ExpectedInitializedNotification(Option<ClientJsonRpcMessage>),

// Wait for initialized notification. The MCP spec permits logging/setLevel and ping
// before initialized; VS Code sends setLevel immediately after the initialize response.
let notification = loop {
let msg = expect_next_message(&mut transport, "initialize notification").await?;
match msg {
ClientJsonRpcMessage::Notification(n)
if matches!(
n.notification,
ClientNotification::InitializedNotification(_)
) =>
{
break n.notification;
}
ClientJsonRpcMessage::Request(req)
if matches!(
req.request,
ClientRequest::SetLevelRequest(_) | ClientRequest::PingRequest(_)
) =>
{
transport
.send(ServerJsonRpcMessage::response(
ServerResult::EmptyResult(EmptyResult {}),
req.id,
))
.await
.map_err(|error| {
ServerInitializeError::transport::<T>(error, "sending pre-init response")
})?;
}
other => {
return Err(ServerInitializeError::ExpectedInitializedNotification(
Some(other),
));
}
}
};
let context = NotificationContext {
meta: notification.get_meta().clone(),
extensions: notification.extensions().clone(),
peer: peer.clone(),
};
let _ = service.handle_notification(notification, context).await;
// Continue processing service
// Enter the main service loop immediately after sending InitializeResult.
// The initialized notification will be handled as a regular notification by serve_inner.
// This matches the TypeScript SDK behavior: no init gate, no waiting for initialized.
// Streamable HTTP has no ordering guarantee between POSTs, and the MCP spec uses
// SHOULD NOT (not MUST NOT) for pre-initialized messages, so any request arriving
// before initialized is processed normally.
Ok(serve_inner(service, transport, peer, peer_rx, ct))

@@ -292,0 +259,0 @@ }

@@ -86,2 +86,4 @@ //! # Transport

pub mod child_process;
#[cfg(feature = "which-command")]
pub use child_process::which_command;
#[cfg(feature = "transport-child-process")]

@@ -88,0 +90,0 @@ pub use child_process::{ConfigureCommandExt, TokioChildProcess};

@@ -236,2 +236,52 @@ use std::process::Stdio;

/// Resolve the absolute path to an executable using the system `PATH`,
/// then return a [`tokio::process::Command`] pointing at it.
///
/// This is especially useful on Windows where `.cmd` / `.exe` shim scripts
/// (e.g. `npx.cmd`) are not reliably found by [`tokio::process::Command`]
/// without a fully-qualified path.
///
/// # Example
/// ```rust,no_run
/// use rmcp::transport::{which_command, ConfigureCommandExt};
///
/// # fn example() -> std::io::Result<()> {
/// let cmd = which_command("npx")?
/// .configure(|cmd| {
/// cmd.arg("-y").arg("@modelcontextprotocol/server-everything");
/// });
/// # Ok(())
/// # }
/// ```
#[cfg(feature = "which-command")]
pub fn which_command(
name: impl AsRef<std::ffi::OsStr>,
) -> std::io::Result<tokio::process::Command> {
let resolved = which::which(name.as_ref())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::NotFound, e))?;
Ok(tokio::process::Command::new(resolved))
}
#[cfg(feature = "which-command")]
#[cfg(test)]
mod tests_which {
#[test]
fn which_command_resolves_known_binary() {
// `ls` exists on every Unix system, `cmd` on Windows
#[cfg(unix)]
let result = super::which_command("ls");
#[cfg(windows)]
let result = super::which_command("cmd");
assert!(result.is_ok());
}
#[test]
fn which_command_fails_for_nonexistent() {
let result = super::which_command("this_binary_definitely_does_not_exist_12345");
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::NotFound);
}
}
#[cfg(unix)]

@@ -238,0 +288,0 @@ #[cfg(test)]

@@ -10,2 +10,3 @@ pub const HEADER_SESSION_ID: &str = "Mcp-Session-Id";

/// injects it after initialization.
#[allow(dead_code)]
pub(crate) const RESERVED_HEADERS: &[&str] = &[

@@ -40,2 +41,3 @@ "accept",

/// Handles both quoted (`scope="files:read files:write"`) and unquoted (`scope=read:data`) forms.
#[cfg(feature = "client-side-sse")]
pub(crate) fn extract_scope_from_header(header: &str) -> Option<String> {

@@ -68,4 +70,6 @@ let header_lowercase = header.to_ascii_lowercase();

mod tests {
#[cfg(feature = "client-side-sse")]
use super::*;
#[cfg(feature = "client-side-sse")]
#[test]

@@ -80,2 +84,3 @@ fn extract_scope_quoted() {

#[cfg(feature = "client-side-sse")]
#[test]

@@ -90,2 +95,3 @@ fn extract_scope_unquoted() {

#[cfg(feature = "client-side-sse")]
#[test]

@@ -97,2 +103,3 @@ fn extract_scope_missing() {

#[cfg(feature = "client-side-sse")]
#[test]

@@ -99,0 +106,0 @@ fn extract_scope_empty_header() {

@@ -60,3 +60,3 @@ #![allow(dead_code)]

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
#[non_exhaustive]

@@ -75,2 +75,33 @@ pub struct ServerSseMessage {

impl ServerSseMessage {
/// Create a message carrying a JSON-RPC response/notification with an event ID.
pub fn new(event_id: impl Into<String>, message: ServerJsonRpcMessage) -> Self {
Self {
event_id: Some(event_id.into()),
message: Some(Arc::new(message)),
retry: None,
}
}
/// Wrap a JSON-RPC message without an event ID or retry hint.
pub fn from_message(message: ServerJsonRpcMessage) -> Self {
Self {
event_id: None,
message: Some(Arc::new(message)),
retry: None,
}
}
/// Create a priming event that tells the client to reconnect after `retry`
/// if the connection drops.
/// See [SEP-1699](https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1699).
pub fn priming(event_id: impl Into<String>, retry: Duration) -> Self {
Self {
event_id: Some(event_id.into()),
message: None,
retry: Some(retry),
}
}
}
pub(crate) fn sse_stream_response(

@@ -174,1 +205,47 @@ stream: impl futures::Stream<Item = ServerSseMessage> + Send + Sync + 'static,

}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::{EmptyResult, JsonRpcResponse, JsonRpcVersion2_0, RequestId, ServerResult};
fn dummy_message() -> ServerJsonRpcMessage {
ServerJsonRpcMessage::Response(JsonRpcResponse {
jsonrpc: JsonRpcVersion2_0,
id: RequestId::Number(1),
result: ServerResult::EmptyResult(EmptyResult {}),
})
}
#[test]
fn default_has_all_none() {
let msg = ServerSseMessage::default();
assert!(msg.event_id.is_none());
assert!(msg.message.is_none());
assert!(msg.retry.is_none());
}
#[test]
fn new_sets_event_id_and_message() {
let msg = ServerSseMessage::new("42", dummy_message());
assert_eq!(msg.event_id.as_deref(), Some("42"));
assert!(msg.message.is_some());
assert!(msg.retry.is_none());
}
#[test]
fn from_message_has_no_event_id() {
let msg = ServerSseMessage::from_message(dummy_message());
assert!(msg.event_id.is_none());
assert!(msg.message.is_some());
assert!(msg.retry.is_none());
}
#[test]
fn priming_sets_event_id_and_retry() {
let msg = ServerSseMessage::priming("0", Duration::from_secs(5));
assert_eq!(msg.event_id.as_deref(), Some("0"));
assert!(msg.message.is_none());
assert_eq!(msg.retry, Some(Duration::from_secs(5)));
}
}
use std::{
collections::{HashMap, HashSet, VecDeque},
num::ParseIntError,
sync::Arc,
time::Duration,

@@ -225,7 +224,3 @@ };

let event_id = self.next_event_id();
let message = ServerSseMessage {
event_id: Some(event_id.to_string()),
message: Some(Arc::new(message)),
retry: None,
};
let message = ServerSseMessage::new(event_id.to_string(), message);
self.cache_and_send(message).await;

@@ -236,7 +231,3 @@ }

let event_id = self.next_event_id();
let message = ServerSseMessage {
event_id: Some(event_id.to_string()),
message: None,
retry: Some(retry),
};
let message = ServerSseMessage::priming(event_id.to_string(), retry);
self.cache_and_send(message).await;

@@ -1074,3 +1065,13 @@ }

pub channel_capacity: usize,
/// if set, the session will be closed after this duration of inactivity.
/// The session will be closed after this duration of inactivity.
///
/// This serves as a safety net for cleaning up sessions whose HTTP
/// connections have silently dropped (e.g., due to an HTTP/2
/// `RST_STREAM`). Without a timeout, such sessions become zombies:
/// the session worker keeps running indefinitely because the session
/// handle's sender is still held in the session manager, preventing
/// the worker's event channel from closing.
///
/// Defaults to 5 minutes. Set to `None` to disable (not recommended
/// for long-running servers behind proxies).
pub keep_alive: Option<Duration>,

@@ -1081,2 +1082,3 @@ }

pub const DEFAULT_CHANNEL_CAPACITY: usize = 16;
pub const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(300);
}

@@ -1088,3 +1090,3 @@

channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY,
keep_alive: None,
keep_alive: Some(Self::DEFAULT_KEEP_ALIVE),
}

@@ -1091,0 +1093,0 @@ }

@@ -5,3 +5,3 @@ use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};

use futures::{StreamExt, future::BoxFuture};
use http::{Method, Request, Response, header::ALLOW};
use http::{HeaderMap, Method, Request, Response, header::ALLOW};
use http_body::Body;

@@ -33,4 +33,4 @@ use http_body_util::{BodyExt, Full, combinators::BoxBody};

#[non_exhaustive]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct StreamableHttpServerConfig {

@@ -54,2 +54,12 @@ /// The ping message duration for SSE connections.

pub cancellation_token: CancellationToken,
/// Allowed hostnames or `host:port` authorities for inbound `Host` validation.
///
/// By default, Streamable HTTP servers only accept loopback hosts to
/// prevent DNS rebinding attacks against locally running servers. Public
/// deployments should override this list with their own hostnames.
/// examples:
/// allowed_hosts = ["localhost", "127.0.0.1", "0.0.0.0"]
/// or with ports:
/// allowed_hosts = ["example.com", "example.com:8080"]
pub allowed_hosts: Vec<String>,
}

@@ -65,2 +75,3 @@

cancellation_token: CancellationToken::new(),
allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()],
}

@@ -71,2 +82,14 @@ }

impl StreamableHttpServerConfig {
pub fn with_allowed_hosts(
mut self,
allowed_hosts: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.allowed_hosts = allowed_hosts.into_iter().map(Into::into).collect();
self
}
/// Disable allowed hosts. This will allow requests with any `Host` header, which is NOT recommended for public deployments.
pub fn disable_allowed_hosts(mut self) -> Self {
self.allowed_hosts.clear();
self
}
pub fn with_sse_keep_alive(mut self, duration: Option<Duration>) -> Self {

@@ -138,2 +161,93 @@ self.sse_keep_alive = duration;

fn forbidden_response(message: impl Into<String>) -> BoxResponse {
Response::builder()
.status(http::StatusCode::FORBIDDEN)
.body(Full::new(Bytes::from(message.into())).boxed())
.expect("valid response")
}
fn normalize_host(host: &str) -> String {
host.trim_matches('[')
.trim_matches(']')
.to_ascii_lowercase()
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct NormalizedAuthority {
host: String,
port: Option<u16>,
}
fn normalize_authority(host: &str, port: Option<u16>) -> NormalizedAuthority {
NormalizedAuthority {
host: normalize_host(host),
port,
}
}
fn parse_allowed_authority(allowed: &str) -> Option<NormalizedAuthority> {
let allowed = allowed.trim();
if allowed.is_empty() {
return None;
}
if let Ok(authority) = http::uri::Authority::try_from(allowed) {
return Some(normalize_authority(authority.host(), authority.port_u16()));
}
Some(normalize_authority(allowed, None))
}
fn host_is_allowed(host: &NormalizedAuthority, allowed_hosts: &[String]) -> bool {
if allowed_hosts.is_empty() {
// If the allowed hosts list is empty, allow all hosts (not recommended).
return true;
}
allowed_hosts
.iter()
.filter_map(|allowed| parse_allowed_authority(allowed))
.any(|allowed| {
allowed.host == host.host
&& match allowed.port {
Some(port) => host.port == Some(port),
None => true,
}
})
}
fn bad_request_response(message: &str) -> BoxResponse {
let body = Full::from(message.to_string()).boxed();
http::Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.header(http::header::CONTENT_TYPE, "text/plain; charset=utf-8")
.body(body)
.expect("failed to build bad request response")
}
fn parse_host_header(headers: &HeaderMap) -> Result<NormalizedAuthority, BoxResponse> {
let Some(host) = headers.get(http::header::HOST) else {
return Err(bad_request_response("Bad Request: missing Host header"));
};
let host = host
.to_str()
.map_err(|_| bad_request_response("Bad Request: Invalid Host header encoding"))?;
let authority = http::uri::Authority::try_from(host)
.map_err(|_| bad_request_response("Bad Request: Invalid Host header"))?;
Ok(normalize_authority(authority.host(), authority.port_u16()))
}
fn validate_dns_rebinding_headers(
headers: &HeaderMap,
config: &StreamableHttpServerConfig,
) -> Result<(), BoxResponse> {
let host = parse_host_header(headers)?;
if !host_is_allowed(&host, &config.allowed_hosts) {
return Err(forbidden_response("Forbidden: Host header is not allowed"));
}
Ok(())
}
/// # Streamable HTTP server

@@ -288,2 +402,5 @@ ///

{
if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) {
return response;
}
let method = request.method().clone();

@@ -392,7 +509,3 @@ let allowed_methods = match self.config.stateful_mode {

let stream = if let Some(retry) = self.config.sse_retry {
let priming = ServerSseMessage {
event_id: Some("0".into()),
message: None,
retry: Some(retry),
};
let priming = ServerSseMessage::priming("0", retry);
futures::stream::once(async move { priming })

@@ -503,7 +616,3 @@ .chain(stream)

let stream = if let Some(retry) = self.config.sse_retry {
let priming = ServerSseMessage {
event_id: Some("0".into()),
message: None,
retry: Some(retry),
};
let priming = ServerSseMessage::priming("0", retry);
futures::stream::once(async move { priming })

@@ -582,16 +691,7 @@ .chain(stream)

.map_err(internal_error_response("create stream"))?;
let stream = futures::stream::once(async move {
ServerSseMessage {
event_id: None,
message: Some(Arc::new(response)),
retry: None,
}
});
let stream =
futures::stream::once(async move { ServerSseMessage::from_message(response) });
// Prepend priming event if sse_retry configured
let stream = if let Some(retry) = self.config.sse_retry {
let priming = ServerSseMessage {
event_id: Some("0".into()),
message: None,
retry: Some(retry),
};
let priming = ServerSseMessage::priming("0", retry);
futures::stream::once(async move { priming })

@@ -670,7 +770,3 @@ .chain(stream)

tracing::trace!(?message);
ServerSseMessage {
event_id: None,
message: Some(Arc::new(message)),
retry: None,
}
ServerSseMessage::from_message(message)
});

@@ -677,0 +773,0 @@ Ok(sse_stream_response(

@@ -764,2 +764,3 @@ #![cfg(not(feature = "local"))]

.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.body(Full::new(Bytes::from(init_body.to_string())))

@@ -789,2 +790,3 @@ .unwrap();

.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.header("mcp-session-id", &session_id)

@@ -807,2 +809,3 @@ .header("mcp-protocol-version", "2025-03-26")

.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.header("mcp-session-id", &session_id)

@@ -829,2 +832,3 @@ .header("mcp-protocol-version", "2025-03-26")

.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.header("mcp-session-id", &session_id)

@@ -851,2 +855,3 @@ .header("mcp-protocol-version", "9999-01-01")

.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.header("mcp-session-id", &session_id)

@@ -878,1 +883,154 @@ .body(Full::new(Bytes::from(no_version_body.to_string())))

}
/// Integration test: Verify server validates only the Host header for DNS rebinding protection
#[tokio::test]
#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))]
async fn test_server_validates_host_header_for_dns_rebinding_protection() {
use std::sync::Arc;
use bytes::Bytes;
use http::{Method, Request, header::CONTENT_TYPE};
use http_body_util::Full;
use rmcp::{
handler::server::ServerHandler,
model::{ServerCapabilities, ServerInfo},
transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
},
};
use serde_json::json;
#[derive(Clone)]
struct TestHandler;
impl ServerHandler for TestHandler {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().build())
}
}
let service = StreamableHttpService::new(
|| Ok(TestHandler),
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default(),
);
let init_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
});
let allowed_request = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.header("Origin", "http://localhost:8080")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();
let response = service.handle(allowed_request).await;
assert_eq!(response.status(), http::StatusCode::OK);
let bad_host_request = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "attacker.example")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();
let response = service.handle(bad_host_request).await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
let ignored_origin_request = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.header("Origin", "http://attacker.example")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();
let response = service.handle(ignored_origin_request).await;
assert_eq!(response.status(), http::StatusCode::OK);
}
/// Integration test: Verify server can enforce an allowed Host port when configured
#[tokio::test]
#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))]
async fn test_server_validates_host_header_port_for_dns_rebinding_protection() {
use std::sync::Arc;
use bytes::Bytes;
use http::{Method, Request, header::CONTENT_TYPE};
use http_body_util::Full;
use rmcp::{
handler::server::ServerHandler,
model::{ServerCapabilities, ServerInfo},
transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
},
};
use serde_json::json;
#[derive(Clone)]
struct TestHandler;
impl ServerHandler for TestHandler {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().build())
}
}
let service = StreamableHttpService::new(
|| Ok(TestHandler),
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default().with_allowed_hosts(["localhost:8080"]),
);
let init_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
});
let allowed_request = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();
let response = service.handle(allowed_request).await;
assert_eq!(response.status(), http::StatusCode::OK);
let wrong_port_request = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:3000")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();
let response = service.handle(wrong_port_request).await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
}

@@ -1,2 +0,2 @@

#![cfg(not(feature = "local"))]
#![cfg(all(feature = "client", feature = "server", not(feature = "local")))]
// cargo test --test test_inflight_response_drain --features "client server"

@@ -3,0 +3,0 @@

@@ -421,2 +421,10 @@ {

"properties": {
"_meta": {
"description": "Optional protocol-level metadata for this result.",
"type": [
"object",
"null"
],
"additionalProperties": true
},
"action": {

@@ -423,0 +431,0 @@ "description": "The user's decision on how to handle the elicitation request",

@@ -421,2 +421,10 @@ {

"properties": {
"_meta": {
"description": "Optional protocol-level metadata for this result.",
"type": [
"object",
"null"
],
"additionalProperties": true
},
"action": {

@@ -423,0 +431,0 @@ "description": "The user's decision on how to handle the elicitation request",

@@ -33,3 +33,3 @@ //cargo test --test test_prompt_handler --features "client server"

#[prompt_handler]
#[prompt_handler(router = self.prompt_router)]
impl ServerHandler for TestPromptServer {}

@@ -84,3 +84,3 @@

#[prompt_handler]
#[prompt_handler(router = self.prompt_router)]
impl<T: Send + Sync + 'static> ServerHandler for GenericPromptServer<T> {}

@@ -153,3 +153,3 @@

#[prompt_handler]
#[prompt_handler(router = self.prompt_router)]
impl ServerHandler for NestedServer {}

@@ -156,0 +156,0 @@

@@ -9,3 +9,2 @@ // cargo test --features "client" --package rmcp -- server_init

model::{ClientJsonRpcMessage, ServerJsonRpcMessage, ServerResult},
service::ServerInitializeError,
transport::{IntoTransport, Transport},

@@ -58,3 +57,3 @@ };

// Server responds with EmptyResult to setLevel received before initialized.
// Server handles setLevel sent before initialized notification (processed by serve_inner).
#[tokio::test]

@@ -69,3 +68,10 @@ async fn server_init_set_level_response_is_empty_result() {

let response = client.receive().await.unwrap();
// The handler may send logging notifications before the response;
// skip notifications to find the EmptyResult response.
let response = loop {
let msg = client.receive().await.unwrap();
if matches!(msg, ServerJsonRpcMessage::Response(_)) {
break msg;
}
};
assert!(

@@ -91,3 +97,9 @@ matches!(

client.send(set_level_request(2)).await.unwrap();
let _response = client.receive().await.unwrap();
// Skip notifications until we get the response
loop {
let msg = client.receive().await.unwrap();
if matches!(msg, ServerJsonRpcMessage::Response(_)) {
break;
}
}
client.send(initialized_notification()).await.unwrap();

@@ -186,5 +198,5 @@

// Server returns ExpectedInitializedNotification for any other message before initialized.
// Server buffers tools/list sent before initialized and processes it after initialization.
#[tokio::test]
async fn server_init_rejects_unexpected_message_before_initialized() {
async fn server_init_buffers_request_before_initialized() {
let (server_transport, client_transport) = tokio::io::duplex(4096);

@@ -196,12 +208,55 @@ let server_handle =

do_initialize(&mut client).await;
// Send tools/list before initialized notification
client.send(list_tools_request(2)).await.unwrap();
// Now send initialized notification
client.send(initialized_notification()).await.unwrap();
// The buffered tools/list should be processed — expect a response
let response = client.receive().await.unwrap();
assert!(
matches!(response, ServerJsonRpcMessage::Response(_)),
"expected response for buffered tools/list, got: {response:?}"
);
let result = server_handle.await.unwrap();
assert!(
matches!(
result,
Err(ServerInitializeError::ExpectedInitializedNotification(_))
),
"expected ExpectedInitializedNotification error"
result.is_ok(),
"server should initialize successfully when buffering pre-init messages"
);
result.unwrap().cancel().await.unwrap();
}
// Server buffers multiple requests before initialized and processes them in order.
#[tokio::test]
async fn server_init_buffers_multiple_requests_before_initialized() {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let server_handle =
tokio::spawn(async move { TestServer::new().serve(server_transport).await });
let mut client = IntoTransport::<rmcp::RoleClient, _, _>::into_transport(client_transport);
do_initialize(&mut client).await;
// Send two requests before initialized
client.send(list_tools_request(2)).await.unwrap();
client.send(ping_request(3)).await.unwrap();
// Now send initialized notification
client.send(initialized_notification()).await.unwrap();
// Both buffered messages should get responses
let response1 = client.receive().await.unwrap();
let response2 = client.receive().await.unwrap();
assert!(
matches!(response1, ServerJsonRpcMessage::Response(_)),
"expected response for first buffered message, got: {response1:?}"
);
assert!(
matches!(response2, ServerJsonRpcMessage::Response(_)),
"expected response for second buffered message, got: {response2:?}"
);
let result = server_handle.await.unwrap();
assert!(
result.is_ok(),
"server should initialize successfully with multiple buffered messages"
);
result.unwrap().cancel().await.unwrap();
}

@@ -22,3 +22,3 @@ #[cfg(test)]

}
#[tool_handler]
#[tool_handler(router = self.tool_router)]
impl ServerHandler for AnnotatedServer {}

@@ -25,0 +25,0 @@

@@ -13,3 +13,3 @@ #![cfg(not(feature = "local"))]

handler::server::{router::tool::ToolRouter, wrapper::Parameters},
model::{CallToolRequestParams, ClientInfo},
model::{CallToolRequestParams, ClientInfo, ServerCapabilities, ServerInfo},
tool, tool_handler, tool_router,

@@ -369,1 +369,209 @@ };

}
// --- Tests for field-free minimal server pattern (issue #711) ---
/// Minimal server: no tool_router field, no new(), no get_info().
#[derive(Debug, Clone)]
pub struct MinimalServer;
#[tool_router]
impl MinimalServer {
#[tool(description = "Say hello")]
fn hello(&self) -> String {
"hello".to_string()
}
}
#[tool_handler]
impl ServerHandler for MinimalServer {}
#[test]
fn test_minimal_server_get_info_auto_generated() {
let server = MinimalServer;
let info = server.get_info();
assert!(
info.capabilities.tools.is_some(),
"tools capability should be enabled"
);
assert!(
info.capabilities.prompts.is_none(),
"prompts should not be auto-enabled"
);
assert!(
info.capabilities.tasks.is_none(),
"tasks should not be auto-enabled"
);
assert!(
!info.server_info.name.is_empty(),
"server name should not be empty"
);
assert!(
!info.server_info.version.is_empty(),
"server version should not be empty"
);
assert!(
info.instructions.is_none(),
"instructions should be None by default"
);
}
#[tokio::test]
async fn test_minimal_server_tool_call() -> anyhow::Result<()> {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let server_handle = tokio::spawn(async move {
MinimalServer
.serve(server_transport)
.await?
.waiting()
.await?;
anyhow::Ok(())
});
let client = DummyClientHandler::default()
.serve(client_transport)
.await?;
let result = client
.call_tool(CallToolRequestParams::new("hello"))
.await?;
let text = result
.content
.first()
.and_then(|c| c.raw.as_text())
.map(|t| t.text.as_str())
.expect("Expected text content");
assert_eq!(text, "hello");
client.cancel().await?;
server_handle.await??;
Ok(())
}
/// Same minimal pattern as [`MinimalServer`], but `#[tool_handler]` is omitted using
/// `#[tool_router(server_handler)]` (emits `#[tool_handler]` for a second macro pass).
#[derive(Debug, Clone)]
pub struct ElidedToolHandlerServer;
#[tool_router(server_handler)]
impl ElidedToolHandlerServer {
#[tool(description = "Say hi")]
fn hi(&self) -> String {
"hi".to_string()
}
}
#[test]
fn test_tool_router_server_handler_flag_matches_minimal_server_get_info() {
let server = ElidedToolHandlerServer;
let info = server.get_info();
assert!(info.capabilities.tools.is_some());
assert!(
info.capabilities.prompts.is_none(),
"prompts should not be auto-enabled"
);
}
#[tokio::test]
async fn test_tool_router_server_handler_flag_end_to_end_tool_call() -> anyhow::Result<()> {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let server_handle = tokio::spawn(async move {
ElidedToolHandlerServer
.serve(server_transport)
.await?
.waiting()
.await?;
anyhow::Ok(())
});
let client = DummyClientHandler::default()
.serve(client_transport)
.await?;
let result = client.call_tool(CallToolRequestParams::new("hi")).await?;
let text = result
.content
.first()
.and_then(|c| c.raw.as_text())
.map(|t| t.text.as_str())
.expect("Expected text content");
assert_eq!(text, "hi");
client.cancel().await?;
server_handle.await??;
Ok(())
}
/// Server with custom name/version/instructions via tool_handler attributes.
#[derive(Debug, Clone)]
pub struct CustomInfoServer;
#[tool_router]
impl CustomInfoServer {
#[tool(description = "Ping")]
fn ping(&self) -> String {
"pong".to_string()
}
}
#[tool_handler(
name = "my-custom-server",
version = "2.0.0",
instructions = "A custom server"
)]
impl ServerHandler for CustomInfoServer {}
#[test]
fn test_custom_info_server() {
let server = CustomInfoServer;
let info = server.get_info();
assert_eq!(info.server_info.name, "my-custom-server");
assert_eq!(info.server_info.version, "2.0.0");
assert_eq!(info.instructions.as_deref(), Some("A custom server"));
assert!(info.capabilities.tools.is_some());
}
/// Server that provides its own get_info() — macro should not override it.
#[derive(Debug, Clone)]
pub struct ManualInfoServer;
#[tool_router]
impl ManualInfoServer {
#[tool(description = "Noop")]
fn noop(&self) {}
}
#[tool_handler]
impl ServerHandler for ManualInfoServer {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(
ServerCapabilities::builder()
.enable_tools()
.enable_resources()
.build(),
)
.with_server_info(rmcp::model::Implementation::new("manual", "9.9.9"))
}
}
#[test]
fn test_manual_get_info_not_overridden() {
let server = ManualInfoServer;
let info = server.get_info();
assert_eq!(info.server_info.name, "manual");
assert_eq!(info.server_info.version, "9.9.9");
assert!(info.capabilities.tools.is_some());
assert!(
info.capabilities.resources.is_some(),
"manual resources should be preserved"
);
}

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is too big to display