package com.zy.ai.service.impl;
|
|
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
import com.zy.ai.entity.AiPromptBlock;
|
import com.zy.ai.entity.AiPromptTemplate;
|
import com.zy.ai.enums.AiPromptBlockType;
|
import com.zy.ai.enums.AiPromptScene;
|
import com.zy.ai.mapper.AiPromptBlockMapper;
|
import com.zy.ai.mapper.AiPromptTemplateMapper;
|
import com.zy.ai.service.AiPromptComposerService;
|
import com.zy.ai.service.AiPromptTemplateService;
|
import com.zy.ai.utils.AiPromptUtils;
|
import lombok.RequiredArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
import org.springframework.stereotype.Service;
|
import org.springframework.transaction.annotation.Transactional;
|
|
import java.util.ArrayList;
|
import java.util.Collections;
|
import java.util.Comparator;
|
import java.util.HashMap;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
@Slf4j
|
@Service("aiPromptTemplateService")
|
@RequiredArgsConstructor
|
public class AiPromptTemplateServiceImpl extends ServiceImpl<AiPromptTemplateMapper, AiPromptTemplate> implements AiPromptTemplateService {
|
|
private static final Comparator<AiPromptBlock> BLOCK_SORT = (a, b) -> {
|
int sa = a != null && a.getSortNo() != null ? a.getSortNo() : Integer.MAX_VALUE;
|
int sb = b != null && b.getSortNo() != null ? b.getSortNo() : Integer.MAX_VALUE;
|
if (sa != sb) {
|
return sa - sb;
|
}
|
long ia = a != null && a.getId() != null ? a.getId() : Long.MAX_VALUE;
|
long ib = b != null && b.getId() != null ? b.getId() : Long.MAX_VALUE;
|
return Long.compare(ia, ib);
|
};
|
|
private final AiPromptUtils aiPromptUtils;
|
private final AiPromptBlockMapper aiPromptBlockMapper;
|
private final AiPromptComposerService aiPromptComposerService;
|
|
@Override
|
public AiPromptTemplate resolvePublished(String sceneCode) {
|
AiPromptScene scene = requireScene(sceneCode);
|
AiPromptTemplate prompt = findPublished(scene.getCode());
|
if (prompt == null) {
|
synchronized (("ai_prompt_scene_init_" + scene.getCode()).intern()) {
|
prompt = findPublished(scene.getCode());
|
if (prompt == null && findLatest(scene.getCode()) == null) {
|
prompt = ensurePublishedScene(scene);
|
}
|
}
|
}
|
if (prompt == null) {
|
throw new IllegalStateException("当前场景没有已发布 Prompt,sceneCode=" + scene.getCode());
|
}
|
return enrichTemplate(prompt);
|
}
|
|
@Override
|
@Transactional(rollbackFor = Exception.class)
|
public AiPromptTemplate savePrompt(AiPromptTemplate template, Long operatorUserId) {
|
if (template == null) {
|
throw new IllegalArgumentException("Prompt 不能为空");
|
}
|
AiPromptScene scene = requireScene(template.getSceneCode());
|
String compiled = buildCompiledPrompt(template);
|
if (compiled.isEmpty()) {
|
throw new IllegalArgumentException("Prompt 分段内容不能为空");
|
}
|
|
if (template.getId() == null) {
|
AiPromptTemplate entity = new AiPromptTemplate();
|
int version = nextVersion(scene.getCode());
|
entity.setName(defaultName(scene, version, template.getName()));
|
entity.setSceneCode(scene.getCode());
|
entity.setVersion(version);
|
entity.setContent(compiled);
|
entity.setStatus(normalizeStatus(template.getStatus()));
|
entity.setPublished((short) 0);
|
entity.setCreatedBy(operatorUserId);
|
entity.setMemo(trim(template.getMemo()));
|
this.save(entity);
|
upsertBlocks(entity.getId(), extractBlockContentMap(template));
|
return entity;
|
}
|
|
AiPromptTemplate db = this.getById(template.getId());
|
if (db == null) {
|
throw new IllegalArgumentException("Prompt 不存在");
|
}
|
if (!scene.getCode().equals(db.getSceneCode())) {
|
throw new IllegalArgumentException("不允许修改 Prompt 所属场景");
|
}
|
if (Short.valueOf((short) 1).equals(db.getPublished())) {
|
throw new IllegalArgumentException("已发布 Prompt 不允许直接修改,请先取消发布后再保存");
|
}
|
|
db.setName(defaultName(scene, db.getVersion() == null ? 1 : db.getVersion(), template.getName()));
|
db.setContent(compiled);
|
db.setStatus(normalizeStatus(template.getStatus()));
|
db.setMemo(trim(template.getMemo()));
|
this.updateById(db);
|
upsertBlocks(db.getId(), extractBlockContentMap(template));
|
return db;
|
}
|
|
@Override
|
@Transactional(rollbackFor = Exception.class)
|
public AiPromptTemplate publishPrompt(Long id, Long operatorUserId) {
|
if (id == null) {
|
throw new IllegalArgumentException("id 不能为空");
|
}
|
AiPromptTemplate db = this.getById(id);
|
if (db == null) {
|
throw new IllegalArgumentException("Prompt 不存在");
|
}
|
db = enrichTemplate(db);
|
String compiled = buildCompiledPrompt(db);
|
if (compiled.isEmpty()) {
|
throw new IllegalArgumentException("Prompt 内容不能为空");
|
}
|
|
UpdateWrapper<AiPromptTemplate> clearWrapper = new UpdateWrapper<>();
|
clearWrapper.eq("scene_code", db.getSceneCode()).set("published", 0);
|
this.update(clearWrapper);
|
|
db.setPublished((short) 1);
|
db.setStatus((short) 1);
|
db.setPublishedBy(operatorUserId);
|
db.setPublishedTime(new java.util.Date());
|
db.setContent(compiled);
|
if (db.getName() == null || db.getName().trim().isEmpty()) {
|
AiPromptScene scene = requireScene(db.getSceneCode());
|
db.setName(defaultName(scene, db.getVersion(), null));
|
}
|
this.updateById(db);
|
return db;
|
}
|
|
@Override
|
@Transactional(rollbackFor = Exception.class)
|
public AiPromptTemplate cancelPublish(Long id, Long operatorUserId) {
|
if (id == null) {
|
throw new IllegalArgumentException("id 不能为空");
|
}
|
AiPromptTemplate db = this.getById(id);
|
if (db == null) {
|
throw new IllegalArgumentException("Prompt 不存在");
|
}
|
if (!Short.valueOf((short) 1).equals(db.getPublished())) {
|
throw new IllegalArgumentException("当前 Prompt 不是已发布状态");
|
}
|
db.setPublished((short) 0);
|
db.setPublishedBy(operatorUserId);
|
db.setPublishedTime(null);
|
this.updateById(db);
|
return db;
|
}
|
|
@Override
|
public AiPromptTemplate enrichTemplate(AiPromptTemplate template) {
|
if (template == null) {
|
return null;
|
}
|
if (template.getId() == null) {
|
template.setContent(buildCompiledPrompt(template));
|
return template;
|
}
|
|
List<AiPromptBlock> blocks = loadBlocks(template.getId());
|
if (blocks.isEmpty()) {
|
migrateLegacyTemplateBlocks(template);
|
blocks = loadBlocks(template.getId());
|
}
|
applyBlocks(template, blocks);
|
return template;
|
}
|
|
@Override
|
public List<AiPromptTemplate> enrichTemplates(List<AiPromptTemplate> templates) {
|
if (templates == null || templates.isEmpty()) {
|
return templates == null ? Collections.emptyList() : templates;
|
}
|
|
List<Long> templateIds = new ArrayList<>();
|
for (AiPromptTemplate template : templates) {
|
if (template != null && template.getId() != null) {
|
templateIds.add(template.getId());
|
}
|
}
|
Map<Long, List<AiPromptBlock>> blockMap = groupBlocks(loadBlocks(templateIds));
|
boolean migrated = false;
|
for (AiPromptTemplate template : templates) {
|
if (template == null || template.getId() == null) {
|
continue;
|
}
|
List<AiPromptBlock> blocks = blockMap.get(template.getId());
|
if (blocks == null || blocks.isEmpty()) {
|
migrateLegacyTemplateBlocks(template);
|
migrated = true;
|
}
|
}
|
if (migrated) {
|
blockMap = groupBlocks(loadBlocks(templateIds));
|
}
|
for (AiPromptTemplate template : templates) {
|
if (template == null) {
|
continue;
|
}
|
if (template.getId() == null) {
|
template.setContent(buildCompiledPrompt(template));
|
continue;
|
}
|
applyBlocks(template, blockMap.get(template.getId()));
|
}
|
return templates;
|
}
|
|
@Override
|
@Transactional(rollbackFor = Exception.class)
|
public boolean deletePrompt(Long id) {
|
if (id == null) {
|
return false;
|
}
|
AiPromptTemplate db = this.getById(id);
|
if (db == null) {
|
return false;
|
}
|
if (Short.valueOf((short) 1).equals(db.getPublished())) {
|
throw new IllegalArgumentException("已发布 Prompt 不允许删除,请先取消发布");
|
}
|
aiPromptBlockMapper.delete(new QueryWrapper<AiPromptBlock>().eq("template_id", id));
|
return this.removeById(id);
|
}
|
|
@Override
|
@Transactional(rollbackFor = Exception.class)
|
public int initDefaultsIfMissing() {
|
int changed = 0;
|
for (AiPromptScene scene : AiPromptScene.values()) {
|
AiPromptTemplate latest = findLatest(scene.getCode());
|
if (latest == null) {
|
ensurePublishedScene(scene);
|
changed++;
|
continue;
|
}
|
List<AiPromptBlock> blocks = loadBlocks(latest.getId());
|
if (blocks.isEmpty()) {
|
migrateLegacyTemplateBlocks(latest);
|
changed++;
|
}
|
}
|
return changed;
|
}
|
|
@Override
|
public List<Map<String, Object>> listSupportedScenes() {
|
List<Map<String, Object>> result = new ArrayList<>();
|
for (AiPromptScene scene : AiPromptScene.values()) {
|
HashMap<String, Object> item = new HashMap<>();
|
item.put("code", scene.getCode());
|
item.put("label", scene.getLabel());
|
result.add(item);
|
}
|
return result;
|
}
|
|
private AiPromptTemplate ensurePublishedScene(AiPromptScene scene) {
|
LinkedHashMap<AiPromptBlockType, String> blocks = aiPromptUtils.getDefaultPromptBlocks(scene);
|
AiPromptTemplate seed = new AiPromptTemplate();
|
seed.setName(defaultName(scene, 1, null));
|
seed.setSceneCode(scene.getCode());
|
seed.setVersion(1);
|
seed.setStatus((short) 1);
|
seed.setPublished((short) 1);
|
seed.setPublishedTime(new java.util.Date());
|
seed.setMemo("系统初始化默认 Prompt");
|
applyBlockFields(seed, blocks);
|
seed.setContent(buildCompiledPrompt(seed));
|
this.save(seed);
|
upsertBlocks(seed.getId(), blocks);
|
log.info("Initialized default AI prompt blocks, sceneCode={}, version={}", scene.getCode(), seed.getVersion());
|
return seed;
|
}
|
|
private void migrateLegacyTemplateBlocks(AiPromptTemplate template) {
|
if (template == null || template.getId() == null) {
|
return;
|
}
|
AiPromptScene scene = requireScene(template.getSceneCode());
|
LinkedHashMap<AiPromptBlockType, String> blocks = aiPromptUtils.resolveStoredOrDefaultBlocks(scene, template.getContent());
|
upsertBlocks(template.getId(), blocks);
|
applyBlockFields(template, blocks);
|
template.setContent(buildCompiledPrompt(template));
|
this.updateById(template);
|
}
|
|
private void applyBlocks(AiPromptTemplate template, List<AiPromptBlock> blocks) {
|
List<AiPromptBlock> ordered = blocks == null ? new ArrayList<>() : new ArrayList<>(blocks);
|
ordered.sort(BLOCK_SORT);
|
template.setBlocks(ordered);
|
|
LinkedHashMap<AiPromptBlockType, String> blockContent = new LinkedHashMap<>();
|
for (AiPromptBlockType type : AiPromptBlockType.values()) {
|
blockContent.put(type, "");
|
}
|
for (AiPromptBlock block : ordered) {
|
AiPromptBlockType type = AiPromptBlockType.ofCode(block.getBlockType());
|
if (type == null) {
|
continue;
|
}
|
blockContent.put(type, block.getContent());
|
}
|
applyBlockFields(template, blockContent);
|
template.setContent(buildCompiledPrompt(template));
|
}
|
|
private void applyBlockFields(AiPromptTemplate template, LinkedHashMap<AiPromptBlockType, String> blockContent) {
|
template.setBasePolicy(valueOf(blockContent, AiPromptBlockType.BASE_POLICY));
|
template.setToolPolicy(valueOf(blockContent, AiPromptBlockType.TOOL_POLICY));
|
template.setOutputContract(valueOf(blockContent, AiPromptBlockType.OUTPUT_CONTRACT));
|
template.setScenePlaybook(valueOf(blockContent, AiPromptBlockType.SCENE_PLAYBOOK));
|
}
|
|
private String valueOf(LinkedHashMap<AiPromptBlockType, String> blockContent, AiPromptBlockType type) {
|
if (blockContent == null) {
|
return "";
|
}
|
String value = blockContent.get(type);
|
return value == null ? "" : value;
|
}
|
|
private String buildCompiledPrompt(AiPromptTemplate template) {
|
String compiled = aiPromptComposerService.compose(template);
|
return compiled == null ? "" : compiled.trim();
|
}
|
|
private LinkedHashMap<AiPromptBlockType, String> extractBlockContentMap(AiPromptTemplate template) {
|
LinkedHashMap<AiPromptBlockType, String> blocks = new LinkedHashMap<>();
|
blocks.put(AiPromptBlockType.BASE_POLICY, defaultString(template.getBasePolicy()));
|
blocks.put(AiPromptBlockType.TOOL_POLICY, defaultString(template.getToolPolicy()));
|
blocks.put(AiPromptBlockType.OUTPUT_CONTRACT, defaultString(template.getOutputContract()));
|
blocks.put(AiPromptBlockType.SCENE_PLAYBOOK, defaultString(template.getScenePlaybook()));
|
return blocks;
|
}
|
|
private void upsertBlocks(Long templateId, LinkedHashMap<AiPromptBlockType, String> blockContent) {
|
List<AiPromptBlock> existingBlocks = loadBlocks(templateId);
|
HashMap<String, AiPromptBlock> existingMap = new HashMap<>();
|
for (AiPromptBlock block : existingBlocks) {
|
existingMap.put(block.getBlockType(), block);
|
}
|
|
for (AiPromptBlockType type : AiPromptBlockType.values()) {
|
AiPromptBlock block = existingMap.get(type.getCode());
|
if (block == null) {
|
block = new AiPromptBlock();
|
block.setTemplateId(templateId);
|
block.setBlockType(type.getCode());
|
block.setSortNo(type.getSort());
|
block.setStatus((short) 1);
|
block.setContent(defaultString(blockContent.get(type)));
|
aiPromptBlockMapper.insert(block);
|
continue;
|
}
|
block.setSortNo(type.getSort());
|
block.setStatus((short) 1);
|
block.setContent(defaultString(blockContent.get(type)));
|
aiPromptBlockMapper.updateById(block);
|
}
|
}
|
|
private List<AiPromptBlock> loadBlocks(Long templateId) {
|
if (templateId == null) {
|
return Collections.emptyList();
|
}
|
return aiPromptBlockMapper.selectList(new QueryWrapper<AiPromptBlock>()
|
.eq("template_id", templateId)
|
.orderByAsc("sort_no")
|
.orderByAsc("id"));
|
}
|
|
private List<AiPromptBlock> loadBlocks(List<Long> templateIds) {
|
if (templateIds == null || templateIds.isEmpty()) {
|
return Collections.emptyList();
|
}
|
return aiPromptBlockMapper.selectList(new QueryWrapper<AiPromptBlock>()
|
.in("template_id", templateIds)
|
.orderByAsc("sort_no")
|
.orderByAsc("id"));
|
}
|
|
private Map<Long, List<AiPromptBlock>> groupBlocks(List<AiPromptBlock> blocks) {
|
HashMap<Long, List<AiPromptBlock>> result = new HashMap<>();
|
if (blocks == null) {
|
return result;
|
}
|
for (AiPromptBlock block : blocks) {
|
if (block == null || block.getTemplateId() == null) {
|
continue;
|
}
|
result.computeIfAbsent(block.getTemplateId(), k -> new ArrayList<>()).add(block);
|
}
|
for (List<AiPromptBlock> list : result.values()) {
|
list.sort(BLOCK_SORT);
|
}
|
return result;
|
}
|
|
private AiPromptTemplate findPublished(String sceneCode) {
|
QueryWrapper<AiPromptTemplate> wrapper = new QueryWrapper<>();
|
wrapper.eq("scene_code", sceneCode)
|
.eq("status", 1)
|
.eq("published", 1)
|
.orderByDesc("version")
|
.orderByDesc("id")
|
.last("limit 1");
|
return this.getOne(wrapper, false);
|
}
|
|
private AiPromptTemplate findLatest(String sceneCode) {
|
QueryWrapper<AiPromptTemplate> wrapper = new QueryWrapper<>();
|
wrapper.eq("scene_code", sceneCode)
|
.orderByDesc("version")
|
.orderByDesc("id")
|
.last("limit 1");
|
return this.getOne(wrapper, false);
|
}
|
|
private int nextVersion(String sceneCode) {
|
QueryWrapper<AiPromptTemplate> wrapper = new QueryWrapper<>();
|
wrapper.eq("scene_code", sceneCode)
|
.select("max(version) as version");
|
Map<String, Object> row = this.getMap(wrapper);
|
if (row == null || row.get("version") == null) {
|
return 1;
|
}
|
Object value = row.get("version");
|
if (value instanceof Number) {
|
return ((Number) value).intValue() + 1;
|
}
|
return Integer.parseInt(String.valueOf(value)) + 1;
|
}
|
|
private Short normalizeStatus(Short status) {
|
return status != null && status == 0 ? (short) 0 : (short) 1;
|
}
|
|
private String defaultName(AiPromptScene scene, Integer version, String name) {
|
String value = trim(name);
|
if (value != null && !value.isEmpty()) {
|
return value;
|
}
|
return scene.getLabel() + " v" + version;
|
}
|
|
private AiPromptScene requireScene(String sceneCode) {
|
String code = trim(sceneCode);
|
AiPromptScene scene = AiPromptScene.ofCode(code);
|
if (scene == null) {
|
throw new IllegalArgumentException("不支持的 Prompt 场景: " + sceneCode);
|
}
|
return scene;
|
}
|
|
private String trim(String value) {
|
if (value == null) {
|
return null;
|
}
|
String trimmed = value.trim();
|
return trimmed.isEmpty() ? null : trimmed;
|
}
|
|
private String defaultString(String value) {
|
return value == null ? "" : value;
|
}
|
}
|