#
Junjie
3 天以前 0c1110daa59bf77ddcff2704641280f417158c10
src/main/java/com/zy/ai/service/impl/AiPromptTemplateServiceImpl.java
@@ -3,9 +3,13 @@
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.zy.ai.entity.AiPromptBlock;
import com.zy.ai.entity.AiPromptTemplate;
import com.zy.ai.enums.AiPromptBlockType;
import com.zy.ai.enums.AiPromptScene;
import com.zy.ai.mapper.AiPromptBlockMapper;
import com.zy.ai.mapper.AiPromptTemplateMapper;
import com.zy.ai.service.AiPromptComposerService;
import com.zy.ai.service.AiPromptTemplateService;
import com.zy.ai.utils.AiPromptUtils;
import lombok.RequiredArgsConstructor;
@@ -14,8 +18,10 @@
import org.springframework.transaction.annotation.Transactional;
import java.util.ArrayList;
import java.util.Date;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -24,7 +30,20 @@
@RequiredArgsConstructor
public class AiPromptTemplateServiceImpl extends ServiceImpl<AiPromptTemplateMapper, AiPromptTemplate> implements AiPromptTemplateService {
    private static final Comparator<AiPromptBlock> BLOCK_SORT = (a, b) -> {
        int sa = a != null && a.getSortNo() != null ? a.getSortNo() : Integer.MAX_VALUE;
        int sb = b != null && b.getSortNo() != null ? b.getSortNo() : Integer.MAX_VALUE;
        if (sa != sb) {
            return sa - sb;
        }
        long ia = a != null && a.getId() != null ? a.getId() : Long.MAX_VALUE;
        long ib = b != null && b.getId() != null ? b.getId() : Long.MAX_VALUE;
        return Long.compare(ia, ib);
    };
    private final AiPromptUtils aiPromptUtils;
    private final AiPromptBlockMapper aiPromptBlockMapper;
    private final AiPromptComposerService aiPromptComposerService;
    @Override
    public AiPromptTemplate resolvePublished(String sceneCode) {
@@ -33,16 +52,15 @@
        if (prompt == null) {
            synchronized (("ai_prompt_scene_init_" + scene.getCode()).intern()) {
                prompt = findPublished(scene.getCode());
                if (prompt == null) {
                if (prompt == null && findLatest(scene.getCode()) == null) {
                    prompt = ensurePublishedScene(scene);
                }
            }
        }
        if (prompt == null) {
            throw new IllegalStateException("未找到已发布的 Prompt,sceneCode=" + scene.getCode());
            throw new IllegalStateException("当前场景没有已发布 Prompt,sceneCode=" + scene.getCode());
        }
        return prompt;
        return enrichTemplate(prompt);
    }
    @Override
@@ -52,9 +70,9 @@
            throw new IllegalArgumentException("Prompt 不能为空");
        }
        AiPromptScene scene = requireScene(template.getSceneCode());
        String content = template.getContent();
        if (content == null || content.trim().isEmpty()) {
            throw new IllegalArgumentException("Prompt 内容不能为空");
        String compiled = buildCompiledPrompt(template);
        if (compiled.isEmpty()) {
            throw new IllegalArgumentException("Prompt 分段内容不能为空");
        }
        if (template.getId() == null) {
@@ -63,12 +81,13 @@
            entity.setName(defaultName(scene, version, template.getName()));
            entity.setSceneCode(scene.getCode());
            entity.setVersion(version);
            entity.setContent(content);
            entity.setContent(compiled);
            entity.setStatus(normalizeStatus(template.getStatus()));
            entity.setPublished((short) 0);
            entity.setCreatedBy(operatorUserId);
            entity.setMemo(trim(template.getMemo()));
            this.save(entity);
            upsertBlocks(entity.getId(), extractBlockContentMap(template));
            return entity;
        }
@@ -80,14 +99,15 @@
            throw new IllegalArgumentException("不允许修改 Prompt 所属场景");
        }
        if (Short.valueOf((short) 1).equals(db.getPublished())) {
            throw new IllegalArgumentException("已发布 Prompt 不允许直接修改,请新建版本后再发布");
            throw new IllegalArgumentException("已发布 Prompt 不允许直接修改,请先取消发布后再保存");
        }
        db.setName(defaultName(scene, db.getVersion() == null ? 1 : db.getVersion(), template.getName()));
        db.setContent(content);
        db.setContent(compiled);
        db.setStatus(normalizeStatus(template.getStatus()));
        db.setMemo(trim(template.getMemo()));
        this.updateById(db);
        upsertBlocks(db.getId(), extractBlockContentMap(template));
        return db;
    }
@@ -101,7 +121,9 @@
        if (db == null) {
            throw new IllegalArgumentException("Prompt 不存在");
        }
        if (db.getContent() == null || db.getContent().trim().isEmpty()) {
        db = enrichTemplate(db);
        String compiled = buildCompiledPrompt(db);
        if (compiled.isEmpty()) {
            throw new IllegalArgumentException("Prompt 内容不能为空");
        }
@@ -112,16 +134,93 @@
        db.setPublished((short) 1);
        db.setStatus((short) 1);
        db.setPublishedBy(operatorUserId);
        db.setPublishedTime(new Date());
        if (db.getVersion() == null || db.getVersion() <= 0) {
            db.setVersion(nextVersion(db.getSceneCode()));
        }
        db.setPublishedTime(new java.util.Date());
        db.setContent(compiled);
        if (db.getName() == null || db.getName().trim().isEmpty()) {
            AiPromptScene scene = requireScene(db.getSceneCode());
            db.setName(defaultName(scene, db.getVersion(), null));
        }
        this.updateById(db);
        return db;
    }
    @Override
    @Transactional(rollbackFor = Exception.class)
    public AiPromptTemplate cancelPublish(Long id, Long operatorUserId) {
        if (id == null) {
            throw new IllegalArgumentException("id 不能为空");
        }
        AiPromptTemplate db = this.getById(id);
        if (db == null) {
            throw new IllegalArgumentException("Prompt 不存在");
        }
        if (!Short.valueOf((short) 1).equals(db.getPublished())) {
            throw new IllegalArgumentException("当前 Prompt 不是已发布状态");
        }
        db.setPublished((short) 0);
        db.setPublishedBy(operatorUserId);
        db.setPublishedTime(null);
        this.updateById(db);
        return db;
    }
    @Override
    public AiPromptTemplate enrichTemplate(AiPromptTemplate template) {
        if (template == null) {
            return null;
        }
        if (template.getId() == null) {
            template.setContent(buildCompiledPrompt(template));
            return template;
        }
        List<AiPromptBlock> blocks = loadBlocks(template.getId());
        if (blocks.isEmpty()) {
            migrateLegacyTemplateBlocks(template);
            blocks = loadBlocks(template.getId());
        }
        applyBlocks(template, blocks);
        return template;
    }
    @Override
    public List<AiPromptTemplate> enrichTemplates(List<AiPromptTemplate> templates) {
        if (templates == null || templates.isEmpty()) {
            return templates == null ? Collections.emptyList() : templates;
        }
        List<Long> templateIds = new ArrayList<>();
        for (AiPromptTemplate template : templates) {
            if (template != null && template.getId() != null) {
                templateIds.add(template.getId());
            }
        }
        Map<Long, List<AiPromptBlock>> blockMap = groupBlocks(loadBlocks(templateIds));
        boolean migrated = false;
        for (AiPromptTemplate template : templates) {
            if (template == null || template.getId() == null) {
                continue;
            }
            List<AiPromptBlock> blocks = blockMap.get(template.getId());
            if (blocks == null || blocks.isEmpty()) {
                migrateLegacyTemplateBlocks(template);
                migrated = true;
            }
        }
        if (migrated) {
            blockMap = groupBlocks(loadBlocks(templateIds));
        }
        for (AiPromptTemplate template : templates) {
            if (template == null) {
                continue;
            }
            if (template.getId() == null) {
                template.setContent(buildCompiledPrompt(template));
                continue;
            }
            applyBlocks(template, blockMap.get(template.getId()));
        }
        return templates;
    }
    @Override
@@ -135,8 +234,9 @@
            return false;
        }
        if (Short.valueOf((short) 1).equals(db.getPublished())) {
            throw new IllegalArgumentException("已发布 Prompt 不允许删除,请先发布其他版本");
            throw new IllegalArgumentException("已发布 Prompt 不允许删除,请先取消发布");
        }
        aiPromptBlockMapper.delete(new QueryWrapper<AiPromptBlock>().eq("template_id", id));
        return this.removeById(id);
    }
@@ -145,9 +245,15 @@
    public int initDefaultsIfMissing() {
        int changed = 0;
        for (AiPromptScene scene : AiPromptScene.values()) {
            AiPromptTemplate prompt = findPublished(scene.getCode());
            if (prompt == null) {
            AiPromptTemplate latest = findLatest(scene.getCode());
            if (latest == null) {
                ensurePublishedScene(scene);
                changed++;
                continue;
            }
            List<AiPromptBlock> blocks = loadBlocks(latest.getId());
            if (blocks.isEmpty()) {
                migrateLegacyTemplateBlocks(latest);
                changed++;
            }
        }
@@ -167,33 +273,145 @@
    }
    private AiPromptTemplate ensurePublishedScene(AiPromptScene scene) {
        AiPromptTemplate latest = findLatest(scene.getCode());
        if (latest == null) {
            AiPromptTemplate seed = new AiPromptTemplate();
            seed.setName(defaultName(scene, 1, null));
            seed.setSceneCode(scene.getCode());
            seed.setVersion(1);
            seed.setContent(aiPromptUtils.getDefaultPrompt(scene.getCode()));
            seed.setStatus((short) 1);
            seed.setPublished((short) 1);
            seed.setPublishedTime(new Date());
            seed.setMemo("系统初始化默认 Prompt");
            this.save(seed);
            log.info("Initialized default AI prompt, sceneCode={}, version={}", scene.getCode(), seed.getVersion());
            return seed;
        LinkedHashMap<AiPromptBlockType, String> blocks = aiPromptUtils.getDefaultPromptBlocks(scene);
        AiPromptTemplate seed = new AiPromptTemplate();
        seed.setName(defaultName(scene, 1, null));
        seed.setSceneCode(scene.getCode());
        seed.setVersion(1);
        seed.setStatus((short) 1);
        seed.setPublished((short) 1);
        seed.setPublishedTime(new java.util.Date());
        seed.setMemo("系统初始化默认 Prompt");
        applyBlockFields(seed, blocks);
        seed.setContent(buildCompiledPrompt(seed));
        this.save(seed);
        upsertBlocks(seed.getId(), blocks);
        log.info("Initialized default AI prompt blocks, sceneCode={}, version={}", scene.getCode(), seed.getVersion());
        return seed;
    }
    private void migrateLegacyTemplateBlocks(AiPromptTemplate template) {
        if (template == null || template.getId() == null) {
            return;
        }
        AiPromptScene scene = requireScene(template.getSceneCode());
        LinkedHashMap<AiPromptBlockType, String> blocks = aiPromptUtils.resolveStoredOrDefaultBlocks(scene, template.getContent());
        upsertBlocks(template.getId(), blocks);
        applyBlockFields(template, blocks);
        template.setContent(buildCompiledPrompt(template));
        this.updateById(template);
    }
    private void applyBlocks(AiPromptTemplate template, List<AiPromptBlock> blocks) {
        List<AiPromptBlock> ordered = blocks == null ? new ArrayList<>() : new ArrayList<>(blocks);
        ordered.sort(BLOCK_SORT);
        template.setBlocks(ordered);
        LinkedHashMap<AiPromptBlockType, String> blockContent = new LinkedHashMap<>();
        for (AiPromptBlockType type : AiPromptBlockType.values()) {
            blockContent.put(type, "");
        }
        for (AiPromptBlock block : ordered) {
            AiPromptBlockType type = AiPromptBlockType.ofCode(block.getBlockType());
            if (type == null) {
                continue;
            }
            blockContent.put(type, block.getContent());
        }
        applyBlockFields(template, blockContent);
        template.setContent(buildCompiledPrompt(template));
    }
    private void applyBlockFields(AiPromptTemplate template, LinkedHashMap<AiPromptBlockType, String> blockContent) {
        template.setBasePolicy(valueOf(blockContent, AiPromptBlockType.BASE_POLICY));
        template.setToolPolicy(valueOf(blockContent, AiPromptBlockType.TOOL_POLICY));
        template.setOutputContract(valueOf(blockContent, AiPromptBlockType.OUTPUT_CONTRACT));
        template.setScenePlaybook(valueOf(blockContent, AiPromptBlockType.SCENE_PLAYBOOK));
    }
    private String valueOf(LinkedHashMap<AiPromptBlockType, String> blockContent, AiPromptBlockType type) {
        if (blockContent == null) {
            return "";
        }
        String value = blockContent.get(type);
        return value == null ? "" : value;
    }
    private String buildCompiledPrompt(AiPromptTemplate template) {
        String compiled = aiPromptComposerService.compose(template);
        return compiled == null ? "" : compiled.trim();
    }
    private LinkedHashMap<AiPromptBlockType, String> extractBlockContentMap(AiPromptTemplate template) {
        LinkedHashMap<AiPromptBlockType, String> blocks = new LinkedHashMap<>();
        blocks.put(AiPromptBlockType.BASE_POLICY, defaultString(template.getBasePolicy()));
        blocks.put(AiPromptBlockType.TOOL_POLICY, defaultString(template.getToolPolicy()));
        blocks.put(AiPromptBlockType.OUTPUT_CONTRACT, defaultString(template.getOutputContract()));
        blocks.put(AiPromptBlockType.SCENE_PLAYBOOK, defaultString(template.getScenePlaybook()));
        return blocks;
    }
    private void upsertBlocks(Long templateId, LinkedHashMap<AiPromptBlockType, String> blockContent) {
        List<AiPromptBlock> existingBlocks = loadBlocks(templateId);
        HashMap<String, AiPromptBlock> existingMap = new HashMap<>();
        for (AiPromptBlock block : existingBlocks) {
            existingMap.put(block.getBlockType(), block);
        }
        UpdateWrapper<AiPromptTemplate> clearWrapper = new UpdateWrapper<>();
        clearWrapper.eq("scene_code", scene.getCode()).set("published", 0);
        this.update(clearWrapper);
        latest.setStatus((short) 1);
        latest.setPublished((short) 1);
        if (latest.getPublishedTime() == null) {
            latest.setPublishedTime(new Date());
        for (AiPromptBlockType type : AiPromptBlockType.values()) {
            AiPromptBlock block = existingMap.get(type.getCode());
            if (block == null) {
                block = new AiPromptBlock();
                block.setTemplateId(templateId);
                block.setBlockType(type.getCode());
                block.setSortNo(type.getSort());
                block.setStatus((short) 1);
                block.setContent(defaultString(blockContent.get(type)));
                aiPromptBlockMapper.insert(block);
                continue;
            }
            block.setSortNo(type.getSort());
            block.setStatus((short) 1);
            block.setContent(defaultString(blockContent.get(type)));
            aiPromptBlockMapper.updateById(block);
        }
        this.updateById(latest);
        return latest;
    }
    private List<AiPromptBlock> loadBlocks(Long templateId) {
        if (templateId == null) {
            return Collections.emptyList();
        }
        return aiPromptBlockMapper.selectList(new QueryWrapper<AiPromptBlock>()
                .eq("template_id", templateId)
                .orderByAsc("sort_no")
                .orderByAsc("id"));
    }
    private List<AiPromptBlock> loadBlocks(List<Long> templateIds) {
        if (templateIds == null || templateIds.isEmpty()) {
            return Collections.emptyList();
        }
        return aiPromptBlockMapper.selectList(new QueryWrapper<AiPromptBlock>()
                .in("template_id", templateIds)
                .orderByAsc("sort_no")
                .orderByAsc("id"));
    }
    private Map<Long, List<AiPromptBlock>> groupBlocks(List<AiPromptBlock> blocks) {
        HashMap<Long, List<AiPromptBlock>> result = new HashMap<>();
        if (blocks == null) {
            return result;
        }
        for (AiPromptBlock block : blocks) {
            if (block == null || block.getTemplateId() == null) {
                continue;
            }
            result.computeIfAbsent(block.getTemplateId(), k -> new ArrayList<>()).add(block);
        }
        for (List<AiPromptBlock> list : result.values()) {
            list.sort(BLOCK_SORT);
        }
        return result;
    }
    private AiPromptTemplate findPublished(String sceneCode) {
@@ -259,4 +477,8 @@
        String trimmed = value.trim();
        return trimmed.isEmpty() ? null : trimmed;
    }
    private String defaultString(String value) {
        return value == null ? "" : value;
    }
}