package com.vincent.rsf.server.ai.service.impl;
|
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
import com.vincent.rsf.server.ai.dto.AiObserveStatsDto;
|
import com.vincent.rsf.server.ai.entity.AiCallLog;
|
import com.vincent.rsf.server.ai.entity.AiMcpCallLog;
|
import com.vincent.rsf.server.ai.mapper.AiCallLogMapper;
|
import com.vincent.rsf.server.ai.mapper.AiMcpCallLogMapper;
|
import com.vincent.rsf.server.ai.service.AiCallLogService;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.stereotype.Service;
|
import org.springframework.util.StringUtils;
|
|
import java.util.Date;
|
import java.util.List;
|
import java.util.Objects;
|
import java.util.regex.Pattern;
|
|
@Service("aiCallLogService")
|
@RequiredArgsConstructor
|
public class AiCallLogServiceImpl extends ServiceImpl<AiCallLogMapper, AiCallLog> implements AiCallLogService {
|
|
private static final Pattern SECRET_JSON_PATTERN = Pattern.compile("(?i)(\"(?:apiKey|token|accessToken|refreshToken|password|authorization)\"\\s*:\\s*\")([^\"]+)(\")");
|
private static final Pattern BEARER_PATTERN = Pattern.compile("(?i)(bearer\\s+)([a-z0-9._-]+)");
|
|
private final AiMcpCallLogMapper aiMcpCallLogMapper;
|
|
@Override
|
public AiCallLog startCallLog(String requestId, Long sessionId, Long userId, Long tenantId, String promptCode,
|
String promptName, String model, Integer configuredMcpCount,
|
Integer mountedMcpCount, List<String> mountedMcpNames) {
|
Date now = new Date();
|
AiCallLog callLog = new AiCallLog()
|
.setRequestId(requestId)
|
.setSessionId(sessionId)
|
.setUserId(userId)
|
.setTenantId(tenantId)
|
.setPromptCode(promptCode)
|
.setPromptName(promptName)
|
.setModel(model)
|
.setStatus("RUNNING")
|
.setConfiguredMcpCount(configuredMcpCount)
|
.setMountedMcpCount(mountedMcpCount)
|
.setMountedMcpNames(joinNames(mountedMcpNames))
|
.setToolCallCount(0)
|
.setToolSuccessCount(0)
|
.setToolFailureCount(0)
|
.setDeleted(0)
|
.setCreateTime(now)
|
.setUpdateTime(now);
|
this.save(callLog);
|
return callLog;
|
}
|
|
@Override
|
public void completeCallLog(Long callLogId, String status, Long elapsedMs, Long firstTokenLatencyMs,
|
Integer promptTokens, Integer completionTokens, Integer totalTokens,
|
long toolSuccessCount, long toolFailureCount) {
|
if (callLogId == null) {
|
return;
|
}
|
this.update(new LambdaUpdateWrapper<AiCallLog>()
|
.eq(AiCallLog::getId, callLogId)
|
.set(AiCallLog::getStatus, status)
|
.set(AiCallLog::getElapsedMs, elapsedMs)
|
.set(AiCallLog::getFirstTokenLatencyMs, firstTokenLatencyMs)
|
.set(AiCallLog::getPromptTokens, promptTokens)
|
.set(AiCallLog::getCompletionTokens, completionTokens)
|
.set(AiCallLog::getTotalTokens, totalTokens)
|
.set(AiCallLog::getToolSuccessCount, (int) toolSuccessCount)
|
.set(AiCallLog::getToolFailureCount, (int) toolFailureCount)
|
.set(AiCallLog::getToolCallCount, (int) (toolSuccessCount + toolFailureCount))
|
.set(AiCallLog::getUpdateTime, new Date()));
|
}
|
|
@Override
|
public void failCallLog(Long callLogId, String status, String errorCategory, String errorStage, String errorMessage,
|
Long elapsedMs, Long firstTokenLatencyMs, long toolSuccessCount, long toolFailureCount) {
|
if (callLogId == null) {
|
return;
|
}
|
this.update(new LambdaUpdateWrapper<AiCallLog>()
|
.eq(AiCallLog::getId, callLogId)
|
.set(AiCallLog::getStatus, status)
|
.set(AiCallLog::getErrorCategory, errorCategory)
|
.set(AiCallLog::getErrorStage, errorStage)
|
.set(AiCallLog::getErrorMessage, maskSensitive(errorMessage))
|
.set(AiCallLog::getElapsedMs, elapsedMs)
|
.set(AiCallLog::getFirstTokenLatencyMs, firstTokenLatencyMs)
|
.set(AiCallLog::getToolSuccessCount, (int) toolSuccessCount)
|
.set(AiCallLog::getToolFailureCount, (int) toolFailureCount)
|
.set(AiCallLog::getToolCallCount, (int) (toolSuccessCount + toolFailureCount))
|
.set(AiCallLog::getUpdateTime, new Date()));
|
}
|
|
@Override
|
public void saveMcpCallLog(Long callLogId, String requestId, Long sessionId, String toolCallId, String mountName,
|
String toolName, String status, String inputSummary, String outputSummary,
|
String errorMessage, Long durationMs, Long userId, Long tenantId) {
|
if (callLogId == null) {
|
return;
|
}
|
aiMcpCallLogMapper.insert(new AiMcpCallLog()
|
.setCallLogId(callLogId)
|
.setRequestId(requestId)
|
.setSessionId(sessionId)
|
.setToolCallId(toolCallId)
|
.setMountName(mountName)
|
.setToolName(toolName)
|
.setStatus(status)
|
.setInputSummary(maskSensitive(inputSummary))
|
.setOutputSummary(maskSensitive(outputSummary))
|
.setErrorMessage(maskSensitive(errorMessage))
|
.setDurationMs(durationMs)
|
.setUserId(userId)
|
.setTenantId(tenantId)
|
.setCreateTime(new Date()));
|
}
|
|
@Override
|
public AiObserveStatsDto getObserveStats(Long tenantId) {
|
List<AiCallLog> callLogs = this.list(new LambdaQueryWrapper<AiCallLog>()
|
.eq(AiCallLog::getTenantId, tenantId)
|
.eq(AiCallLog::getDeleted, 0)
|
.orderByDesc(AiCallLog::getId));
|
List<AiMcpCallLog> mcpCallLogs = aiMcpCallLogMapper.selectList(new LambdaQueryWrapper<AiMcpCallLog>()
|
.eq(AiMcpCallLog::getTenantId, tenantId)
|
.orderByDesc(AiMcpCallLog::getId));
|
|
long callCount = callLogs.size();
|
long successCount = callLogs.stream().filter(item -> "COMPLETED".equals(item.getStatus())).count();
|
long failureCount = callLogs.stream().filter(item -> "FAILED".equals(item.getStatus())).count();
|
long totalElapsed = callLogs.stream().map(AiCallLog::getElapsedMs).filter(Objects::nonNull).mapToLong(Long::longValue).sum();
|
long elapsedCount = callLogs.stream().map(AiCallLog::getElapsedMs).filter(Objects::nonNull).count();
|
long totalFirstToken = callLogs.stream().map(AiCallLog::getFirstTokenLatencyMs).filter(Objects::nonNull).mapToLong(Long::longValue).sum();
|
long firstTokenCount = callLogs.stream().map(AiCallLog::getFirstTokenLatencyMs).filter(Objects::nonNull).count();
|
long totalTokens = callLogs.stream().map(AiCallLog::getTotalTokens).filter(Objects::nonNull).mapToLong(Integer::longValue).sum();
|
long tokenCount = callLogs.stream().map(AiCallLog::getTotalTokens).filter(Objects::nonNull).count();
|
|
long toolCallCount = mcpCallLogs.size();
|
long toolSuccessCount = mcpCallLogs.stream().filter(item -> "COMPLETED".equals(item.getStatus())).count();
|
long toolFailureCount = mcpCallLogs.stream().filter(item -> "FAILED".equals(item.getStatus())).count();
|
double toolSuccessRate = toolCallCount == 0 ? 0D : (toolSuccessCount * 100D) / toolCallCount;
|
|
return AiObserveStatsDto.builder()
|
.callCount(callCount)
|
.successCount(successCount)
|
.failureCount(failureCount)
|
.avgElapsedMs(elapsedCount == 0 ? 0L : totalElapsed / elapsedCount)
|
.avgFirstTokenLatencyMs(firstTokenCount == 0 ? 0L : totalFirstToken / firstTokenCount)
|
.totalTokens(totalTokens)
|
.avgTotalTokens(tokenCount == 0 ? 0L : totalTokens / tokenCount)
|
.toolCallCount(toolCallCount)
|
.toolSuccessCount(toolSuccessCount)
|
.toolFailureCount(toolFailureCount)
|
.toolSuccessRate(toolSuccessRate)
|
.build();
|
}
|
|
@Override
|
public List<AiMcpCallLog> listMcpLogs(Long callLogId, Long tenantId) {
|
return aiMcpCallLogMapper.selectList(new LambdaQueryWrapper<AiMcpCallLog>()
|
.eq(AiMcpCallLog::getCallLogId, callLogId)
|
.eq(AiMcpCallLog::getTenantId, tenantId)
|
.orderByDesc(AiMcpCallLog::getId));
|
}
|
|
private String joinNames(List<String> names) {
|
if (names == null || names.isEmpty()) {
|
return "";
|
}
|
return String.join("、", names);
|
}
|
|
private String maskSensitive(String source) {
|
if (!StringUtils.hasText(source)) {
|
return source;
|
}
|
String masked = SECRET_JSON_PATTERN.matcher(source).replaceAll("$1***$3");
|
return BEARER_PATTERN.matcher(masked).replaceAll("$1***");
|
}
|
}
|