package com.zy.ai.service;
|
|
import com.zy.ai.entity.ResponsesApiRequest;
|
import com.zy.ai.entity.ResponsesApiResponse;
|
import com.zy.ai.gateway.AiGatewayService;
|
import com.zy.ai.gateway.model.AiRequest;
|
import com.zy.ai.gateway.model.AiResponse;
|
import com.zy.ai.gateway.model.AiUsage;
|
import org.junit.jupiter.api.Test;
|
import org.mockito.ArgumentCaptor;
|
|
import java.util.List;
|
import java.util.Map;
|
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.when;
|
|
class OpenAiResponsesServiceTest {
|
|
@Test
|
void createConvertsResponsesRequestToAiRequestAndBack() {
|
AiGatewayService aiGatewayService = mock(AiGatewayService.class);
|
OpenAiResponsesService service = new OpenAiResponsesService(aiGatewayService);
|
|
ResponsesApiRequest request = new ResponsesApiRequest();
|
request.setModel("gpt-test");
|
request.setInstructions("system prompt");
|
request.setInput(List.of(Map.of(
|
"role", "user",
|
"content", List.of(Map.of("type", "input_text", "text", "hello"))
|
)));
|
request.setTemperature(0.2);
|
request.setMaxOutputTokens(128);
|
request.setTools(List.of(Map.of("type", "function")));
|
request.setToolChoice("auto");
|
|
when(aiGatewayService.generate(org.mockito.ArgumentMatchers.any(AiRequest.class))).thenReturn(aiResponse());
|
|
ResponsesApiResponse response = service.create(request);
|
|
ArgumentCaptor<AiRequest> captor = ArgumentCaptor.forClass(AiRequest.class);
|
verify(aiGatewayService).generate(captor.capture());
|
AiRequest aiRequest = captor.getValue();
|
|
assertEquals("gpt-test", aiRequest.getModel());
|
assertEquals(0.2, aiRequest.getTemperature());
|
assertEquals(128, aiRequest.getMaxTokens());
|
assertEquals("system", aiRequest.getMessages().get(0).getRole());
|
assertEquals("system prompt", aiRequest.getMessages().get(0).getContent());
|
assertEquals("user", aiRequest.getMessages().get(1).getRole());
|
assertEquals("hello", aiRequest.getMessages().get(1).getContent());
|
assertEquals(request.getInput(), aiRequest.getRawOptions().get(OpenAiResponsesService.RAW_RESPONSES_INPUT));
|
assertEquals(request.getInstructions(), aiRequest.getRawOptions().get(OpenAiResponsesService.RAW_RESPONSES_INSTRUCTIONS));
|
assertEquals(request.getTools(), aiRequest.getTools());
|
assertEquals("auto", aiRequest.getToolChoice());
|
|
assertEquals("resp_test", response.getId());
|
assertEquals("response", response.getObject());
|
assertEquals("completed", response.getStatus());
|
assertEquals("answer", response.getOutputText());
|
assertEquals("answer", response.getOutput().get(0).getContent().get(0).getText());
|
assertEquals(3, response.getUsage().getInputTokens());
|
assertEquals(4, response.getUsage().getOutputTokens());
|
assertEquals(7, response.getUsage().getTotalTokens());
|
}
|
|
@Test
|
void createRejectsEmptyInput() {
|
AiGatewayService aiGatewayService = mock(AiGatewayService.class);
|
OpenAiResponsesService service = new OpenAiResponsesService(aiGatewayService);
|
|
ResponsesApiRequest request = new ResponsesApiRequest();
|
|
assertThrows(IllegalArgumentException.class, () -> service.create(request));
|
}
|
|
private AiResponse aiResponse() {
|
AiUsage usage = new AiUsage();
|
usage.setInputTokens(3);
|
usage.setOutputTokens(4);
|
usage.setTotalTokens(7);
|
|
AiResponse response = new AiResponse();
|
response.setId("resp_test");
|
response.setModel("gpt-test");
|
response.setText("answer");
|
response.setUsage(usage);
|
return response;
|
}
|
}
|