Skip to content
Merged
63 changes: 62 additions & 1 deletion apps/sim/app/api/mcp/oauth/start/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ vi.mock('@/lib/auth/hybrid', () => hybridAuthMock)
vi.mock('@/lib/workspaces/permissions/utils', () => permissionsMock)
vi.mock('@/lib/mcp/oauth', () => mcpOauthMock)

import { GET } from './route'
import { GET, surfaceOauthError } from './route'

describe('MCP OAuth start route', () => {
beforeEach(() => {
Expand Down Expand Up @@ -134,4 +134,65 @@ describe('MCP OAuth start route', () => {
expect(body.error).toBe('OAuth authorization already in progress for this server')
expect(mockMcpAuth).not.toHaveBeenCalled()
})

it('does not leak non-OAuth internal error details to the client', async () => {
mcpOauthMockFns.mockGetOrCreateOauthRow.mockRejectedValueOnce(
new Error('connect ECONNREFUSED 10.0.0.5:5432 (internal-db-host)')
)
const request = new NextRequest(
'http://localhost:3000/api/mcp/oauth/start?workspaceId=workspace-1&serverId=server-1'
)

const response = await GET(request)
const body = await response.json()

expect(response.status).toBe(500)
expect(body.error).toBe('Failed to start OAuth flow')
expect(body.error).not.toContain('ECONNREFUSED')
expect(body.error).not.toContain('internal-db-host')
})
})

describe('surfaceOauthError', () => {
it('uses typed OAuthError errorCode and message for spec-compliant errors', async () => {
const { InvalidGrantError } = await import('@modelcontextprotocol/sdk/server/auth/errors.js')
const err = new InvalidGrantError('Refresh token expired')
expect(surfaceOauthError(err)).toBe('invalid_grant: Refresh token expired')
})

it('parses Raw body envelope for ServerError fallbacks (non-spec vendors)', async () => {
const { ServerError } = await import('@modelcontextprotocol/sdk/server/auth/errors.js')
const err = new ServerError(
'HTTP 400: Invalid OAuth error response: zod error. Raw body: {"code":400,"message":"redirect URI https://example.com/cb is not allowed","retryable":false}'
)
expect(surfaceOauthError(err)).toBe(
'Authorization server: redirect URI https://example.com/cb is not allowed'
)
})

it('prefers error_description over message over error in fallback envelope', async () => {
const { ServerError } = await import('@modelcontextprotocol/sdk/server/auth/errors.js')
const err = new ServerError(
'HTTP 400: Invalid OAuth error response: zod. Raw body: {"error":"invalid_grant","error_description":"the description","message":"the message"}'
)
expect(surfaceOauthError(err)).toBe('Authorization server: the description')
})

it('returns first line of generic errors', () => {
const err = new Error('Network blip\n at fetch (...)')
expect(surfaceOauthError(err)).toBe('Network blip')
})

it('truncates messages longer than 250 chars with ellipsis', async () => {
const { InvalidGrantError } = await import('@modelcontextprotocol/sdk/server/auth/errors.js')
const longMessage = 'x'.repeat(300)
const result = surfaceOauthError(new InvalidGrantError(longMessage))
expect(result.endsWith('…')).toBe(true)
expect(result.length).toBe(251)
})

it('returns generic fallback for non-Error values', () => {
expect(surfaceOauthError(null)).toBe('Failed to start OAuth flow')
expect(surfaceOauthError(undefined)).toBe('Failed to start OAuth flow')
})
})
40 changes: 39 additions & 1 deletion apps/sim/app/api/mcp/oauth/start/route.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { auth as mcpAuth } from '@modelcontextprotocol/sdk/client/auth.js'
import { OAuthError, ServerError } from '@modelcontextprotocol/sdk/server/auth/errors.js'
import { db } from '@sim/db'
import { mcpServers } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
Expand All @@ -23,6 +24,39 @@ import { createMcpErrorResponse } from '@/lib/mcp/utils'

const logger = createLogger('McpOauthStartAPI')
const OAUTH_START_TTL_MS = 10 * 60 * 1000
const MAX_SURFACED_ERROR_LENGTH = 250

export function surfaceOauthError(error: unknown): string {
// Spec-compliant OAuth servers throw typed subclasses with clean RFC 6749 fields.
if (error instanceof OAuthError && !(error instanceof ServerError)) {
return truncate(`${error.errorCode}: ${error.message}`)
}

// ServerError wraps non-spec response bodies as "HTTP N: Invalid OAuth error
// response: ... Raw body: {...}". Dig the vendor message out of the JSON tail.
if (error instanceof Error) {
const rawBodyMatch = error.message.match(/Raw body:\s*(\{[\s\S]*\})\s*$/)
if (rawBodyMatch) {
try {
const body = JSON.parse(rawBodyMatch[1]) as Record<string, unknown>
const vendorMessage =
(typeof body.error_description === 'string' && body.error_description) ||
(typeof body.message === 'string' && body.message) ||
(typeof body.error === 'string' && body.error) ||
null
if (vendorMessage) return truncate(`Authorization server: ${vendorMessage}`)
} catch {}
}
return truncate(error.message.split('\n')[0] || 'Failed to start OAuth flow')
}
return 'Failed to start OAuth flow'
}

function truncate(message: string): string {
return message.length > MAX_SURFACED_ERROR_LENGTH
? `${message.slice(0, MAX_SURFACED_ERROR_LENGTH)}…`
: message
}

export const dynamic = 'force-dynamic'

Expand Down Expand Up @@ -116,7 +150,11 @@ export const GET = withRouteHandler(
}
} catch (error) {
logger.error('Error starting MCP OAuth flow:', error)
return createMcpErrorResponse(toError(error), 'Failed to start OAuth flow', 500)
// Only surface OAuth-flow errors verbatim; everything else (DB, decryption,
// network) gets a generic message to avoid leaking internal details.
const userMessage =
error instanceof OAuthError ? surfaceOauthError(error) : 'Failed to start OAuth flow'
return createMcpErrorResponse(toError(error), userMessage, 500)
}
})
)
241 changes: 241 additions & 0 deletions apps/sim/lib/mcp/service.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
/**
* @vitest-environment node
*/
import { beforeEach, describe, expect, it, vi } from 'vitest'

const {
MockMcpClient,
mockListTools,
mockConnect,
mockDisconnect,
mockGetWorkspaceServersRows,
mockResolveEnvVars,
mockValidateDomain,
mockValidateSsrf,
mockIsDomainAllowed,
} = vi.hoisted(() => {
const mockListTools = vi.fn()
const mockConnect = vi.fn()
const mockDisconnect = vi.fn()
return {
MockMcpClient: vi.fn().mockImplementation(() => ({
connect: mockConnect,
disconnect: mockDisconnect,
listTools: mockListTools,
hasListChangedCapability: vi.fn(() => false),
onClose: vi.fn(),
getNegotiatedVersion: vi.fn(() => '2025-06-18'),
})),
mockListTools,
mockConnect,
mockDisconnect,
mockGetWorkspaceServersRows: vi.fn(),
mockResolveEnvVars: vi.fn(),
mockValidateDomain: vi.fn(),
mockValidateSsrf: vi.fn(),
mockIsDomainAllowed: vi.fn(() => true),
}
})

vi.mock('@sim/db', () => {
const setter = vi.fn().mockReturnValue({ where: vi.fn().mockResolvedValue(undefined) })
return {
db: {
select: vi.fn().mockReturnValue({
from: vi.fn().mockReturnValue({
where: (...args: unknown[]) => mockGetWorkspaceServersRows(...args),
}),
}),
update: vi.fn().mockReturnValue({ set: setter }),
insert: vi.fn(),
delete: vi.fn(),
},
}
})

vi.mock('@/lib/mcp/client', () => ({
McpClient: MockMcpClient,
}))

vi.mock('@/lib/mcp/connection-manager', () => ({
mcpConnectionManager: null,
}))

vi.mock('@/lib/mcp/domain-check', () => ({
isMcpDomainAllowed: (...args: unknown[]) => mockIsDomainAllowed(...args),
validateMcpDomain: (...args: unknown[]) => mockValidateDomain(...args),
validateMcpServerSsrf: (...args: unknown[]) => mockValidateSsrf(...args),
}))

vi.mock('@/lib/mcp/oauth', () => ({
getOrCreateOauthRow: vi.fn(),
loadPreregisteredClient: vi.fn(),
SimMcpOauthProvider: vi.fn(),
withMcpOauthRefreshLock: vi.fn(),
}))

vi.mock('@/lib/mcp/resolve-config', () => ({
resolveMcpConfigEnvVars: (...args: unknown[]) => mockResolveEnvVars(...args),
}))

import { mcpService } from '@/lib/mcp/service'
import { McpOauthAuthorizationRequiredError } from '@/lib/mcp/types'

const WORKSPACE_ID = 'workspace-test'
const USER_ID = 'user-test'

function dbRow(id: string, name: string, overrides: Record<string, unknown> = {}) {
return {
id,
name,
description: null,
transport: 'streamable-http',
url: `https://${id}.example.com/mcp`,
authType: 'headers',
workspaceId: WORKSPACE_ID,
headers: {},
timeout: 30000,
retries: 3,
enabled: true,
deletedAt: null,
createdAt: new Date('2026-01-01T00:00:00Z'),
updatedAt: new Date('2026-01-01T00:00:00Z'),
...overrides,
}
}

function tool(name: string, serverId: string) {
return {
name,
description: name,
inputSchema: { type: 'object' },
serverId,
serverName: serverId,
}
}

describe('McpService.discoverTools per-server caching', () => {
beforeEach(async () => {
vi.clearAllMocks()
mockIsDomainAllowed.mockReturnValue(true)
mockValidateSsrf.mockResolvedValue('1.2.3.4')
mockValidateDomain.mockImplementation(() => undefined)
mockResolveEnvVars.mockImplementation((config: { url: string }) =>
Promise.resolve({ config: { ...config, url: config.url }, missingVars: [] })
)
mockConnect.mockResolvedValue(undefined)
mockDisconnect.mockResolvedValue(undefined)
// The McpService singleton holds cache state across imports.
await mcpService.clearCache()
})

it('caches each server independently after first discovery', async () => {
mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A'), dbRow('mcp-b', 'B')])
mockListTools
.mockResolvedValueOnce([tool('a1', 'mcp-a')])
.mockResolvedValueOnce([tool('b1', 'mcp-b')])

const first = await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
expect(first.map((t) => t.name).sort()).toEqual(['a1', 'b1'])
expect(mockListTools).toHaveBeenCalledTimes(2)

mockListTools.mockClear()
const second = await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
expect(second.map((t) => t.name).sort()).toEqual(['a1', 'b1'])
expect(mockListTools).not.toHaveBeenCalled()
})

it("one server failing does not poison another server's cache", async () => {
mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A'), dbRow('mcp-b', 'B')])
mockListTools
.mockResolvedValueOnce([tool('a1', 'mcp-a')])
.mockRejectedValueOnce(new Error('Request timed out'))

const first = await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
expect(first.map((t) => t.name)).toEqual(['a1'])

mockListTools.mockClear()
mockListTools.mockResolvedValueOnce([tool('b1', 'mcp-b')])

const second = await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
expect(second.map((t) => t.name).sort()).toEqual(['a1', 'b1'])
expect(mockListTools).toHaveBeenCalledTimes(1)
})

it("forceRefresh bypasses every server's cache", async () => {
mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A'), dbRow('mcp-b', 'B')])
mockListTools
.mockResolvedValueOnce([tool('a1', 'mcp-a')])
.mockResolvedValueOnce([tool('b1', 'mcp-b')])

await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
expect(mockListTools).toHaveBeenCalledTimes(2)

mockListTools.mockClear()
mockListTools
.mockResolvedValueOnce([tool('a2', 'mcp-a')])
.mockResolvedValueOnce([tool('b2', 'mcp-b')])

const refreshed = await mcpService.discoverTools(USER_ID, WORKSPACE_ID, true)
expect(refreshed.map((t) => t.name).sort()).toEqual(['a2', 'b2'])
expect(mockListTools).toHaveBeenCalledTimes(2)
})

it('OAuth-pending is treated as a soft skip without poisoning cache', async () => {
mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A'), dbRow('mcp-b', 'B')])
mockListTools
.mockResolvedValueOnce([tool('a1', 'mcp-a')])
.mockRejectedValueOnce(new McpOauthAuthorizationRequiredError('mcp-b', 'B'))

const first = await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
expect(first.map((t) => t.name)).toEqual(['a1'])

mockListTools.mockClear()
mockListTools.mockRejectedValueOnce(new McpOauthAuthorizationRequiredError('mcp-b', 'B'))

await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
expect(mockListTools).toHaveBeenCalledTimes(1)
})

it('returns empty array immediately when workspace has no servers', async () => {
mockGetWorkspaceServersRows.mockResolvedValue([])

const result = await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
expect(result).toEqual([])
expect(mockListTools).not.toHaveBeenCalled()
expect(MockMcpClient).not.toHaveBeenCalled()
})

it('clearCache(workspaceId) drops cached tools so next call re-fetches', async () => {
mockGetWorkspaceServersRows.mockResolvedValue([dbRow('mcp-a', 'A')])
mockListTools.mockResolvedValueOnce([tool('a1', 'mcp-a')])

await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
expect(mockListTools).toHaveBeenCalledTimes(1)

await mcpService.clearCache(WORKSPACE_ID)

mockListTools.mockClear()
mockListTools.mockResolvedValueOnce([tool('a1', 'mcp-a')])
await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
expect(mockListTools).toHaveBeenCalledTimes(1)
})

it('isolates caches across workspaces', async () => {
const otherWorkspaceId = 'workspace-other'
mockGetWorkspaceServersRows
.mockResolvedValueOnce([dbRow('mcp-a', 'A')])
.mockResolvedValueOnce([dbRow('mcp-a', 'A', { workspaceId: otherWorkspaceId })])

mockListTools
.mockResolvedValueOnce([tool('a1', 'mcp-a')])
.mockResolvedValueOnce([tool('a-other', 'mcp-a')])

const first = await mcpService.discoverTools(USER_ID, WORKSPACE_ID)
const second = await mcpService.discoverTools(USER_ID, otherWorkspaceId)

expect(first.map((t) => t.name)).toEqual(['a1'])
expect(second.map((t) => t.name)).toEqual(['a-other'])
expect(mockListTools).toHaveBeenCalledTimes(2)
})
})
Loading
Loading