package com.vincent.rsf.server.ai.service.impl.chat; import com.vincent.rsf.framework.exception.CoolException; import com.vincent.rsf.server.ai.entity.AiParam; import com.vincent.rsf.server.ai.enums.AiErrorCategory; import com.vincent.rsf.server.ai.exception.AiChatException; import com.vincent.rsf.server.ai.service.impl.AiOpenAiApiSupport; import io.micrometer.observation.ObservationRegistry; import lombok.RequiredArgsConstructor; import org.springframework.ai.model.tool.DefaultToolCallingManager; import org.springframework.ai.model.tool.ToolCallingManager; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor; import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; import org.springframework.ai.util.json.schema.SchemaType; import org.springframework.context.support.GenericApplicationContext; import org.springframework.stereotype.Component; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.Map; @Component @RequiredArgsConstructor public class AiOpenAiChatModelFactory { private final GenericApplicationContext applicationContext; private final ObservationRegistry observationRegistry; public OpenAiChatModel createChatModel(AiParam aiParam) { OpenAiApi openAiApi = AiOpenAiApiSupport.buildOpenAiApi(aiParam); ToolCallingManager toolCallingManager = DefaultToolCallingManager.builder() .observationRegistry(observationRegistry) .toolCallbackResolver(new SpringBeanToolCallbackResolver(applicationContext, SchemaType.OPEN_API_SCHEMA)) .toolExecutionExceptionProcessor(new DefaultToolExecutionExceptionProcessor(false)) .build(); return new OpenAiChatModel( openAiApi, OpenAiChatOptions.builder() .model(aiParam.getModel()) .temperature(aiParam.getTemperature()) .topP(aiParam.getTopP()) .maxTokens(aiParam.getMaxTokens()) .streamUsage(true) .build(), toolCallingManager, org.springframework.retry.support.RetryTemplate.builder().maxAttempts(1).build(), observationRegistry ); } public OpenAiChatOptions buildChatOptions(AiParam aiParam, ToolCallback[] toolCallbacks, Long userId, Long tenantId, String requestId, Long sessionId, Map metadata) { if (userId == null) { throw new AiChatException("AI_AUTH_USER_MISSING", AiErrorCategory.AUTH, "OPTIONS_BUILD", "当前登录用户不存在", null); } OpenAiChatOptions.Builder builder = OpenAiChatOptions.builder() .model(aiParam.getModel()) .temperature(aiParam.getTemperature()) .topP(aiParam.getTopP()) .maxTokens(aiParam.getMaxTokens()) .streamUsage(true) .user(String.valueOf(userId)); if (toolCallbacks != null && toolCallbacks.length > 0) { builder.toolCallbacks(Arrays.stream(toolCallbacks).toList()); } Map toolContext = new LinkedHashMap<>(); toolContext.put("userId", userId); toolContext.put("tenantId", tenantId); toolContext.put("requestId", requestId); toolContext.put("sessionId", sessionId); Map metadataMap = new LinkedHashMap<>(); if (metadata != null) { metadata.forEach((key, value) -> { String normalized = value == null ? "" : String.valueOf(value); metadataMap.put(key, normalized); toolContext.put(key, normalized); }); } builder.toolContext(toolContext); if (!metadataMap.isEmpty()) { builder.metadata(metadataMap); } return builder.build(); } }