package com.zy.ai.service.impl;
|
|
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
import com.zy.ai.entity.AiChatMessage;
|
import com.zy.ai.entity.AiChatSession;
|
import com.zy.ai.entity.AiPromptTemplate;
|
import com.zy.ai.entity.ChatCompletionRequest;
|
import com.zy.ai.mapper.AiChatMessageMapper;
|
import com.zy.ai.mapper.AiChatSessionMapper;
|
import com.zy.ai.service.AiChatStoreService;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.stereotype.Service;
|
import org.springframework.transaction.annotation.Transactional;
|
|
import java.util.ArrayList;
|
import java.util.Date;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
@Service
|
@RequiredArgsConstructor
|
public class AiChatStoreServiceImpl implements AiChatStoreService {
|
|
private final AiChatSessionMapper aiChatSessionMapper;
|
private final AiChatMessageMapper aiChatMessageMapper;
|
|
@Override
|
public List<Map<String, Object>> listChats() {
|
List<AiChatSession> sessions = aiChatSessionMapper.selectList(new QueryWrapper<AiChatSession>()
|
.orderByDesc("update_time")
|
.orderByDesc("id"));
|
List<Map<String, Object>> result = new ArrayList<>();
|
for (AiChatSession session : sessions) {
|
if (session == null) {
|
continue;
|
}
|
LinkedHashMap<String, Object> item = new LinkedHashMap<>();
|
item.put("chatId", session.getChatId());
|
item.put("title", session.getTitle());
|
item.put("size", session.getMessageCount());
|
item.put("promptTemplateId", session.getPromptTemplateId());
|
item.put("promptSceneCode", session.getPromptSceneCode());
|
item.put("promptVersion", session.getPromptVersion());
|
item.put("promptName", session.getPromptName());
|
item.put("lastPromptTokens", session.getLastPromptTokens());
|
item.put("lastCompletionTokens", session.getLastCompletionTokens());
|
item.put("lastTotalTokens", session.getLastTotalTokens());
|
item.put("lastLlmCallCount", session.getLastLlmCallCount());
|
item.put("sumPromptTokens", session.getSumPromptTokens());
|
item.put("sumCompletionTokens", session.getSumCompletionTokens());
|
item.put("sumTotalTokens", session.getSumTotalTokens());
|
item.put("askCount", session.getAskCount());
|
item.put("createdAt", toEpochMilli(session.getCreateTime()));
|
item.put("updatedAt", toEpochMilli(session.getUpdateTime()));
|
item.put("lastTokenUpdatedAt", toEpochMilli(session.getLastTokenUpdatedAt()));
|
result.add(item);
|
}
|
return result;
|
}
|
|
@Override
|
@Transactional(rollbackFor = Exception.class)
|
public boolean deleteChat(String chatId) {
|
if (isBlank(chatId)) {
|
return false;
|
}
|
aiChatMessageMapper.delete(new QueryWrapper<AiChatMessage>().eq("chat_id", chatId));
|
aiChatSessionMapper.delete(new QueryWrapper<AiChatSession>().eq("chat_id", chatId));
|
return true;
|
}
|
|
@Override
|
public List<ChatCompletionRequest.Message> getChatHistory(String chatId) {
|
if (isBlank(chatId)) {
|
return java.util.Collections.emptyList();
|
}
|
List<AiChatMessage> rows = aiChatMessageMapper.selectList(new QueryWrapper<AiChatMessage>()
|
.eq("chat_id", chatId)
|
.orderByAsc("seq_no")
|
.orderByAsc("id"));
|
List<ChatCompletionRequest.Message> result = new ArrayList<>(rows.size());
|
for (AiChatMessage row : rows) {
|
if (row == null) {
|
continue;
|
}
|
ChatCompletionRequest.Message message = new ChatCompletionRequest.Message();
|
message.setRole(row.getRole());
|
message.setContent(row.getContent());
|
message.setReasoningContent(row.getReasoningContent());
|
result.add(message);
|
}
|
return result;
|
}
|
|
@Override
|
@Transactional(rollbackFor = Exception.class)
|
public void saveConversation(String chatId,
|
String title,
|
ChatCompletionRequest.Message userMessage,
|
ChatCompletionRequest.Message assistantMessage,
|
AiPromptTemplate promptTemplate,
|
long promptTokens,
|
long completionTokens,
|
long totalTokens,
|
int llmCallCount) {
|
if (isBlank(chatId) || userMessage == null || assistantMessage == null) {
|
return;
|
}
|
synchronized (("ai_chat_store_" + chatId).intern()) {
|
AiChatSession session = aiChatSessionMapper.selectOne(new QueryWrapper<AiChatSession>()
|
.eq("chat_id", chatId)
|
.last("limit 1"));
|
Date now = new Date();
|
int nextSeq = 1;
|
if (session == null) {
|
session = new AiChatSession();
|
session.setChatId(chatId);
|
session.setCreateTime(now);
|
session.setMessageCount(0);
|
session.setSumPromptTokens(0L);
|
session.setSumCompletionTokens(0L);
|
session.setSumTotalTokens(0L);
|
session.setAskCount(0L);
|
} else {
|
Integer maxSeq = maxSeqNo(chatId);
|
nextSeq = maxSeq == null ? 1 : (maxSeq + 1);
|
}
|
|
session.setTitle(cut(title, 255));
|
if (promptTemplate != null) {
|
session.setPromptTemplateId(promptTemplate.getId());
|
session.setPromptSceneCode(cut(promptTemplate.getSceneCode(), 64));
|
session.setPromptVersion(promptTemplate.getVersion());
|
session.setPromptName(cut(promptTemplate.getName(), 255));
|
} else {
|
session.setPromptTemplateId(null);
|
session.setPromptSceneCode(null);
|
session.setPromptVersion(null);
|
session.setPromptName(null);
|
}
|
session.setLastPromptTokens(promptTokens);
|
session.setLastCompletionTokens(completionTokens);
|
session.setLastTotalTokens(totalTokens);
|
session.setLastLlmCallCount(llmCallCount);
|
session.setLastTokenUpdatedAt(now);
|
session.setMessageCount((session.getMessageCount() == null ? 0 : session.getMessageCount()) + 2);
|
session.setSumPromptTokens((session.getSumPromptTokens() == null ? 0L : session.getSumPromptTokens()) + promptTokens);
|
session.setSumCompletionTokens((session.getSumCompletionTokens() == null ? 0L : session.getSumCompletionTokens()) + completionTokens);
|
session.setSumTotalTokens((session.getSumTotalTokens() == null ? 0L : session.getSumTotalTokens()) + totalTokens);
|
session.setAskCount((session.getAskCount() == null ? 0L : session.getAskCount()) + 1);
|
|
if (session.getId() == null) {
|
aiChatSessionMapper.insert(session);
|
} else {
|
aiChatSessionMapper.updateById(session);
|
}
|
|
insertMessage(chatId, nextSeq, userMessage, now);
|
insertMessage(chatId, nextSeq + 1, assistantMessage, now);
|
}
|
}
|
|
private void insertMessage(String chatId, int seqNo, ChatCompletionRequest.Message source, Date now) {
|
AiChatMessage row = new AiChatMessage();
|
row.setChatId(chatId);
|
row.setSeqNo(seqNo);
|
row.setRole(cut(source.getRole(), 32));
|
row.setContent(source.getContent());
|
row.setReasoningContent(source.getReasoningContent());
|
row.setCreateTime(now);
|
aiChatMessageMapper.insert(row);
|
}
|
|
private Integer maxSeqNo(String chatId) {
|
AiChatMessage last = aiChatMessageMapper.selectOne(new QueryWrapper<AiChatMessage>()
|
.eq("chat_id", chatId)
|
.orderByDesc("seq_no")
|
.orderByDesc("id")
|
.last("limit 1"));
|
return last == null ? null : last.getSeqNo();
|
}
|
|
private long toEpochMilli(Date date) {
|
return date == null ? 0L : date.getTime();
|
}
|
|
private boolean isBlank(String text) {
|
return text == null || text.trim().isEmpty();
|
}
|
|
private String cut(String text, int maxLen) {
|
if (text == null) {
|
return null;
|
}
|
return text.length() > maxLen ? text.substring(0, maxLen) : text;
|
}
|
}
|