package com.vincent.rsf.server.ai.service.impl.chat;
|
|
import com.vincent.rsf.framework.exception.CoolException;
|
import com.vincent.rsf.server.ai.entity.AiParam;
|
import com.vincent.rsf.server.ai.enums.AiErrorCategory;
|
import com.vincent.rsf.server.ai.exception.AiChatException;
|
import com.vincent.rsf.server.ai.service.impl.AiOpenAiApiSupport;
|
import io.micrometer.observation.ObservationRegistry;
|
import lombok.RequiredArgsConstructor;
|
import org.springframework.ai.model.tool.DefaultToolCallingManager;
|
import org.springframework.ai.model.tool.ToolCallingManager;
|
import org.springframework.ai.openai.OpenAiChatModel;
|
import org.springframework.ai.openai.OpenAiChatOptions;
|
import org.springframework.ai.openai.api.OpenAiApi;
|
import org.springframework.ai.tool.ToolCallback;
|
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
|
import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
|
import org.springframework.ai.util.json.schema.SchemaType;
|
import org.springframework.context.support.GenericApplicationContext;
|
import org.springframework.stereotype.Component;
|
|
import java.util.Arrays;
|
import java.util.LinkedHashMap;
|
import java.util.Map;
|
|
@Component
|
@RequiredArgsConstructor
|
public class AiOpenAiChatModelFactory {
|
|
private final GenericApplicationContext applicationContext;
|
private final ObservationRegistry observationRegistry;
|
|
public OpenAiChatModel createChatModel(AiParam aiParam) {
|
OpenAiApi openAiApi = AiOpenAiApiSupport.buildOpenAiApi(aiParam);
|
ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder()
|
.observationRegistry(observationRegistry)
|
.toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA))
|
.toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false))
|
.build();
|
return new OpenAiChatModel(
|
openAiApi,
|
OpenAiChatOptions.builder()
|
.model(aiParam.getModel())
|
.temperature(aiParam.getTemperature())
|
.topP(aiParam.getTopP())
|
.maxTokens(aiParam.getMaxTokens())
|
.streamUsage(true)
|
.build(),
|
toolCallingManager,
|
org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(),
|
observationRegistry
|
);
|
}
|
|
public OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId,
|
String requestId, Long sessionId, Map<String, Object> metadata) {
|
if (userId == null) {
|
throw new AiChatException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "当前登录用户不存在", null);
|
}
|
OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder()
|
.model(aiParam.getModel())
|
.temperature(aiParam.getTemperature())
|
.topP(aiParam.getTopP())
|
.maxTokens(aiParam.getMaxTokens())
|
.streamUsage(true)
|
.user(String.valueOf(userId));
|
if (toolCallbacks != null && toolCallbacks.length > 0) {
|
builder.toolCallbacks(Arrays.stream(toolCallbacks).toList());
|
}
|
Map<String, Object> toolContext = new LinkedHashMap<>();
|
toolContext.put("userId", userId);
|
toolContext.put("tenantId", tenantId);
|
toolContext.put("requestId", requestId);
|
toolContext.put("sessionId", sessionId);
|
Map<String, String> metadataMap = new LinkedHashMap<>();
|
if (metadata != null) {
|
metadata.forEach((key, value) -> {
|
String normalized = value == null ? "" : String.valueOf(value);
|
metadataMap.put(key, normalized);
|
toolContext.put(key, normalized);
|
});
|
}
|
builder.toolContext(toolContext);
|
if (!metadataMap.isEmpty()) {
|
builder.metadata(metadataMap);
|
}
|
return builder.build();
|
}
|
}
|