package com.zy.ai.service;
|
|
import com.zy.ai.entity.ResponsesApiRequest;
|
import com.zy.ai.entity.ResponsesApiResponse;
|
import com.zy.ai.gateway.AiGatewayService;
|
import com.zy.ai.gateway.model.AiMessage;
|
import com.zy.ai.gateway.model.AiRequest;
|
import com.zy.ai.gateway.model.AiResponse;
|
import com.zy.ai.gateway.model.AiUsage;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.stereotype.Service;
|
|
import java.util.ArrayList;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
import java.util.UUID;
|
|
@Service
|
@RequiredArgsConstructor
|
public class OpenAiResponsesService {
|
|
public static final String RAW_RESPONSES_INPUT = "openAiResponsesInput";
|
public static final String RAW_RESPONSES_INSTRUCTIONS = "openAiResponsesInstructions";
|
|
private final AiGatewayService aiGatewayService;
|
|
public ResponsesApiResponse create(ResponsesApiRequest request) {
|
AiRequest aiRequest = toAiRequest(request);
|
AiResponse aiResponse = aiGatewayService.generate(aiRequest);
|
return toResponsesApiResponse(request, aiResponse);
|
}
|
|
private AiRequest toAiRequest(ResponsesApiRequest request) {
|
List<AiMessage> messages = buildMessages(request);
|
if (messages.isEmpty() && request.getInput() == null) {
|
throw new IllegalArgumentException("input 不能为空");
|
}
|
|
AiRequest aiRequest = new AiRequest();
|
aiRequest.setModel(request.getModel());
|
aiRequest.setScene("responses");
|
aiRequest.setMessages(messages);
|
aiRequest.setTemperature(request.getTemperature());
|
aiRequest.setMaxTokens(request.getMaxOutputTokens());
|
aiRequest.setStream(false);
|
aiRequest.setTools(request.getTools());
|
aiRequest.setToolChoice(request.getToolChoice());
|
aiRequest.setRawOptions(buildRawOptions(request));
|
return aiRequest;
|
}
|
|
private Map<String, Object> buildRawOptions(ResponsesApiRequest request) {
|
LinkedHashMap<String, Object> rawOptions = new LinkedHashMap<>();
|
if (request.getInput() != null) {
|
rawOptions.put(RAW_RESPONSES_INPUT, request.getInput());
|
}
|
if (request.getInstructions() != null) {
|
rawOptions.put(RAW_RESPONSES_INSTRUCTIONS, request.getInstructions());
|
}
|
return rawOptions.isEmpty() ? null : rawOptions;
|
}
|
|
private List<AiMessage> buildMessages(ResponsesApiRequest request) {
|
List<AiMessage> messages = new ArrayList<>();
|
String instructions = extractText(request.getInstructions());
|
if (!isBlank(instructions)) {
|
messages.add(message("system", instructions));
|
}
|
|
Object input = request.getInput();
|
if (input instanceof List) {
|
appendInputItems(messages, (List<?>) input);
|
return messages;
|
}
|
|
String text = extractText(input);
|
if (!isBlank(text)) {
|
messages.add(message("user", text));
|
}
|
return messages;
|
}
|
|
private void appendInputItems(List<AiMessage> messages, List<?> items) {
|
for (Object item : items) {
|
AiMessage message = toMessage(item);
|
if (message != null) {
|
messages.add(message);
|
}
|
}
|
}
|
|
private AiMessage toMessage(Object item) {
|
if (item instanceof Map) {
|
Map<?, ?> map = (Map<?, ?>) item;
|
String role = normalizeRole(valueAsString(map.get("role")));
|
Object content = map.containsKey("content") ? map.get("content") : map.get("text");
|
String text = extractText(content);
|
if (!isBlank(text)) {
|
return message(role, text);
|
}
|
}
|
|
String text = extractText(item);
|
if (isBlank(text)) {
|
return null;
|
}
|
return message("user", text);
|
}
|
|
private AiMessage message(String role, String content) {
|
AiMessage message = new AiMessage();
|
message.setRole(role);
|
message.setContent(content);
|
return message;
|
}
|
|
private String normalizeRole(String role) {
|
if ("assistant".equals(role) || "tool".equals(role)) {
|
return role;
|
}
|
if ("system".equals(role) || "developer".equals(role)) {
|
return "system";
|
}
|
return "user";
|
}
|
|
private String extractText(Object value) {
|
if (value == null) {
|
return null;
|
}
|
if (value instanceof String) {
|
return (String) value;
|
}
|
if (value instanceof List) {
|
return extractTextList((List<?>) value);
|
}
|
if (value instanceof Map) {
|
return extractTextMap((Map<?, ?>) value);
|
}
|
return String.valueOf(value);
|
}
|
|
private String extractTextList(List<?> values) {
|
StringBuilder text = new StringBuilder();
|
for (Object value : values) {
|
String itemText = extractText(value);
|
if (isBlank(itemText)) {
|
continue;
|
}
|
if (text.length() > 0) {
|
text.append('\n');
|
}
|
text.append(itemText);
|
}
|
return text.toString();
|
}
|
|
private String extractTextMap(Map<?, ?> map) {
|
Object textValue = map.containsKey("text") ? map.get("text") : map.get("input_text");
|
if (textValue == null && map.containsKey("content")) {
|
textValue = map.get("content");
|
}
|
if (textValue == null) {
|
return null;
|
}
|
return extractText(textValue);
|
}
|
|
private ResponsesApiResponse toResponsesApiResponse(ResponsesApiRequest request, AiResponse aiResponse) {
|
if (aiResponse == null) {
|
throw new IllegalStateException("LLM 响应为空");
|
}
|
|
String responseId = isBlank(aiResponse.getId()) ? "resp_" + compactUuid() : aiResponse.getId();
|
String outputText = aiResponse.getText() == null ? "" : aiResponse.getText();
|
|
ResponsesApiResponse response = new ResponsesApiResponse();
|
response.setId(responseId);
|
response.setModel(responseModel(request, aiResponse));
|
response.setCreatedAt(System.currentTimeMillis() / 1000);
|
response.setOutputText(outputText);
|
response.setOutput(buildOutput(responseId, outputText));
|
response.setUsage(buildUsage(aiResponse.getUsage()));
|
return response;
|
}
|
|
private List<ResponsesApiResponse.OutputItem> buildOutput(String responseId, String outputText) {
|
ResponsesApiResponse.ContentItem contentItem = new ResponsesApiResponse.ContentItem();
|
contentItem.setText(outputText);
|
|
ResponsesApiResponse.OutputItem outputItem = new ResponsesApiResponse.OutputItem();
|
outputItem.setId("msg_" + responseId.replace("resp_", ""));
|
outputItem.setContent(List.of(contentItem));
|
return List.of(outputItem);
|
}
|
|
private ResponsesApiResponse.Usage buildUsage(AiUsage aiUsage) {
|
if (aiUsage == null) {
|
return null;
|
}
|
ResponsesApiResponse.Usage usage = new ResponsesApiResponse.Usage();
|
usage.setInputTokens(aiUsage.getInputTokens());
|
usage.setOutputTokens(aiUsage.getOutputTokens());
|
usage.setTotalTokens(aiUsage.getTotalTokens());
|
return usage;
|
}
|
|
private String responseModel(ResponsesApiRequest request, AiResponse response) {
|
if (request != null && !isBlank(request.getModel())) {
|
return request.getModel();
|
}
|
if (response != null && !isBlank(response.getModel())) {
|
return response.getModel();
|
}
|
return "wcs-routed-model";
|
}
|
|
private String compactUuid() {
|
return UUID.randomUUID().toString().replace("-", "");
|
}
|
|
private String valueAsString(Object value) {
|
return value == null ? null : String.valueOf(value);
|
}
|
|
private boolean isBlank(String text) {
|
return text == null || text.trim().isEmpty();
|
}
|
}
|