package com.vincent.rsf.server.ai.service.mcp;
|
|
import com.fasterxml.jackson.databind.JsonNode;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.vincent.rsf.server.ai.constant.AiMcpConstants;
|
import com.vincent.rsf.server.ai.model.AiDiagnosticToolResult;
|
import com.vincent.rsf.server.ai.model.AiMcpToolDescriptor;
|
import com.vincent.rsf.server.system.entity.AiMcpMount;
|
import org.slf4j.Logger;
|
import org.slf4j.LoggerFactory;
|
import org.springframework.stereotype.Component;
|
|
import javax.annotation.Resource;
|
import java.io.BufferedReader;
|
import java.io.InputStream;
|
import java.io.InputStreamReader;
|
import java.io.OutputStream;
|
import java.net.HttpURLConnection;
|
import java.net.URI;
|
import java.net.URL;
|
import java.nio.charset.StandardCharsets;
|
import java.util.ArrayList;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
import java.util.UUID;
|
import java.util.concurrent.BlockingQueue;
|
import java.util.concurrent.LinkedBlockingQueue;
|
import java.util.concurrent.TimeUnit;
|
|
@Component
|
public class AiMcpSseClient {
|
|
private static final Logger logger = LoggerFactory.getLogger(AiMcpSseClient.class);
|
|
@Resource
|
private ObjectMapper objectMapper;
|
@Resource
|
private AiMcpPayloadMapper aiMcpPayloadMapper;
|
|
/**
|
* 通过 SSE + message endpoint 协议加载远程 MCP 工具目录。
|
*/
|
public List<AiMcpToolDescriptor> listTools(AiMcpMount mount) {
|
try (SseSession session = openSession(mount)) {
|
logger.info("AI MCP SSE listTools start: mountCode={}, url={}", mount.getMountCode(), mount.getUrl());
|
session.initialize();
|
JsonNode result = session.request("tools/list", new LinkedHashMap<String, Object>());
|
List<AiMcpToolDescriptor> output = new ArrayList<>();
|
JsonNode toolsNode = result.path("tools");
|
if (!toolsNode.isArray()) {
|
logger.warn("AI MCP SSE listTools no tools array: mountCode={}, url={}", mount.getMountCode(), mount.getUrl());
|
return output;
|
}
|
for (JsonNode item : toolsNode) {
|
AiMcpToolDescriptor descriptor = aiMcpPayloadMapper.toExternalToolDescriptor(mount, item);
|
if (descriptor != null) {
|
output.add(descriptor);
|
}
|
}
|
logger.info("AI MCP SSE listTools success: mountCode={}, url={}, toolCount={}",
|
mount.getMountCode(), mount.getUrl(), output.size());
|
return output;
|
} catch (Exception e) {
|
logger.warn("AI MCP SSE listTools failed: mountCode={}, url={}, message={}",
|
mount.getMountCode(), mount.getUrl(), e.getMessage());
|
throw new IllegalStateException("SSE MCP工具加载失败: " + e.getMessage(), e);
|
}
|
}
|
|
/**
|
* 通过 SSE + message endpoint 协议执行远程 MCP 工具。
|
*/
|
public AiDiagnosticToolResult callTool(AiMcpMount mount, String toolName, Map<String, Object> arguments) {
|
try (SseSession session = openSession(mount)) {
|
logger.info("AI MCP SSE callTool start: mountCode={}, url={}, toolName={}",
|
mount.getMountCode(), mount.getUrl(), toolName);
|
session.initialize();
|
Map<String, Object> params = new LinkedHashMap<>();
|
params.put("name", toolName);
|
params.put("arguments", arguments == null ? new LinkedHashMap<String, Object>() : arguments);
|
JsonNode result = session.request("tools/call", params);
|
AiDiagnosticToolResult toolResult = aiMcpPayloadMapper.toExternalToolResult(mount, toolName, result);
|
logger.info("AI MCP SSE callTool success: mountCode={}, url={}, toolName={}, isError={}, summaryLength={}",
|
mount.getMountCode(), mount.getUrl(), toolName,
|
"WARN".equalsIgnoreCase(toolResult.getSeverity()),
|
toolResult.getSummaryText() == null ? 0 : toolResult.getSummaryText().length());
|
return toolResult;
|
} catch (Exception e) {
|
logger.warn("AI MCP SSE callTool failed: mountCode={}, url={}, toolName={}, message={}",
|
mount.getMountCode(), mount.getUrl(), toolName, e.getMessage());
|
throw new IllegalStateException("SSE MCP工具调用失败: " + e.getMessage(), e);
|
}
|
}
|
|
/**
|
* 打开远程 SSE 流并创建会话包装对象。
|
*/
|
private SseSession openSession(AiMcpMount mount) throws Exception {
|
logger.info("AI MCP SSE opening stream: mountCode={}, url={}", mount.getMountCode(), mount.getUrl());
|
int timeoutMs = mount.getTimeoutMs() == null || mount.getTimeoutMs() <= 0 ? 10000 : mount.getTimeoutMs();
|
HttpURLConnection connection = (HttpURLConnection) new URL(mount.getUrl()).openConnection();
|
connection.setRequestMethod("GET");
|
connection.setDoInput(true);
|
connection.setConnectTimeout(timeoutMs);
|
connection.setReadTimeout(timeoutMs);
|
connection.setRequestProperty("Accept", "text/event-stream");
|
applyAuthHeaders(connection, mount);
|
InputStream inputStream = connection.getInputStream();
|
logger.info("AI MCP SSE stream connected: mountCode={}, url={}, responseCode={}",
|
mount.getMountCode(), mount.getUrl(), connection.getResponseCode());
|
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));
|
SseSession session = new SseSession(mount, connection, reader);
|
session.start();
|
return session;
|
}
|
|
private class SseSession implements AutoCloseable {
|
private final AiMcpMount mount;
|
private final HttpURLConnection connection;
|
private final BufferedReader reader;
|
private final BlockingQueue<SseEvent> events = new LinkedBlockingQueue<>();
|
private volatile boolean closed;
|
private Thread worker;
|
private String messageEndpoint;
|
|
private SseSession(AiMcpMount mount, HttpURLConnection connection, BufferedReader reader) {
|
this.mount = mount;
|
this.connection = connection;
|
this.reader = reader;
|
}
|
|
/**
|
* 启动后台读取线程,并等待远程服务返回 endpoint 事件。
|
*/
|
private void start() throws Exception {
|
worker = new Thread(this::readLoop, "ai-mcp-sse-client-" + mount.getMountCode());
|
worker.setDaemon(true);
|
worker.start();
|
logger.info("AI MCP SSE waiting endpoint event: mountCode={}, url={}", mount.getMountCode(), mount.getUrl());
|
SseEvent endpointEvent = waitEvent("endpoint");
|
messageEndpoint = resolveEndpoint(endpointEvent == null ? null : endpointEvent.getData());
|
logger.info("AI MCP SSE endpoint event received: mountCode={}, url={}, rawEndpoint={}, resolvedEndpoint={}",
|
mount.getMountCode(), mount.getUrl(),
|
endpointEvent == null ? null : endpointEvent.getData(),
|
messageEndpoint);
|
if (messageEndpoint == null || messageEndpoint.trim().isEmpty()) {
|
throw new IllegalStateException("SSE MCP未返回 message endpoint");
|
}
|
}
|
|
/**
|
* 完成一次 MCP initialize 握手。
|
*/
|
private void initialize() throws Exception {
|
logger.info("AI MCP SSE initialize start: mountCode={}, url={}, messageEndpoint={}",
|
mount.getMountCode(), mount.getUrl(), messageEndpoint);
|
Map<String, Object> params = new LinkedHashMap<>();
|
params.put("protocolVersion", AiMcpConstants.PROTOCOL_VERSION);
|
params.put("capabilities", new LinkedHashMap<String, Object>());
|
Map<String, Object> clientInfo = new LinkedHashMap<>();
|
clientInfo.put("name", "rsf-server");
|
clientInfo.put("version", AiMcpConstants.SERVER_VERSION);
|
params.put("clientInfo", clientInfo);
|
request("initialize", params);
|
notifyInitialized();
|
logger.info("AI MCP SSE initialize success: mountCode={}, url={}, messageEndpoint={}",
|
mount.getMountCode(), mount.getUrl(), messageEndpoint);
|
}
|
|
/**
|
* 向远程 MCP 发送 initialized 通知。
|
*/
|
private void notifyInitialized() throws Exception {
|
Map<String, Object> body = new LinkedHashMap<>();
|
body.put("jsonrpc", "2.0");
|
body.put("method", "notifications/initialized");
|
body.put("params", new LinkedHashMap<String, Object>());
|
postMessage(body, false);
|
logger.info("AI MCP SSE initialized notification sent: mountCode={}, messageEndpoint={}",
|
mount.getMountCode(), messageEndpoint);
|
}
|
|
/**
|
* 通过 message endpoint 发送一次 JSON-RPC 请求,并等待对应 message 事件响应。
|
*/
|
private JsonNode request(String method, Object params) throws Exception {
|
String id = UUID.randomUUID().toString().replace("-", "");
|
Map<String, Object> body = new LinkedHashMap<>();
|
body.put("jsonrpc", "2.0");
|
body.put("id", id);
|
body.put("method", method);
|
body.put("params", params == null ? new LinkedHashMap<String, Object>() : params);
|
logger.info("AI MCP SSE request send: mountCode={}, method={}, requestId={}, messageEndpoint={}",
|
mount.getMountCode(), method, id, messageEndpoint);
|
postMessage(body, true);
|
SseEvent response = waitEvent("message");
|
if (response == null || response.getData() == null || response.getData().trim().isEmpty()) {
|
throw new IllegalStateException("SSE MCP未返回响应消息");
|
}
|
logger.info("AI MCP SSE response received: mountCode={}, method={}, requestId={}, dataLength={}",
|
mount.getMountCode(), method, id, response.getData().length());
|
JsonNode root = objectMapper.readTree(response.getData());
|
if (!id.equals(root.path("id").asText(""))) {
|
logger.warn("AI MCP SSE response id mismatch: mountCode={}, method={}, requestId={}, responseId={}",
|
mount.getMountCode(), method, id, root.path("id").asText(""));
|
throw new IllegalStateException("SSE MCP响应ID不匹配");
|
}
|
if (root.has("error") && !root.get("error").isNull()) {
|
logger.warn("AI MCP SSE response error: mountCode={}, method={}, requestId={}, message={}",
|
mount.getMountCode(), method, id, root.path("error").path("message").asText(""));
|
throw new IllegalStateException(root.path("error").path("message").asText("MCP调用失败"));
|
}
|
return root.path("result");
|
}
|
|
/**
|
* 向 message endpoint 提交一条 HTTP POST 消息。
|
*/
|
private void postMessage(Map<String, Object> body, boolean expectSuccess) throws Exception {
|
HttpURLConnection post = null;
|
try {
|
post = (HttpURLConnection) new URL(messageEndpoint).openConnection();
|
post.setRequestMethod("POST");
|
post.setDoOutput(true);
|
post.setConnectTimeout(mount.getTimeoutMs() == null ? 10000 : mount.getTimeoutMs());
|
post.setReadTimeout(mount.getTimeoutMs() == null ? 10000 : mount.getTimeoutMs());
|
post.setRequestProperty("Content-Type", "application/json");
|
post.setRequestProperty("Accept", "application/json");
|
applyAuthHeaders(post, mount);
|
logger.info("AI MCP SSE post message: mountCode={}, endpoint={}, method={}",
|
mount.getMountCode(), messageEndpoint, body.get("method"));
|
try (OutputStream outputStream = post.getOutputStream()) {
|
outputStream.write(objectMapper.writeValueAsBytes(body));
|
outputStream.flush();
|
}
|
int statusCode = post.getResponseCode();
|
logger.info("AI MCP SSE post response: mountCode={}, endpoint={}, method={}, statusCode={}",
|
mount.getMountCode(), messageEndpoint, body.get("method"), statusCode);
|
if (expectSuccess && statusCode >= 400) {
|
throw new IllegalStateException("SSE MCP消息提交失败,状态码=" + statusCode);
|
}
|
} finally {
|
if (post != null) {
|
post.disconnect();
|
}
|
}
|
}
|
|
/**
|
* 在限定时间内等待某类 SSE 事件。
|
*/
|
private SseEvent waitEvent(String targetName) throws Exception {
|
long timeoutMs = mount.getTimeoutMs() == null ? 10000L : mount.getTimeoutMs().longValue();
|
long deadline = System.currentTimeMillis() + timeoutMs;
|
while (System.currentTimeMillis() < deadline) {
|
long remain = deadline - System.currentTimeMillis();
|
SseEvent event = events.poll(remain <= 0 ? 1L : remain, TimeUnit.MILLISECONDS);
|
if (event == null) {
|
continue;
|
}
|
logger.info("AI MCP SSE event dequeued: mountCode={}, target={}, actual={}, dataLength={}",
|
mount.getMountCode(), targetName, event.getName(), event.getData() == null ? 0 : event.getData().length());
|
if ("error".equals(event.getName())) {
|
throw new IllegalStateException("SSE MCP事件读取失败: " + event.getData());
|
}
|
if (targetName.equals(event.getName())) {
|
return event;
|
}
|
}
|
logger.warn("AI MCP SSE wait event timeout: mountCode={}, target={}, timeoutMs={}",
|
mount.getMountCode(), targetName, timeoutMs);
|
throw new IllegalStateException("等待SSE事件超时: " + targetName);
|
}
|
|
/**
|
* 后台持续读取 SSE 流,并把事件转发到队列供主线程消费。
|
*/
|
private void readLoop() {
|
String eventName = "message";
|
StringBuilder dataBuilder = new StringBuilder();
|
try {
|
String line;
|
while (!closed && (line = reader.readLine()) != null) {
|
if (line.startsWith("event:")) {
|
eventName = line.substring(6).trim();
|
continue;
|
}
|
if (line.startsWith("data:")) {
|
dataBuilder.append(line.substring(5).trim()).append('\n');
|
continue;
|
}
|
if (line.trim().isEmpty()) {
|
if (dataBuilder.length() > 0) {
|
logger.info("AI MCP SSE raw event read: mountCode={}, event={}, dataLength={}",
|
mount.getMountCode(), eventName, dataBuilder.length());
|
events.offer(new SseEvent(eventName, dataBuilder.toString().trim()));
|
}
|
eventName = "message";
|
dataBuilder.setLength(0);
|
}
|
}
|
} catch (Exception e) {
|
if (!closed) {
|
logger.warn("AI MCP SSE read loop failed: mountCode={}, url={}, message={}",
|
mount.getMountCode(), mount.getUrl(), e.getMessage());
|
events.offer(new SseEvent("error", e.getMessage()));
|
}
|
} finally {
|
try {
|
reader.close();
|
} catch (Exception ignore) {
|
}
|
}
|
}
|
|
/**
|
* 解析远程返回的 message endpoint。
|
* 当对方错误返回回环地址时,会重写成当前挂载 URL 的主机地址。
|
*/
|
private String resolveEndpoint(String rawEndpoint) {
|
if (rawEndpoint == null || rawEndpoint.trim().isEmpty()) {
|
return null;
|
}
|
try {
|
URI baseUri = new URI(mount.getUrl());
|
URI endpointUri = baseUri.resolve(rawEndpoint.trim());
|
String host = endpointUri.getHost();
|
if (isLoopbackHost(host) && !sameHost(host, baseUri.getHost())) {
|
endpointUri = new URI(
|
baseUri.getScheme(),
|
endpointUri.getUserInfo(),
|
baseUri.getHost(),
|
endpointUri.getPort() > 0 ? endpointUri.getPort() : baseUri.getPort(),
|
endpointUri.getPath(),
|
endpointUri.getQuery(),
|
endpointUri.getFragment());
|
logger.info("AI MCP SSE endpoint rewritten: mountCode={}, rawEndpoint={}, rewrittenEndpoint={}",
|
mount.getMountCode(), rawEndpoint, endpointUri);
|
}
|
return endpointUri.toString();
|
} catch (Exception e) {
|
return rawEndpoint.trim();
|
}
|
}
|
|
/**
|
* 判断地址是否为本机回环地址。
|
*/
|
private boolean isLoopbackHost(String host) {
|
return "127.0.0.1".equals(host) || "localhost".equalsIgnoreCase(host) || "::1".equals(host);
|
}
|
|
/**
|
* 判断两个 host 是否等价。
|
*/
|
private boolean sameHost(String left, String right) {
|
if (left == null || right == null) {
|
return false;
|
}
|
return left.equalsIgnoreCase(right);
|
}
|
|
/**
|
* 关闭 SSE 会话,并异步清理底层连接资源。
|
*/
|
@Override
|
public void close() throws Exception {
|
closed = true;
|
logger.info("AI MCP SSE closing session: mountCode={}, url={}", mount.getMountCode(), mount.getUrl());
|
if (worker != null) {
|
worker.interrupt();
|
}
|
Thread cleanup = new Thread(() -> {
|
try {
|
connection.disconnect();
|
} catch (Exception ignore) {
|
}
|
try {
|
reader.close();
|
} catch (Exception ignore) {
|
}
|
logger.info("AI MCP SSE cleanup finished: mountCode={}, url={}", mount.getMountCode(), mount.getUrl());
|
}, "ai-mcp-sse-cleanup-" + mount.getMountCode());
|
cleanup.setDaemon(true);
|
cleanup.start();
|
}
|
}
|
|
/**
|
* 按挂载配置写入鉴权请求头。
|
*/
|
private void applyAuthHeaders(HttpURLConnection connection, AiMcpMount mount) {
|
if (mount == null || mount.getAuthType() == null || mount.getAuthValue() == null || mount.getAuthValue().trim().isEmpty()) {
|
return;
|
}
|
String authType = mount.getAuthType().trim().toUpperCase();
|
if (AiMcpConstants.AUTH_TYPE_BEARER.equals(authType)) {
|
connection.setRequestProperty("Authorization", "Bearer " + mount.getAuthValue().trim());
|
} else if (AiMcpConstants.AUTH_TYPE_API_KEY.equals(authType)) {
|
connection.setRequestProperty("X-API-Key", mount.getAuthValue().trim());
|
}
|
}
|
|
private static class SseEvent {
|
private final String name;
|
private final String data;
|
|
/**
|
* 封装一条 SSE 事件。
|
*/
|
private SseEvent(String name, String data) {
|
this.name = name;
|
this.data = data;
|
}
|
|
/**
|
* 返回事件名。
|
*/
|
public String getName() {
|
return name;
|
}
|
|
/**
|
* 返回事件数据文本。
|
*/
|
public String getData() {
|
return data;
|
}
|
}
|
}
|