zhou zhou
8 小时以前 1d0ab9996661fdc66037870d4b98037f2dfa079a
#AI.工具调用可视化
1个文件已添加
2个文件已修改
286 ■■■■■ 已修改文件
rsf-admin/src/layout/AiChatDrawer.jsx 127 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
rsf-server/src/main/java/com/vincent/rsf/server/ai/dto/AiChatToolEventDto.java 29 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java 130 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
rsf-admin/src/layout/AiChatDrawer.jsx
@@ -6,6 +6,7 @@
    Box,
    Button,
    Chip,
    Collapse,
    Dialog,
    DialogActions,
    DialogContent,
@@ -36,6 +37,8 @@
import PushPinOutlinedIcon from "@mui/icons-material/PushPinOutlined";
import PushPinIcon from "@mui/icons-material/PushPin";
import SearchOutlinedIcon from "@mui/icons-material/SearchOutlined";
import ExpandMoreOutlinedIcon from "@mui/icons-material/ExpandMoreOutlined";
import ExpandLessOutlinedIcon from "@mui/icons-material/ExpandLessOutlined";
import { clearAiSessionMemory, getAiRuntime, getAiSessions, pinAiSession, removeAiSession, renameAiSession, retainAiSessionLatestRound, streamAiChat } from "@/api/ai/chat";
const DEFAULT_PROMPT_CODE = "home.default";
@@ -56,6 +59,8 @@
    const [sessions, setSessions] = useState([]);
    const [persistedMessages, setPersistedMessages] = useState([]);
    const [messages, setMessages] = useState([]);
    const [toolEvents, setToolEvents] = useState([]);
    const [expandedToolIds, setExpandedToolIds] = useState([]);
    const [input, setInput] = useState("");
    const [loadingRuntime, setLoadingRuntime] = useState(false);
    const [streaming, setStreaming] = useState(false);
@@ -91,6 +96,8 @@
    }, []);
    const initializeDrawer = async (targetSessionId = null) => {
        setToolEvents([]);
        setExpandedToolIds([]);
        await Promise.all([
            loadRuntime(targetSessionId),
            loadSessions(sessionKeyword),
@@ -132,6 +139,8 @@
        setSessionId(null);
        setPersistedMessages([]);
        setMessages([]);
        setToolEvents([]);
        setExpandedToolIds([]);
        setUsage(null);
        setDrawerError("");
    };
@@ -151,6 +160,8 @@
            return;
        }
        setUsage(null);
        setToolEvents([]);
        setExpandedToolIds([]);
        await loadRuntime(targetSessionId);
    };
@@ -288,6 +299,32 @@
        return next;
    };
    const upsertToolEvent = (payload) => {
        if (!payload?.toolCallId) {
            return;
        }
        setToolEvents((prev) => {
            const index = prev.findIndex((item) => item.toolCallId === payload.toolCallId);
            if (index < 0) {
                return [...prev, payload];
            }
            const next = [...prev];
            next[index] = { ...next[index], ...payload };
            return next;
        });
    };
    const toggleToolEventExpanded = (toolCallId) => {
        if (!toolCallId) {
            return;
        }
        setExpandedToolIds((prev) => (
            prev.includes(toolCallId)
                ? prev.filter((item) => item !== toolCallId)
                : [...prev, toolCallId]
        ));
    };
    const handleSend = async () => {
        const content = input.trim();
        if (!content || streaming) {
@@ -298,6 +335,8 @@
        setInput("");
        setUsage(null);
        setDrawerError("");
        setToolEvents([]);
        setExpandedToolIds([]);
        setMessages(ensureAssistantPlaceholder(nextMessages));
        setStreaming(true);
@@ -329,6 +368,9 @@
                        }
                        if (eventName === "delta") {
                            appendAssistantDelta(payload?.content || "");
                        }
                        if (eventName === "tool_start" || eventName === "tool_result" || eventName === "tool_error") {
                            upsertToolEvent(payload);
                        }
                        if (eventName === "done") {
                            setUsage(payload);
@@ -382,7 +424,7 @@
                "& .MuiDrawer-paper": {
                    top: 0,
                    height: "100vh",
                    width: { xs: "100vw", md: "50vw" },
                    width: { xs: "100vw", md: "70vw" },
                },
            }}
        >
@@ -502,6 +544,89 @@
                        </Box>
                    </Box>
                    <Box
                        width={{ xs: "100%", md: 280 }}
                        borderRight={{ xs: "none", md: "1px solid rgba(224, 224, 224, 1)" }}
                        borderBottom={{ xs: "1px solid rgba(224, 224, 224, 1)", md: "none" }}
                        display="flex"
                        flexDirection="column"
                        minHeight={0}
                    >
                        <Box px={2} py={1.5} display="flex" flexDirection="column" minHeight={0}>
                            <Typography variant="subtitle2" mb={1}>
                                工具调用轨迹
                            </Typography>
                            <Paper variant="outlined" sx={{ flex: 1, minHeight: { xs: 140, md: 0 }, overflow: "hidden", bgcolor: "grey.50" }}>
                                {!toolEvents.length ? (
                                    <Box px={1.5} py={1.25}>
                                        <Typography variant="body2" color="text.secondary">
                                            当前轮未触发工具调用
                                        </Typography>
                                    </Box>
                                ) : (
                                    <Stack spacing={1} sx={{ p: 1.25, maxHeight: { xs: 220, md: "calc(100vh - 180px)" }, overflow: "auto" }}>
                                        {toolEvents.map((item) => (
                                            <Paper
                                                key={item.toolCallId}
                                                variant="outlined"
                                                sx={{
                                                    p: 1.25,
                                                    bgcolor: item.status === "FAILED" ? "error.lighter" : "common.white",
                                                    borderColor: item.status === "FAILED" ? "error.light" : "divider",
                                                }}
                                            >
                                                <Stack direction="row" spacing={1} alignItems="center" flexWrap="wrap" useFlexGap>
                                                    <Typography variant="body2" fontWeight={700}>
                                                        {item.toolName || "未知工具"}
                                                    </Typography>
                                                    <Chip
                                                        size="small"
                                                        color={item.status === "FAILED" ? "error" : item.status === "COMPLETED" ? "success" : "info"}
                                                        label={item.status === "FAILED" ? "失败" : item.status === "COMPLETED" ? "完成" : "执行中"}
                                                    />
                                                    {item.durationMs != null && (
                                                        <Typography variant="caption" color="text.secondary">
                                                            {item.durationMs} ms
                                                        </Typography>
                                                    )}
                                                    {(item.inputSummary || item.outputSummary || item.errorMessage) && (
                                                        <Button
                                                            size="small"
                                                            onClick={() => toggleToolEventExpanded(item.toolCallId)}
                                                            endIcon={expandedToolIds.includes(item.toolCallId)
                                                                ? <ExpandLessOutlinedIcon fontSize="small" />
                                                                : <ExpandMoreOutlinedIcon fontSize="small" />}
                                                            sx={{ ml: "auto", minWidth: "auto", px: 0.5 }}
                                                        >
                                                            {expandedToolIds.includes(item.toolCallId) ? "收起详情" : "查看详情"}
                                                        </Button>
                                                    )}
                                                </Stack>
                                                <Collapse in={expandedToolIds.includes(item.toolCallId)} timeout="auto" unmountOnExit>
                                                    {!!item.inputSummary && (
                                                        <Typography variant="caption" display="block" sx={{ mt: 0.75, whiteSpace: "pre-wrap" }}>
                                                            入参: {item.inputSummary}
                                                        </Typography>
                                                    )}
                                                    {!!item.outputSummary && (
                                                        <Typography variant="caption" display="block" sx={{ mt: 0.75, whiteSpace: "pre-wrap" }}>
                                                            结果摘要: {item.outputSummary}
                                                        </Typography>
                                                    )}
                                                    {!!item.errorMessage && (
                                                        <Typography variant="caption" color="error.main" display="block" sx={{ mt: 0.75, whiteSpace: "pre-wrap" }}>
                                                            错误: {item.errorMessage}
                                                        </Typography>
                                                    )}
                                                </Collapse>
                                            </Paper>
                                        ))}
                                    </Stack>
                                )}
                            </Paper>
                        </Box>
                    </Box>
                    <Box flex={1} display="flex" flexDirection="column" minHeight={0}>
                        <Box px={2} py={1.5}>
                            <Stack direction="row" spacing={1} flexWrap="wrap" useFlexGap>
rsf-server/src/main/java/com/vincent/rsf/server/ai/dto/AiChatToolEventDto.java
New file
@@ -0,0 +1,29 @@
package com.vincent.rsf.server.ai.dto;
import lombok.Builder;
import lombok.Data;
@Data
@Builder
public class AiChatToolEventDto {
    private String requestId;
    private Long sessionId;
    private String toolCallId;
    private String toolName;
    private String status;
    private String inputSummary;
    private String outputSummary;
    private String errorMessage;
    private Long durationMs;
    private Long timestamp;
}
rsf-server/src/main/java/com/vincent/rsf/server/ai/service/impl/AiChatServiceImpl.java
@@ -14,6 +14,7 @@
import com.vincent.rsf.server.ai.dto.AiChatSessionDto;
import com.vincent.rsf.server.ai.dto.AiChatSessionPinRequest;
import com.vincent.rsf.server.ai.dto.AiChatSessionRenameRequest;
import com.vincent.rsf.server.ai.dto.AiChatToolEventDto;
import com.vincent.rsf.server.ai.dto.AiResolvedConfig;
import com.vincent.rsf.server.ai.entity.AiParam;
import com.vincent.rsf.server.ai.entity.AiPrompt;
@@ -34,6 +35,7 @@
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.tool.DefaultToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingManager;
@@ -66,6 +68,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.AtomicLong;
@Slf4j
@Service
@@ -144,6 +147,7 @@
        String requestId = request.getRequestId();
        long startedAt = System.currentTimeMillis();
        AtomicReference<Long> firstTokenAtRef = new AtomicReference<>();
        AtomicLong toolCallSequence = new AtomicLong(0);
        Long sessionId = request.getSessionId();
        String model = null;
        try {
@@ -182,9 +186,13 @@
                log.info("AI chat started, requestId={}, userId={}, tenantId={}, sessionId={}, model={}",
                        requestId, userId, tenantId, session.getId(), resolvedModel);
                ToolCallback[] observableToolCallbacks = wrapToolCallbacks(
                        runtime.getToolCallbacks(), emitter, requestId, session.getId(), toolCallSequence
                );
                Prompt prompt = new Prompt(
                        buildPromptMessages(memory, mergedMessages, config.getPrompt(), request.getMetadata()),
                        buildChatOptions(config.getAiParam(), runtime.getToolCallbacks(), userId, request.getMetadata())
                        buildChatOptions(config.getAiParam(), observableToolCallbacks, userId, tenantId,
                                requestId, session.getId(), request.getMetadata())
                );
                OpenAiChatModel chatModel = createChatModel(config.getAiParam());
                if (Boolean.FALSE.equals(config.getAiParam().getStreamingEnabled())) {
@@ -392,7 +400,8 @@
                .build();
    }
    private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Map<String, Object> metadata) {
    private OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId,
                                               String requestId, Long sessionId, Map<String, Object> metadata) {
        if (userId == null) {
            throw buildAiException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "当前登录用户不存在", null);
        }
@@ -404,16 +413,41 @@
                .streamUsage(true)
                .user(String.valueOf(userId));
        if (!Cools.isEmpty(toolCallbacks)) {
            builder.toolCallbacks(Arrays.asList(toolCallbacks));
            builder.toolCallbacks(Arrays.stream(toolCallbacks).toList());
        }
        Map<String, Object> toolContext = new LinkedHashMap<>();
        toolContext.put("userId", userId);
        toolContext.put("tenantId", tenantId);
        toolContext.put("requestId", requestId);
        toolContext.put("sessionId", sessionId);
        Map<String, String> metadataMap = new LinkedHashMap<>();
        if (metadata != null) {
            metadata.forEach((key, value) -> metadataMap.put(key, value == null ? "" : String.valueOf(value)));
            metadata.forEach((key, value) -> {
                String normalized = value == null ? "" : String.valueOf(value);
                metadataMap.put(key, normalized);
                toolContext.put(key, normalized);
            });
        }
        builder.toolContext(toolContext);
        if (!metadataMap.isEmpty()) {
            builder.metadata(metadataMap);
        }
        return builder.build();
    }
    private ToolCallback[] wrapToolCallbacks(ToolCallback[] toolCallbacks, SseEmitter emitter, String requestId,
                                             Long sessionId, AtomicLong toolCallSequence) {
        if (Cools.isEmpty(toolCallbacks)) {
            return toolCallbacks;
        }
        List<ToolCallback> wrappedCallbacks = new ArrayList<>();
        for (ToolCallback callback : toolCallbacks) {
            if (callback == null) {
                continue;
            }
            wrappedCallbacks.add(new ObservableToolCallback(callback, emitter, requestId, sessionId, toolCallSequence));
        }
        return wrappedCallbacks.toArray(new ToolCallback[0]);
    }
    private List<Message> buildPromptMessages(AiChatMemoryDto memory, List<AiChatMessageDto> sourceMessages, AiPrompt aiPrompt, Map<String, Object> metadata) {
@@ -516,6 +550,17 @@
        return response.getResult().getOutput().getText();
    }
    private String summarizeToolPayload(String content, int maxLength) {
        if (!StringUtils.hasText(content)) {
            return null;
        }
        String normalized = content.trim()
                .replace("\r", " ")
                .replace("\n", " ")
                .replaceAll("\\s+", " ");
        return normalized.length() > maxLength ? normalized.substring(0, maxLength) : normalized;
    }
    private void emitDone(SseEmitter emitter, String requestId, ChatResponseMetadata metadata, String fallbackModel, Long sessionId, long startedAt, Long firstTokenAt) {
        Usage usage = metadata == null ? null : metadata.getUsage();
        emitStrict(emitter, "done", AiChatDoneDto.builder()
@@ -584,4 +629,81 @@
        }
        return false;
    }
    private class ObservableToolCallback implements ToolCallback {
        private final ToolCallback delegate;
        private final SseEmitter emitter;
        private final String requestId;
        private final Long sessionId;
        private final AtomicLong toolCallSequence;
        private ObservableToolCallback(ToolCallback delegate, SseEmitter emitter, String requestId,
                                       Long sessionId, AtomicLong toolCallSequence) {
            this.delegate = delegate;
            this.emitter = emitter;
            this.requestId = requestId;
            this.sessionId = sessionId;
            this.toolCallSequence = toolCallSequence;
        }
        @Override
        public org.springframework.ai.tool.definition.ToolDefinition getToolDefinition() {
            return delegate.getToolDefinition();
        }
        @Override
        public org.springframework.ai.tool.metadata.ToolMetadata getToolMetadata() {
            return delegate.getToolMetadata();
        }
        @Override
        public String call(String toolInput) {
            return call(toolInput, null);
        }
        @Override
        public String call(String toolInput, ToolContext toolContext) {
            String toolName = delegate.getToolDefinition() == null ? "unknown" : delegate.getToolDefinition().name();
            String toolCallId = requestId + "-tool-" + toolCallSequence.incrementAndGet();
            long startedAt = System.currentTimeMillis();
            emitSafely(emitter, "tool_start", AiChatToolEventDto.builder()
                    .requestId(requestId)
                    .sessionId(sessionId)
                    .toolCallId(toolCallId)
                    .toolName(toolName)
                    .status("STARTED")
                    .inputSummary(summarizeToolPayload(toolInput, 400))
                    .timestamp(startedAt)
                    .build());
            try {
                String output = toolContext == null ? delegate.call(toolInput) : delegate.call(toolInput, toolContext);
                emitSafely(emitter, "tool_result", AiChatToolEventDto.builder()
                        .requestId(requestId)
                        .sessionId(sessionId)
                        .toolCallId(toolCallId)
                        .toolName(toolName)
                        .status("COMPLETED")
                        .inputSummary(summarizeToolPayload(toolInput, 400))
                        .outputSummary(summarizeToolPayload(output, 600))
                        .durationMs(System.currentTimeMillis() - startedAt)
                        .timestamp(System.currentTimeMillis())
                        .build());
                return output;
            } catch (RuntimeException e) {
                emitSafely(emitter, "tool_error", AiChatToolEventDto.builder()
                        .requestId(requestId)
                        .sessionId(sessionId)
                        .toolCallId(toolCallId)
                        .toolName(toolName)
                        .status("FAILED")
                        .inputSummary(summarizeToolPayload(toolInput, 400))
                        .errorMessage(e.getMessage())
                        .durationMs(System.currentTimeMillis() - startedAt)
                        .timestamp(System.currentTimeMillis())
                        .build());
                throw e;
            }
        }
    }
}