package com.zy.ai.gateway.adapter.openai;
|
|
import com.zy.ai.entity.LlmRouteConfig;
|
import com.zy.ai.entity.ChatCompletionRequest;
|
import com.zy.ai.gateway.model.AiMessage;
|
import com.zy.ai.gateway.model.AiRequest;
|
import com.zy.ai.gateway.model.AiResponse;
|
import com.zy.ai.service.OpenAiResponsesService;
|
import org.junit.jupiter.api.Test;
|
import org.springframework.web.client.RestClient;
|
|
import java.lang.reflect.Method;
|
import java.util.LinkedHashMap;
|
import java.util.List;
|
import java.util.Map;
|
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
class OpenAiResponsesAdapterTest {
|
|
@Test
|
void supportsOnlyOpenAiResponsesProtocol() {
|
OpenAiResponsesAdapter adapter = new OpenAiResponsesAdapter(RestClient.builder());
|
LlmRouteConfig route = new LlmRouteConfig();
|
route.setProviderType("OPENAI_COMPATIBLE");
|
route.setProtocolType("OPENAI_RESPONSES");
|
|
assertTrue(adapter.supports(route, new AiRequest()));
|
|
route.setProtocolType("OPENAI_CHAT_COMPLETIONS");
|
assertFalse(adapter.supports(route, new AiRequest()));
|
}
|
|
@Test
|
@SuppressWarnings("unchecked")
|
void requestBodyPreservesResponsesInputToolsAndRouteOptions() throws Exception {
|
OpenAiResponsesAdapter adapter = new OpenAiResponsesAdapter(RestClient.builder());
|
LlmRouteConfig route = new LlmRouteConfig();
|
route.setModel("gpt-test");
|
route.setRequestOptions("{\"reasoning\":{\"effort\":\"medium\"},\"model\":\"ignored\"}");
|
|
List<Map<String, Object>> input = List.of(Map.of(
|
"role", "user",
|
"content", List.of(Map.of("type", "input_image", "image_url", "https://example.com/a.png"))
|
));
|
List<Map<String, Object>> tools = List.of(Map.of("type", "function"));
|
|
AiRequest request = new AiRequest();
|
request.setRawOptions(new LinkedHashMap<>());
|
request.getRawOptions().put(OpenAiResponsesService.RAW_RESPONSES_INPUT, input);
|
request.getRawOptions().put(OpenAiResponsesService.RAW_RESPONSES_INSTRUCTIONS, "be concise");
|
request.setTools(tools);
|
request.setToolChoice("auto");
|
request.setMaxTokens(64);
|
|
Method requestBody = OpenAiResponsesAdapter.class.getDeclaredMethod("requestBody", LlmRouteConfig.class, AiRequest.class);
|
requestBody.setAccessible(true);
|
Map<String, Object> body = (Map<String, Object>) requestBody.invoke(adapter, route, request);
|
|
assertEquals("gpt-test", body.get("model"));
|
assertEquals(input, body.get("input"));
|
assertEquals("be concise", body.get("instructions"));
|
assertEquals(tools, body.get("tools"));
|
assertEquals("auto", body.get("tool_choice"));
|
assertEquals(64, body.get("max_output_tokens"));
|
assertEquals("medium", ((Map<?, ?>) body.get("reasoning")).get("effort"));
|
}
|
|
@Test
|
@SuppressWarnings("unchecked")
|
void requestBodyConvertsChatToolsAndToolResultMessagesToResponsesShape() throws Exception {
|
OpenAiResponsesAdapter adapter = new OpenAiResponsesAdapter(RestClient.builder());
|
LlmRouteConfig route = new LlmRouteConfig();
|
route.setModel("gpt-test");
|
|
ChatCompletionRequest.ToolCall toolCall = toolCall("call_1", "get_status", "{\"device\":\"crn1\"}");
|
|
AiMessage assistant = new AiMessage();
|
assistant.setRole("assistant");
|
assistant.setToolCalls(List.of(toolCall));
|
|
AiMessage tool = new AiMessage();
|
tool.setRole("tool");
|
tool.setToolCallId("call_1");
|
tool.setContent("{\"ok\":true}");
|
|
AiRequest request = new AiRequest();
|
request.setMessages(List.of(assistant, tool));
|
request.setTools(List.of(Map.of(
|
"type", "function",
|
"function", Map.of(
|
"name", "get_status",
|
"description", "get device status",
|
"parameters", Map.of("type", "object")
|
)
|
)));
|
|
Method requestBody = OpenAiResponsesAdapter.class.getDeclaredMethod("requestBody", LlmRouteConfig.class, AiRequest.class);
|
requestBody.setAccessible(true);
|
Map<String, Object> body = (Map<String, Object>) requestBody.invoke(adapter, route, request);
|
|
List<Map<String, Object>> input = (List<Map<String, Object>>) body.get("input");
|
assertEquals("function_call", input.get(0).get("type"));
|
assertEquals("call_1", input.get(0).get("call_id"));
|
assertEquals("get_status", input.get(0).get("name"));
|
assertEquals("function_call_output", input.get(1).get("type"));
|
assertEquals("call_1", input.get(1).get("call_id"));
|
|
List<Map<String, Object>> tools = (List<Map<String, Object>>) body.get("tools");
|
assertEquals("function", tools.get(0).get("type"));
|
assertEquals("get_status", tools.get(0).get("name"));
|
assertFalse(tools.get(0).containsKey("function"));
|
}
|
|
@Test
|
@SuppressWarnings("unchecked")
|
void responseFunctionCallMapsToChatToolCall() throws Exception {
|
OpenAiResponsesAdapter adapter = new OpenAiResponsesAdapter(RestClient.builder());
|
LlmRouteConfig route = new LlmRouteConfig();
|
route.setModel("gpt-test");
|
|
Map<String, Object> source = Map.of(
|
"id", "resp_1",
|
"output", List.of(Map.of(
|
"type", "function_call",
|
"call_id", "call_1",
|
"name", "get_status",
|
"arguments", "{\"device\":\"crn1\"}"
|
))
|
);
|
|
Method toAiResponse = OpenAiResponsesAdapter.class.getDeclaredMethod("toAiResponse", Map.class, LlmRouteConfig.class);
|
toAiResponse.setAccessible(true);
|
AiResponse response = (AiResponse) toAiResponse.invoke(adapter, source, route);
|
|
assertEquals("resp_1", response.getId());
|
assertEquals("assistant", response.getMessage().getRole());
|
List<Object> toolCalls = response.getMessage().getToolCalls();
|
ChatCompletionRequest.ToolCall toolCall = (ChatCompletionRequest.ToolCall) toolCalls.get(0);
|
assertEquals("call_1", toolCall.getId());
|
assertEquals("function", toolCall.getType());
|
assertEquals("get_status", toolCall.getFunction().getName());
|
assertEquals("{\"device\":\"crn1\"}", toolCall.getFunction().getArguments());
|
}
|
|
private ChatCompletionRequest.ToolCall toolCall(String id, String name, String arguments) {
|
ChatCompletionRequest.Function function = new ChatCompletionRequest.Function();
|
function.setName(name);
|
function.setArguments(arguments);
|
|
ChatCompletionRequest.ToolCall toolCall = new ChatCompletionRequest.ToolCall();
|
toolCall.setId(id);
|
toolCall.setType("function");
|
toolCall.setFunction(function);
|
return toolCall;
|
}
|
}
|