package com.vincent.rsf.server.ai.service.impl.chat;
|
|
import com.vincent.rsf.server.ai.service.AiCallLogService;
|
import com.vincent.rsf.server.ai.service.MountedToolCallback;
|
import com.vincent.rsf.server.ai.store.AiCachedToolResult;
|
import com.vincent.rsf.server.ai.store.AiToolResultStore;
|
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.mockito.Mock;
|
import org.mockito.junit.jupiter.MockitoExtension;
|
import org.springframework.ai.tool.ToolCallback;
|
import org.springframework.ai.tool.definition.ToolDefinition;
|
|
import java.util.concurrent.atomic.AtomicLong;
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
import static org.mockito.ArgumentMatchers.anyLong;
|
import static org.mockito.ArgumentMatchers.eq;
|
import static org.mockito.Mockito.never;
|
import static org.mockito.Mockito.verify;
|
import static org.mockito.Mockito.when;
|
|
@ExtendWith(MockitoExtension.class)
|
class AiToolObservationServiceTest {
|
|
@Mock
|
private AiToolResultStore aiToolResultStore;
|
@Mock
|
private AiCallLogService aiCallLogService;
|
@Mock
|
private AiChatTraceEmitter traceEmitter;
|
@Mock
|
private MountedToolCallback mountedToolCallback;
|
@Mock
|
private ToolDefinition toolDefinition;
|
|
@Test
|
void shouldSkipStartTraceWhenToolResultComesFromCache() {
|
AiToolObservationService aiToolObservationService = new AiToolObservationService(aiToolResultStore, aiCallLogService);
|
when(mountedToolCallback.getToolDefinition()).thenReturn(toolDefinition);
|
when(toolDefinition.name()).thenReturn("inventory.lookup");
|
when(mountedToolCallback.getMountName()).thenReturn("builtin-stock");
|
when(aiToolResultStore.getToolResult(1L, "req-1", "inventory.lookup", "{\"code\":\"A01\"}"))
|
.thenReturn(AiCachedToolResult.builder()
|
.success(true)
|
.output("{\"stock\":12}")
|
.build());
|
|
ToolCallback[] callbacks = aiToolObservationService.wrapToolCallbacks(
|
new ToolCallback[]{mountedToolCallback},
|
"req-1",
|
11L,
|
new AtomicLong(0),
|
new AtomicLong(0),
|
new AtomicLong(0),
|
21L,
|
31L,
|
1L,
|
traceEmitter
|
);
|
|
String output = callbacks[0].call("{\"code\":\"A01\"}");
|
|
assertThat(output).isEqualTo("{\"stock\":12}");
|
verify(traceEmitter, never()).onToolStart(eq("inventory.lookup"), eq("builtin-stock"), eq("req-1-tool-1"), eq("{\"code\":\"A01\"}"), anyLong());
|
verify(traceEmitter).onToolResult(eq("inventory.lookup"), eq("builtin-stock"), eq("req-1-tool-1"),
|
eq("{\"code\":\"A01\"}"), eq("{\"stock\":12}"), eq(null), eq(0L), anyLong(), eq(false));
|
verify(mountedToolCallback, never()).call("{\"code\":\"A01\"}");
|
}
|
}
|