diff --git a/src/adcp/server/translate.py b/src/adcp/server/translate.py new file mode 100644 index 000000000..c4091c3b0 --- /dev/null +++ b/src/adcp/server/translate.py @@ -0,0 +1,347 @@ +"""Error translation and request normalization for proxy and custom-transport servers. + +Standard servers using ``serve()`` or ``ADCPAgentExecutor`` do not need these +helpers — the framework handles error translation and request normalization +internally. + +These are for **proxy servers** that catch ``ADCPError`` from a downstream +agent call and need to format it for their own transport, or custom +multi-transport servers that bypass the standard framework. + +Not exported from ``adcp.server`` — import directly:: + + from adcp.server.translate import translate_error, normalize_request + + # In a proxy catching errors from a downstream agent: + try: + result = await downstream_client.create_media_buy(params) + except ADCPError as e: + raise translate_error(e, protocol="a2a") + # Raises: ServerError(InternalError(message="...", data={...})) + + # Normalize deprecated field names from older callers: + params = normalize_request(params, task_name="create_media_buy") +""" + +from __future__ import annotations + +from typing import Any, Literal +from urllib.parse import urlparse + +from a2a.types import InternalError, InvalidParamsError +from a2a.utils.errors import ServerError +from mcp.server.fastmcp.exceptions import ToolError + +from adcp.exceptions import ( + ADCPAuthenticationError, + ADCPConnectionError, + ADCPError, + ADCPTaskError, + ADCPTimeoutError, +) +from adcp.server.helpers import STANDARD_ERROR_CODES +from adcp.types import Error +from adcp.types.core import Protocol + +# ============================================================================ +# Error Translation +# ============================================================================ + +# Maps Python exception types to ADCP standard error codes. +_EXCEPTION_CODE_MAP: dict[type[ADCPError], str] = { + ADCPAuthenticationError: "AUTH_REQUIRED", + ADCPTimeoutError: "SERVICE_UNAVAILABLE", + ADCPConnectionError: "SERVICE_UNAVAILABLE", +} + +# A2A JSON-RPC error codes for correctable vs non-correctable errors. +_A2A_CORRECTABLE_CODE = -32602 # InvalidParamsError +_A2A_INTERNAL_CODE = -32603 # InternalError + + +def _error_code_for_exception(exc: ADCPError) -> str: + """Derive a structured ADCP error code from an exception type.""" + # ADCPTaskError carries the original error codes from the response + if isinstance(exc, ADCPTaskError) and exc.error_codes: + return str(exc.error_codes[0]) + return _EXCEPTION_CODE_MAP.get(type(exc), "INTERNAL_ERROR") + + +def _recovery_for_code(code: str) -> str: + """Look up recovery classification for an error code.""" + std = STANDARD_ERROR_CODES.get(code) + if std: + return std["recovery"] + return "terminal" + + +def _build_error_data( + code: str, + message: str, + *, + recovery: str | None = None, + suggestion: str | None = None, + details: dict[str, Any] | None = None, + errors: list[Any] | None = None, +) -> dict[str, Any]: + """Build the structured data payload for protocol error responses.""" + data: dict[str, Any] = { + "error_code": code, + "recovery": recovery or _recovery_for_code(code), + } + if suggestion: + data["suggestion"] = suggestion + if details: + data["details"] = details + if errors: + data["errors"] = [ + e.model_dump(exclude_none=True) if hasattr(e, "model_dump") else e + for e in errors + ] + return data + + +def translate_error( + exc: ADCPError | Error, + protocol: Literal["mcp", "a2a"] | Protocol, +) -> ToolError | ServerError: + """Translate an AdCP error to a protocol SDK error type. + + Returns an error that can be directly raised in a protocol handler:: + + try: + result = await handler.create_media_buy(params) + except ADCPError as e: + raise translate_error(e, protocol="mcp") + + For MCP, returns ``ToolError`` (from ``mcp.server.fastmcp``). + For A2A, returns ``ServerError`` wrapping ``InvalidParamsError`` + (for correctable errors) or ``InternalError`` (for transient/terminal). + + The ``data`` field on A2A errors preserves recovery classification, + error_code, suggestion, and details so buyer agents can make + retry/fix/abandon decisions. + + Args: + exc: An ADCPError exception or an Error Pydantic model. + protocol: Target protocol - ``"mcp"`` or ``"a2a"``. + + Returns: + ``ToolError`` for MCP, ``ServerError`` for A2A. Raise the result. + + Raises: + ValueError: If protocol is not ``"mcp"`` or ``"a2a"``. + + Warning: + Error details are passed through to the caller. Do not include + internal state (stack traces, SQL queries, internal URLs) in + Error objects passed to this function. + """ + proto = protocol.value if isinstance(protocol, Protocol) else str(protocol) + proto = proto.lower() + if proto not in ("mcp", "a2a"): + raise ValueError(f"protocol must be 'mcp' or 'a2a', got {protocol!r}") + + # Extract structured fields from the input + if isinstance(exc, Error): + code = exc.code + message = exc.message + suggestion = exc.suggestion + details = exc.details + recovery = _recovery_for_code(code) + errors = None + elif isinstance(exc, ADCPError): + code = _error_code_for_exception(exc) + message = exc.message + suggestion = exc.suggestion + recovery = _recovery_for_code(code) + details = None + errors = getattr(exc, "errors", None) + else: + raise TypeError(f"Expected ADCPError or Error, got {type(exc).__name__}") + + if proto == "mcp": + return _to_mcp(code, message, suggestion=suggestion) + return _to_a2a( + code, message, + recovery=recovery, suggestion=suggestion, + details=details, errors=errors, + ) + + +def _to_mcp( + code: str, + message: str, + *, + suggestion: str | None = None, +) -> ToolError: + """Format error as a ToolError for MCP servers.""" + text = f"{code}: {message}" + if suggestion: + text += f"\nSuggestion: {suggestion}" + return ToolError(text) + + +def _to_a2a( + code: str, + message: str, + *, + recovery: str | None = None, + suggestion: str | None = None, + details: dict[str, Any] | None = None, + errors: list[Any] | None = None, +) -> ServerError: + """Format error as a ServerError for A2A servers.""" + data = _build_error_data( + code, message, + recovery=recovery, suggestion=suggestion, + details=details, errors=errors, + ) + + # Use InvalidParamsError for correctable errors (client can fix), + # InternalError for transient/terminal (server-side or unfixable). + effective_recovery = recovery or _recovery_for_code(code) + if effective_recovery == "correctable": + return ServerError(InvalidParamsError(message=message, data=data)) + return ServerError(InternalError(message=message, data=data)) + + +# ============================================================================ +# Request Normalization +# ============================================================================ + +# Global field renames (apply to all task types). +_GLOBAL_RENAMES: dict[str, str] = { + "promoted_offerings": "catalogs", +} + +# Tool-scoped field renames (apply only to specific task types). +_TOOL_RENAMES: dict[str, dict[str, str]] = { + "create_media_buy": { + "campaign_ref": "buyer_campaign_ref", + }, +} + + +def _normalize_account(params: dict[str, Any]) -> None: + """Reshape account_id string to account object. + + Old format: ``account_id: "123"`` + New format: ``account: {account_id: "123"}`` + """ + if "account_id" not in params: + return + if "account" not in params: + params["account"] = {"account_id": params.pop("account_id")} + else: + del params["account_id"] + + +def _normalize_brand_manifest(params: dict[str, Any]) -> None: + """Reshape brand_manifest URL string to brand object. + + Old format: ``brand_manifest: "https://example.com/brand.json"`` + New format: ``brand: {domain: "example.com"}`` + """ + if "brand_manifest" not in params: + return + if "brand" not in params: + manifest = params.pop("brand_manifest") + if isinstance(manifest, str): + parsed = urlparse(manifest) + params["brand"] = {"domain": parsed.hostname or manifest} + else: + # Already an object, just rename the key + params["brand"] = manifest + else: + del params["brand_manifest"] + + +def _normalize_packages(params: dict[str, Any]) -> None: + """Normalize package-level fields: scalar-to-array wraps. + + - ``optimization_goal`` (str) → ``optimization_goals`` (list[str]) + - ``catalog`` (str) → ``catalogs`` (list[str]) + """ + packages = params.get("packages") + if not isinstance(packages, list): + return + for pkg in packages: + if not isinstance(pkg, dict): + continue + # optimization_goal → optimization_goals + if "optimization_goal" in pkg and "optimization_goals" not in pkg: + pkg["optimization_goals"] = [pkg.pop("optimization_goal")] + elif "optimization_goal" in pkg: + del pkg["optimization_goal"] + # catalog → catalogs + if "catalog" in pkg and "catalogs" not in pkg: + pkg["catalogs"] = [pkg.pop("catalog")] + elif "catalog" in pkg: + del pkg["catalog"] + + +def normalize_request( + params: dict[str, Any], + task_name: str | None = None, +) -> dict[str, Any]: + """Normalize deprecated field names and structures in request params. + + Applies known transforms so servers can accept both old and new field + formats without duplicating normalization logic in every handler. + + Transforms applied: + + - ``account_id: "123"`` → ``account: {account_id: "123"}`` (structural) + - ``brand_manifest: "https://..."`` → ``brand: {domain: "..."}`` (URL parse) + - ``promoted_offerings`` → ``catalogs`` (rename) + - ``campaign_ref`` → ``buyer_campaign_ref`` (create_media_buy only) + - Package-level ``optimization_goal`` → ``optimization_goals`` (scalar→array) + - Package-level ``catalog`` → ``catalogs`` (scalar→array) + + If both the deprecated and current field name are present, the current + name takes precedence and the deprecated name is removed. + + Args: + params: Request parameters dict. + task_name: ADCP task/tool name (e.g. ``"create_media_buy"``). + Enables tool-scoped renames when provided. + + Returns: + New dict with deprecated field names replaced by current names. + Original dict is not mutated (top-level copy; packages list is + copied if package-level transforms apply). + """ + result = dict(params) + + # Structural transforms + _normalize_account(result) + _normalize_brand_manifest(result) + + # Package-level transforms (deep copy the packages list) + if "packages" in result and isinstance(result["packages"], list): + result["packages"] = [ + dict(pkg) if isinstance(pkg, dict) else pkg + for pkg in result["packages"] + ] + _normalize_packages(result) + + # Global renames + for old_name, new_name in _GLOBAL_RENAMES.items(): + if old_name in result: + if new_name not in result: + result[new_name] = result.pop(old_name) + else: + del result[old_name] + + # Tool-scoped renames + if task_name: + tool_renames = _TOOL_RENAMES.get(task_name, {}) + for old_name, new_name in tool_renames.items(): + if old_name in result: + if new_name not in result: + result[new_name] = result.pop(old_name) + else: + del result[old_name] + + return result diff --git a/tests/test_translate.py b/tests/test_translate.py new file mode 100644 index 000000000..c9c133ede --- /dev/null +++ b/tests/test_translate.py @@ -0,0 +1,373 @@ +"""Tests for error translation and request normalization helpers.""" + +from __future__ import annotations + +import pytest +from a2a.types import InternalError, InvalidParamsError +from a2a.utils.errors import ServerError +from mcp.server.fastmcp.exceptions import ToolError + +from adcp.exceptions import ( + ADCPAuthenticationError, + ADCPConnectionError, + ADCPError, + ADCPTaskError, + ADCPTimeoutError, +) +from adcp.server.translate import normalize_request, translate_error +from adcp.types import Error +from adcp.types.core import Protocol + +# ============================================================================ +# translate_error → MCP +# ============================================================================ + + +class TestTranslateErrorToMCP: + """Test translate_error with protocol='mcp'.""" + + def test_returns_tool_error(self): + """MCP translation returns a ToolError instance.""" + exc = ADCPError("something went wrong") + result = translate_error(exc, protocol="mcp") + assert isinstance(result, ToolError) + + def test_includes_code_and_message(self): + """ToolError text contains the error code and message.""" + exc = ADCPError("something went wrong") + result = translate_error(exc, protocol="mcp") + assert "INTERNAL_ERROR" in str(result) + assert "something went wrong" in str(result) + + def test_error_model_uses_its_code(self): + """Error Pydantic model produces ToolError with its own code.""" + err = Error(code="VALIDATION_ERROR", message="Missing field 'packages'") + result = translate_error(err, protocol="mcp") + assert "VALIDATION_ERROR" in str(result) + assert "packages" in str(result) + + def test_preserves_suggestion(self): + """Suggestion from ADCPError appears in ToolError text.""" + exc = ADCPError("bad request", suggestion="Set the budget field") + result = translate_error(exc, protocol="mcp") + assert "Set the budget field" in str(result) + + def test_auth_error_maps_to_auth_required(self): + """ADCPAuthenticationError maps to AUTH_REQUIRED code.""" + exc = ADCPAuthenticationError("Invalid token", agent_id="test-agent") + result = translate_error(exc, protocol="mcp") + assert "AUTH_REQUIRED" in str(result) + + def test_timeout_error_maps_to_service_unavailable(self): + """ADCPTimeoutError maps to SERVICE_UNAVAILABLE code.""" + exc = ADCPTimeoutError("Request timed out", timeout=30.0) + result = translate_error(exc, protocol="mcp") + assert "SERVICE_UNAVAILABLE" in str(result) + + def test_connection_error_maps_to_service_unavailable(self): + """ADCPConnectionError maps to SERVICE_UNAVAILABLE code.""" + exc = ADCPConnectionError("Cannot reach upstream") + result = translate_error(exc, protocol="mcp") + assert "SERVICE_UNAVAILABLE" in str(result) + + def test_task_error_uses_original_code(self): + """ADCPTaskError preserves the original error code from the response.""" + err = Error(code="BUDGET_TOO_LOW", message="Budget below minimum") + exc = ADCPTaskError("create_media_buy", [err]) + result = translate_error(exc, protocol="mcp") + assert "BUDGET_TOO_LOW" in str(result) + + +# ============================================================================ +# translate_error → A2A +# ============================================================================ + + +class TestTranslateErrorToA2A: + """Test translate_error with protocol='a2a'.""" + + def test_returns_server_error(self): + """A2A translation returns a ServerError instance.""" + exc = ADCPError("something went wrong") + result = translate_error(exc, protocol="a2a") + assert isinstance(result, ServerError) + + def test_internal_error_wraps_internal(self): + """Generic ADCPError wraps InternalError (terminal/transient).""" + exc = ADCPError("something went wrong") + result = translate_error(exc, protocol="a2a") + assert isinstance(result.error, InternalError) + assert result.error.message == "something went wrong" + + def test_correctable_error_wraps_invalid_params(self): + """Error with correctable code wraps InvalidParamsError.""" + err = Error(code="VALIDATION_ERROR", message="Missing field") + result = translate_error(err, protocol="a2a") + assert isinstance(result.error, InvalidParamsError) + + def test_data_includes_recovery(self): + """A2A error data includes recovery classification.""" + exc = ADCPConnectionError("Cannot reach upstream") + result = translate_error(exc, protocol="a2a") + assert result.error.data["recovery"] == "transient" + + def test_data_includes_error_code(self): + """A2A error data includes the ADCP error code.""" + err = Error(code="BUDGET_TOO_LOW", message="Budget below minimum") + result = translate_error(err, protocol="a2a") + assert result.error.data["error_code"] == "BUDGET_TOO_LOW" + + def test_data_includes_suggestion(self): + """A2A error data includes suggestion when present.""" + exc = ADCPError("bad request", suggestion="Check the budget field") + result = translate_error(exc, protocol="a2a") + assert result.error.data["suggestion"] == "Check the budget field" + + def test_data_includes_details(self): + """A2A error data includes details from Error model.""" + err = Error( + code="BUDGET_EXCEEDED", + message="Budget exceeded", + details={"max_budget": 10000, "requested": 15000}, + ) + result = translate_error(err, protocol="a2a") + assert result.error.data["details"] == {"max_budget": 10000, "requested": 15000} + + def test_task_error_preserves_original_errors(self): + """ADCPTaskError passes through the original error list.""" + err1 = Error(code="BUDGET_TOO_LOW", message="Budget below minimum") + err2 = Error(code="AUDIENCE_TOO_SMALL", message="Audience too small") + exc = ADCPTaskError("create_media_buy", [err1, err2]) + result = translate_error(exc, protocol="a2a") + errors = result.error.data["errors"] + assert len(errors) == 2 + assert errors[0]["code"] == "BUDGET_TOO_LOW" + assert errors[1]["code"] == "AUDIENCE_TOO_SMALL" + + def test_auth_error_is_terminal(self): + """ADCPAuthenticationError gets terminal recovery.""" + exc = ADCPAuthenticationError("Forbidden") + result = translate_error(exc, protocol="a2a") + assert result.error.data["recovery"] == "terminal" + + def test_timeout_error_is_transient(self): + """ADCPTimeoutError gets transient recovery.""" + exc = ADCPTimeoutError("Timed out", timeout=30.0) + result = translate_error(exc, protocol="a2a") + assert result.error.data["recovery"] == "transient" + + +# ============================================================================ +# translate_error validation +# ============================================================================ + + +class TestTranslateErrorValidation: + """Test translate_error input validation.""" + + def test_rejects_unknown_protocol(self): + """Unknown protocol raises ValueError.""" + with pytest.raises(ValueError, match="protocol"): + translate_error(ADCPError("err"), protocol="grpc") # type: ignore[arg-type] + + def test_accepts_protocol_enum(self): + """Protocol enum values work.""" + err = Error(code="TEST", message="test") + result_mcp = translate_error(err, protocol=Protocol.MCP) + assert isinstance(result_mcp, ToolError) + + result_a2a = translate_error(err, protocol=Protocol.A2A) + assert isinstance(result_a2a, ServerError) + + def test_accepts_uppercase_protocol_string(self): + """Protocol strings are case-insensitive.""" + err = Error(code="TEST", message="test") + result = translate_error(err, protocol="MCP") # type: ignore[arg-type] + assert isinstance(result, ToolError) + + +# ============================================================================ +# normalize_request — structural transforms +# ============================================================================ + + +class TestNormalizeAccountId: + """Test account_id → account structural reshape.""" + + def test_reshapes_account_id_to_nested_object(self): + """account_id string becomes account: {account_id: "..."}.""" + params = {"account_id": "acct-123", "name": "Test"} + result = normalize_request(params) + + assert result["account"] == {"account_id": "acct-123"} + assert "account_id" not in result + + def test_does_not_overwrite_existing_account(self): + """If account already present, account_id is dropped.""" + params = {"account_id": "old", "account": {"account_id": "current"}} + result = normalize_request(params) + + assert result["account"] == {"account_id": "current"} + assert "account_id" not in result + + def test_no_account_id_is_noop(self): + """Params without account_id pass through unchanged.""" + params = {"account": {"account_id": "123"}} + result = normalize_request(params) + assert result == params + + +class TestNormalizeBrandManifest: + """Test brand_manifest URL → brand object.""" + + def test_parses_url_to_domain(self): + """brand_manifest URL is parsed to brand: {domain: hostname}.""" + params = {"brand_manifest": "https://example.com/brand.json"} + result = normalize_request(params) + + assert result["brand"] == {"domain": "example.com"} + assert "brand_manifest" not in result + + def test_does_not_overwrite_existing_brand(self): + """If brand already present, brand_manifest is dropped.""" + params = {"brand_manifest": "https://old.com", "brand": {"domain": "new.com"}} + result = normalize_request(params) + + assert result["brand"] == {"domain": "new.com"} + assert "brand_manifest" not in result + + def test_non_string_manifest_renamed(self): + """Non-string brand_manifest is passed through as brand.""" + params = {"brand_manifest": {"url": "https://example.com"}} + result = normalize_request(params) + + assert result["brand"] == {"url": "https://example.com"} + + +class TestNormalizePackages: + """Test package-level scalar-to-array transforms.""" + + def test_optimization_goal_to_array(self): + """optimization_goal string wraps to optimization_goals array.""" + params = {"packages": [{"optimization_goal": "cpa", "name": "pkg1"}]} + result = normalize_request(params) + + assert result["packages"][0]["optimization_goals"] == ["cpa"] + assert "optimization_goal" not in result["packages"][0] + + def test_catalog_to_array(self): + """catalog string wraps to catalogs array.""" + params = {"packages": [{"catalog": "retail"}]} + result = normalize_request(params) + + assert result["packages"][0]["catalogs"] == ["retail"] + assert "catalog" not in result["packages"][0] + + def test_does_not_overwrite_existing_array(self): + """If optimization_goals already present, scalar is dropped.""" + params = {"packages": [{"optimization_goal": "cpa", "optimization_goals": ["roas"]}]} + result = normalize_request(params) + + assert result["packages"][0]["optimization_goals"] == ["roas"] + assert "optimization_goal" not in result["packages"][0] + + def test_does_not_mutate_original_packages(self): + """Package dicts in the original params are not mutated.""" + pkg = {"optimization_goal": "cpa"} + params = {"packages": [pkg]} + normalize_request(params) + + assert "optimization_goal" in pkg # original unchanged + + +# ============================================================================ +# normalize_request — renames +# ============================================================================ + + +class TestNormalizeRenames: + """Test field renames (global and tool-scoped).""" + + def test_promoted_offerings_to_catalogs(self): + """promoted_offerings renames to catalogs.""" + params = {"promoted_offerings": ["offer-1"]} + result = normalize_request(params) + + assert result["catalogs"] == ["offer-1"] + assert "promoted_offerings" not in result + + def test_campaign_ref_scoped_to_create_media_buy(self): + """campaign_ref → buyer_campaign_ref only on create_media_buy.""" + params = {"campaign_ref": "camp-456"} + + result_scoped = normalize_request(params, task_name="create_media_buy") + assert result_scoped["buyer_campaign_ref"] == "camp-456" + assert "campaign_ref" not in result_scoped + + def test_campaign_ref_not_renamed_for_other_tasks(self): + """campaign_ref passes through for non-create_media_buy tasks.""" + params = {"campaign_ref": "camp-456"} + + result_other = normalize_request(params, task_name="update_media_buy") + assert result_other["campaign_ref"] == "camp-456" + assert "buyer_campaign_ref" not in result_other + + def test_campaign_ref_not_renamed_without_task_name(self): + """campaign_ref passes through when no task_name provided.""" + params = {"campaign_ref": "camp-456"} + result = normalize_request(params) + + assert result["campaign_ref"] == "camp-456" + + +# ============================================================================ +# normalize_request — general behavior +# ============================================================================ + + +class TestNormalizeGeneral: + """Test general normalize_request behavior.""" + + def test_returns_new_dict(self): + """normalize_request returns a copy, does not mutate input.""" + params = {"account_id": "acct-123"} + result = normalize_request(params) + + assert result is not params + assert "account_id" in params # original unchanged + + def test_empty_params(self): + """Empty params return empty dict.""" + result = normalize_request({}) + assert result == {} + + def test_unknown_fields_pass_through(self): + """Fields not in any rename map pass through unchanged.""" + params = {"custom_field": "value", "account_id": "acct-123"} + result = normalize_request(params) + + assert result["custom_field"] == "value" + assert result["account"] == {"account_id": "acct-123"} + + def test_all_transforms_combined(self): + """Multiple transforms apply in a single call.""" + params = { + "account_id": "acct-1", + "brand_manifest": "https://brand.co/manifest.json", + "promoted_offerings": ["offer-1"], + "campaign_ref": "camp-1", + "packages": [{"optimization_goal": "cpa", "catalog": "retail"}], + } + result = normalize_request(params, task_name="create_media_buy") + + assert result["account"] == {"account_id": "acct-1"} + assert result["brand"] == {"domain": "brand.co"} + assert result["catalogs"] == ["offer-1"] + assert result["buyer_campaign_ref"] == "camp-1" + assert result["packages"][0]["optimization_goals"] == ["cpa"] + assert result["packages"][0]["catalogs"] == ["retail"] + # Old names removed + assert "account_id" not in result + assert "brand_manifest" not in result + assert "promoted_offerings" not in result + assert "campaign_ref" not in result