package com.vincent.rsf.server.ai.service.impl;
|
|
import com.vincent.rsf.server.ai.dto.AiPromptPreviewDto;
|
import org.springframework.stereotype.Component;
|
import org.springframework.util.StringUtils;
|
|
import java.util.ArrayList;
|
import java.util.LinkedHashSet;
|
import java.util.List;
|
import java.util.Map;
|
import java.util.Objects;
|
import java.util.regex.Matcher;
|
import java.util.regex.Pattern;
|
|
@Component
|
public class AiPromptRenderSupport {
|
|
private static final Pattern VARIABLE_PATTERN = Pattern.compile("\\{\\{?([a-zA-Z0-9_.-]+)}}?");
|
|
public AiPromptPreviewDto render(String systemPrompt, String userPromptTemplate, String input, Map<String, Object> metadata) {
|
String finalInput = input == null ? "" : input;
|
return AiPromptPreviewDto.builder()
|
.renderedSystemPrompt(renderTemplate(systemPrompt, finalInput, metadata))
|
.renderedUserPrompt(renderUserPrompt(userPromptTemplate, finalInput, metadata))
|
.resolvedVariables(resolveVariables(systemPrompt, userPromptTemplate, metadata))
|
.build();
|
}
|
|
public String renderUserPrompt(String userPromptTemplate, String input, Map<String, Object> metadata) {
|
if (!StringUtils.hasText(userPromptTemplate)) {
|
return input;
|
}
|
String rendered = replaceTemplateVariables(userPromptTemplate, input, metadata);
|
if (Objects.equals(rendered, userPromptTemplate)) {
|
return userPromptTemplate + "\n\n" + input;
|
}
|
return rendered;
|
}
|
|
private String renderTemplate(String template, String input, Map<String, Object> metadata) {
|
if (!StringUtils.hasText(template)) {
|
return template;
|
}
|
return replaceTemplateVariables(template, input, metadata);
|
}
|
|
private String replaceTemplateVariables(String template, String input, Map<String, Object> metadata) {
|
String rendered = template
|
.replace("{{input}}", input)
|
.replace("{input}", input);
|
if (metadata == null || metadata.isEmpty()) {
|
return rendered;
|
}
|
for (Map.Entry<String, Object> entry : metadata.entrySet()) {
|
String value = entry.getValue() == null ? "" : String.valueOf(entry.getValue());
|
rendered = rendered.replace("{{" + entry.getKey() + "}}", value);
|
rendered = rendered.replace("{" + entry.getKey() + "}", value);
|
}
|
return rendered;
|
}
|
|
private List<String> resolveVariables(String systemPrompt, String userPromptTemplate, Map<String, Object> metadata) {
|
LinkedHashSet<String> variables = new LinkedHashSet<>();
|
collectVariables(variables, systemPrompt);
|
collectVariables(variables, userPromptTemplate);
|
if (metadata != null && !metadata.isEmpty()) {
|
variables.addAll(metadata.keySet());
|
}
|
return new ArrayList<>(variables);
|
}
|
|
private void collectVariables(LinkedHashSet<String> variables, String template) {
|
if (!StringUtils.hasText(template)) {
|
return;
|
}
|
Matcher matcher = VARIABLE_PATTERN.matcher(template);
|
while (matcher.find()) {
|
variables.add(matcher.group(1));
|
}
|
}
|
}
|