diff --git a/internal/agent/act.go b/internal/agent/act.go index a0526fa65..60b0f10ba 100644 --- a/internal/agent/act.go +++ b/internal/agent/act.go @@ -11,6 +11,7 @@ import ( "github.com/Tencent/WeKnora/internal/common" "github.com/Tencent/WeKnora/internal/event" "github.com/Tencent/WeKnora/internal/logger" + "github.com/Tencent/WeKnora/internal/models/chat" "github.com/Tencent/WeKnora/internal/types" "golang.org/x/sync/errgroup" ) @@ -67,7 +68,7 @@ func formatToolHint(name string, args map[string]any) string { // When ParallelToolCalls is enabled and there are 2+ tool calls, they execute concurrently. func (e *AgentEngine) executeToolCalls( ctx context.Context, response *types.ChatResponse, - step *types.AgentStep, iteration int, sessionID string, + step *types.AgentStep, iteration int, sessionID string, messages []chat.Message, ) { if len(response.ToolCalls) == 0 { return @@ -79,12 +80,12 @@ func (e *AgentEngine) executeToolCalls( // Use parallel execution when enabled and there are multiple tool calls if e.config.ParallelToolCalls && n >= 2 { - e.executeToolCallsParallel(ctx, response, step, iteration, sessionID) + e.executeToolCallsParallel(ctx, response, step, iteration, sessionID, messages) return } for i, tc := range response.ToolCalls { - e.executeSingleToolCall(ctx, tc, i, step, iteration, round, sessionID) + e.executeSingleToolCall(ctx, response.ToolCalls, tc, i, step, iteration, round, sessionID, messages) } } @@ -92,7 +93,7 @@ func (e *AgentEngine) executeToolCalls( // collecting results in original order. func (e *AgentEngine) executeToolCallsParallel( ctx context.Context, response *types.ChatResponse, - step *types.AgentStep, iteration int, sessionID string, + step *types.AgentStep, iteration int, sessionID string, messages []chat.Message, ) { round := iteration + 1 n := len(response.ToolCalls) @@ -105,7 +106,7 @@ func (e *AgentEngine) executeToolCallsParallel( for i, tc := range response.ToolCalls { i, tc := i, tc // capture loop vars g.Go(func() error { - toolCall := e.runToolCall(gCtx, tc, i, iteration, round, sessionID) + toolCall := e.runToolCall(gCtx, response.ToolCalls, tc, i, iteration, round, sessionID, messages) mu.Lock() results[i] = toolCall mu.Unlock() @@ -159,10 +160,10 @@ func (e *AgentEngine) executeToolCallsParallel( // executeSingleToolCall runs one tool call sequentially (original behavior). func (e *AgentEngine) executeSingleToolCall( - ctx context.Context, tc types.LLMToolCall, i int, - step *types.AgentStep, iteration, round int, sessionID string, + ctx context.Context, toolCalls []types.LLMToolCall, tc types.LLMToolCall, i int, + step *types.AgentStep, iteration, round int, sessionID string, messages []chat.Message, ) { - toolCall := e.runToolCall(ctx, tc, i, iteration, round, sessionID) + toolCall := e.runToolCall(ctx, toolCalls, tc, i, iteration, round, sessionID, messages) step.ToolCalls = append(step.ToolCalls, toolCall) result := toolCall.Result @@ -205,8 +206,8 @@ func (e *AgentEngine) executeSingleToolCall( // runToolCall handles argument parsing, execution, logging, and pipeline events for a single tool call. // It returns the completed ToolCall struct. Safe to call from multiple goroutines. func (e *AgentEngine) runToolCall( - ctx context.Context, tc types.LLMToolCall, i int, - iteration, round int, sessionID string, + ctx context.Context, allToolCalls []types.LLMToolCall, tc types.LLMToolCall, i int, + iteration, round int, sessionID string, messages []chat.Message, ) types.ToolCall { tc.ID = agenttools.NormalizeToolCallID(tc.ID, tc.Function.Name, i) total := "?" // unknown in isolation; callers log the batch size @@ -239,6 +240,29 @@ func (e *AgentEngine) runToolCall( toolCallStartTime := time.Now() + if msgIdx, tcIdx, duplicated := FindDuplicateToolCallInMessages(messages, tc.Function.Name, tc.Function.Arguments); duplicated { + logger.Warnf(ctx, "%s Duplicate tool call found in history (messages[%d].tool_calls[%d])", toolTag, msgIdx, tcIdx) + duration := time.Since(toolCallStartTime).Milliseconds() + return types.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Args: args, + Result: &types.ToolResult{Success: true, Output: BuildDuplicateToolCallWarningCN(tc.Function.Name, msgIdx, tcIdx)}, + Duration: duration, + } + } + if prevIdx, duplicated := FindDuplicateToolCallInCurrentResponseBeforeIndex(allToolCalls, i); duplicated { + logger.Warnf(ctx, "%s Duplicate tool call found earlier in current response (tool_calls[%d])", toolTag, prevIdx) + duration := time.Since(toolCallStartTime).Milliseconds() + return types.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Args: args, + Result: &types.ToolResult{Success: true, Output: BuildDuplicateToolCallWarningCN(tc.Function.Name, -1, prevIdx)}, + Duration: duration, + } + } + // Emit tool hint for UI progress display toolHint := formatToolHint(tc.Function.Name, args) e.eventBus.Emit(ctx, event.Event{ diff --git a/internal/agent/engine.go b/internal/agent/engine.go index 22111fe19..21e9bfb34 100644 --- a/internal/agent/engine.go +++ b/internal/agent/engine.go @@ -366,7 +366,7 @@ func (e *AgentEngine) executeLoop( } // 3. Act: Execute tool calls - e.executeToolCalls(ctx, response, &step, state.CurrentRound, sessionID) + e.executeToolCalls(ctx, response, &step, state.CurrentRound, sessionID, messages) // 4. Observe: Add tool results to messages and write to context state.RoundSteps = append(state.RoundSteps, step) diff --git a/internal/agent/tool_call_repeat_guard.go b/internal/agent/tool_call_repeat_guard.go new file mode 100644 index 000000000..3542d45dc --- /dev/null +++ b/internal/agent/tool_call_repeat_guard.go @@ -0,0 +1,115 @@ +package agent + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + + "github.com/Tencent/WeKnora/internal/models/chat" + "github.com/Tencent/WeKnora/internal/types" +) + +// JSONToolArgumentsSemanticallyEqual compares two JSON argument strings after parsing. +// Whitespace, key order, and number formatting differences do not affect equality. +// Invalid JSON on either side returns false. +func JSONToolArgumentsSemanticallyEqual(a, b string) bool { + var va, vb any + if err := json.Unmarshal([]byte(a), &va); err != nil { + return false + } + if err := json.Unmarshal([]byte(b), &vb); err != nil { + return false + } + return reflect.DeepEqual(va, vb) +} + +// CanonicalJSONArgs returns a stable string form for logging and fingerprinting: +// parsed then re-marshaled so map key order is deterministic (Go json.Marshal sorts map keys). +func CanonicalJSONArgs(argsJSON string) (string, error) { + var v any + if err := json.Unmarshal([]byte(argsJSON), &v); err != nil { + return "", err + } + b, err := json.Marshal(v) + if err != nil { + return "", err + } + return string(b), nil +} + +// BuildDuplicateToolCallWarningCN returns a warning with first seen indexes. +func BuildDuplicateToolCallWarningCN(toolName string, historyMessageIndex, toolCallIndex int) string { + if historyMessageIndex < 0 { + return fmt.Sprintf( + "[System Notice] Duplicate tool call detected: `%s` has arguments identical to an earlier call in the current response (ignoring JSON formatting differences), at current_response.tool_calls[%d]. Do not repeat the same call; adjust parameters or continue reasoning based on existing results.", + toolName, toolCallIndex, + ) + } + return fmt.Sprintf( + "[System Notice] Duplicate tool call detected: `%s` has arguments identical to a previous historical call (ignoring JSON formatting differences). First seen at messages[%d].tool_calls[%d]. Do not repeat the same call; adjust parameters or continue reasoning based on existing results.", + toolName, historyMessageIndex, toolCallIndex, + ) +} + +// FindDuplicateToolCallInMessages scans assistant messages to find whether +// (tool name + semantically equal JSON args) has appeared before. +// Returns first hit indexes and true when duplicated. +func FindDuplicateToolCallInMessages( + messages []chat.Message, + toolName string, + argsJSON string, +) (historyMessageIndex int, toolCallIndex int, duplicated bool) { + target, err := CanonicalJSONArgs(argsJSON) + if err != nil { + return -1, -1, false + } + for mi, m := range messages { + if m.Role != "assistant" || len(m.ToolCalls) == 0 { + continue + } + for ti, tc := range m.ToolCalls { + if !strings.EqualFold(tc.Function.Name, toolName) { + continue + } + canon, err := CanonicalJSONArgs(tc.Function.Arguments) + if err != nil { + continue + } + if canon == target { + return mi, ti, true + } + } + } + return -1, -1, false +} + +// FindDuplicateToolCallInCurrentResponseBeforeIndex checks tool calls that +// already appeared earlier in the same LLM response. +func FindDuplicateToolCallInCurrentResponseBeforeIndex( + toolCalls []types.LLMToolCall, + currentIndex int, +) (prevIndex int, duplicated bool) { + if currentIndex <= 0 || currentIndex >= len(toolCalls) { + return -1, false + } + curr := toolCalls[currentIndex] + currCanon, err := CanonicalJSONArgs(curr.Function.Arguments) + if err != nil { + return -1, false + } + for i := 0; i < currentIndex; i++ { + prev := toolCalls[i] + if !strings.EqualFold(prev.Function.Name, curr.Function.Name) { + continue + } + prevCanon, err := CanonicalJSONArgs(prev.Function.Arguments) + if err != nil { + continue + } + if prevCanon == currCanon { + return i, true + } + } + return -1, false +} diff --git a/internal/agent/tool_call_repeat_guard_test.go b/internal/agent/tool_call_repeat_guard_test.go new file mode 100644 index 000000000..deaf14cb9 --- /dev/null +++ b/internal/agent/tool_call_repeat_guard_test.go @@ -0,0 +1,151 @@ +package agent + +import ( + "testing" + + "github.com/Tencent/WeKnora/internal/models/chat" + "github.com/Tencent/WeKnora/internal/types" +) + +func TestJSONToolArgumentsSemanticallyEqual_userExample(t *testing.T) { + a := `{ + "req": [ + { + "knowledge_id": "fbd587fe-6249-448e-b174-d5818af5b42f", + "limit": 20, + "offset": 0 + }, + { + "knowledge_id": "a6e73511-c239-4c35-b2ca-a78b4354e5e5", + "limit": 20, + "offset": 0 + }, + { + "knowledge_id": "06bc4b39-3118-48e9-b4cc-ff8cc1e618d2", + "limit": 20, + "offset": 0 + }, + { + "knowledge_id": "65fb8383-99e6-4b2f-b711-130c5a6dd4aa", + "limit": 20, + "offset": 0 + }, + { + "knowledge_id": "cdb87e7f-2a47-4394-8a6f-ae0bc75e9969", + "limit": 20, + "offset": 0 + } + ] +}` + b := `{ + "req" : [ { + "knowledge_id" : "a6e73511-c239-4c35-b2ca-a78b4354e5e5", + "limit" : 20, + "offset" : 0 + }, { + "knowledge_id" : "fbd587fe-6249-448e-b174-d5818af5b42f", + "limit" : 20, + "offset" : 0 + }, { + "knowledge_id" : "06bc4b39-3118-48e9-b4cc-ff8cc1e618d2", + "limit" : 20, + "offset" : 0 + }, { + "knowledge_id" : "65fb8383-99e6-4b2f-b711-130c5a6dd4aa", + "limit" : 20, + "offset" : 0 + }, { + "knowledge_id" : "cdb87e7f-2a47-4394-8a6f-ae0bc75e9969", + "limit" : 20, + "offset" : 0 + } ] +}` + + if !JSONToolArgumentsSemanticallyEqual(a, b) { + t.Fatal("expected semantically equal JSON (format differs)") + } +} + +func TestFindDuplicateToolCallInMessages(t *testing.T) { + messages := []chat.Message{ + {Role: "user", Content: "hi"}, + { + Role: "assistant", + ToolCalls: []chat.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: chat.FunctionCall{ + Name: "list_knowledge_chunks", + Arguments: `{"req":[{"knowledge_id":"x","limit":20,"offset":0}]}`, + }, + }, + }, + }, + } + + msgIdx, tcIdx, ok := FindDuplicateToolCallInMessages( + messages, + "list_knowledge_chunks", + `{ + "req" : [ { + "knowledge_id" : "x", + "limit" : 20, + "offset" : 0 + } ] + }`, + ) + if !ok { + t.Fatal("expected duplicate in assistant history") + } + if msgIdx != 1 || tcIdx != 0 { + t.Fatalf("unexpected hit index: msg=%d tc=%d", msgIdx, tcIdx) + } +} + +func TestFindDuplicateToolCallInMessages_notDuplicate(t *testing.T) { + messages := []chat.Message{ + { + Role: "assistant", + ToolCalls: []chat.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: chat.FunctionCall{ + Name: "knowledge_search", + Arguments: `{"query":"a"}`, + }, + }, + }, + }, + } + _, _, ok := FindDuplicateToolCallInMessages(messages, "knowledge_search", `{"query":"b"}`) + if ok { + t.Fatal("expected non-duplicate when args differ") + } +} + +func TestFindDuplicateToolCallInCurrentResponseBeforeIndex(t *testing.T) { + toolCalls := []types.LLMToolCall{ + {Function: types.FunctionCall{Name: "knowledge_search", Arguments: `{"query":"a"}`}}, + {Function: types.FunctionCall{Name: "knowledge_search", Arguments: `{"query":"a"}`}}, + {Function: types.FunctionCall{Name: "knowledge_search", Arguments: `{"query":"b"}`}}, + } + prev, ok := FindDuplicateToolCallInCurrentResponseBeforeIndex(toolCalls, 1) + if !ok || prev != 0 { + t.Fatalf("expected duplicate at index 0, got prev=%d ok=%v", prev, ok) + } + if _, ok := FindDuplicateToolCallInCurrentResponseBeforeIndex(toolCalls, 2); ok { + t.Fatal("index 2 should not be duplicate") + } +} + +func TestBuildDuplicateToolCallWarningCN(t *testing.T) { + msg := BuildDuplicateToolCallWarningCN("knowledge_search", 3, 1) + if msg == "" { + t.Fatal("warning should not be empty") + } + if msg[0] != '[' { + t.Fatal("warning format seems broken") + } +}