Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions internal/agent/act.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/Tencent/WeKnora/internal/common"
"github.com/Tencent/WeKnora/internal/event"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/models/chat"
"github.com/Tencent/WeKnora/internal/types"
"golang.org/x/sync/errgroup"
)
Expand Down Expand Up @@ -67,7 +68,7 @@ func formatToolHint(name string, args map[string]any) string {
// When ParallelToolCalls is enabled and there are 2+ tool calls, they execute concurrently.
func (e *AgentEngine) executeToolCalls(
ctx context.Context, response *types.ChatResponse,
step *types.AgentStep, iteration int, sessionID string,
step *types.AgentStep, iteration int, sessionID string, messages []chat.Message,
) {
if len(response.ToolCalls) == 0 {
return
Expand All @@ -79,20 +80,20 @@ func (e *AgentEngine) executeToolCalls(

// Use parallel execution when enabled and there are multiple tool calls
if e.config.ParallelToolCalls && n >= 2 {
e.executeToolCallsParallel(ctx, response, step, iteration, sessionID)
e.executeToolCallsParallel(ctx, response, step, iteration, sessionID, messages)
return
}

for i, tc := range response.ToolCalls {
e.executeSingleToolCall(ctx, tc, i, step, iteration, round, sessionID)
e.executeSingleToolCall(ctx, response.ToolCalls, tc, i, step, iteration, round, sessionID, messages)
}
}

// executeToolCallsParallel runs all tool calls concurrently using errgroup,
// collecting results in original order.
func (e *AgentEngine) executeToolCallsParallel(
ctx context.Context, response *types.ChatResponse,
step *types.AgentStep, iteration int, sessionID string,
step *types.AgentStep, iteration int, sessionID string, messages []chat.Message,
) {
round := iteration + 1
n := len(response.ToolCalls)
Expand All @@ -105,7 +106,7 @@ func (e *AgentEngine) executeToolCallsParallel(
for i, tc := range response.ToolCalls {
i, tc := i, tc // capture loop vars
g.Go(func() error {
toolCall := e.runToolCall(gCtx, tc, i, iteration, round, sessionID)
toolCall := e.runToolCall(gCtx, response.ToolCalls, tc, i, iteration, round, sessionID, messages)
mu.Lock()
results[i] = toolCall
mu.Unlock()
Expand Down Expand Up @@ -159,10 +160,10 @@ func (e *AgentEngine) executeToolCallsParallel(

// executeSingleToolCall runs one tool call sequentially (original behavior).
func (e *AgentEngine) executeSingleToolCall(
ctx context.Context, tc types.LLMToolCall, i int,
step *types.AgentStep, iteration, round int, sessionID string,
ctx context.Context, toolCalls []types.LLMToolCall, tc types.LLMToolCall, i int,
step *types.AgentStep, iteration, round int, sessionID string, messages []chat.Message,
) {
toolCall := e.runToolCall(ctx, tc, i, iteration, round, sessionID)
toolCall := e.runToolCall(ctx, toolCalls, tc, i, iteration, round, sessionID, messages)
step.ToolCalls = append(step.ToolCalls, toolCall)

result := toolCall.Result
Expand Down Expand Up @@ -205,8 +206,8 @@ func (e *AgentEngine) executeSingleToolCall(
// runToolCall handles argument parsing, execution, logging, and pipeline events for a single tool call.
// It returns the completed ToolCall struct. Safe to call from multiple goroutines.
func (e *AgentEngine) runToolCall(
ctx context.Context, tc types.LLMToolCall, i int,
iteration, round int, sessionID string,
ctx context.Context, allToolCalls []types.LLMToolCall, tc types.LLMToolCall, i int,
iteration, round int, sessionID string, messages []chat.Message,
) types.ToolCall {
tc.ID = agenttools.NormalizeToolCallID(tc.ID, tc.Function.Name, i)
total := "?" // unknown in isolation; callers log the batch size
Expand Down Expand Up @@ -239,6 +240,29 @@ func (e *AgentEngine) runToolCall(

toolCallStartTime := time.Now()

if msgIdx, tcIdx, duplicated := FindDuplicateToolCallInMessages(messages, tc.Function.Name, tc.Function.Arguments); duplicated {
logger.Warnf(ctx, "%s Duplicate tool call found in history (messages[%d].tool_calls[%d])", toolTag, msgIdx, tcIdx)
duration := time.Since(toolCallStartTime).Milliseconds()
return types.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Args: args,
Result: &types.ToolResult{Success: true, Output: BuildDuplicateToolCallWarningCN(tc.Function.Name, msgIdx, tcIdx)},
Duration: duration,
}
}
if prevIdx, duplicated := FindDuplicateToolCallInCurrentResponseBeforeIndex(allToolCalls, i); duplicated {
logger.Warnf(ctx, "%s Duplicate tool call found earlier in current response (tool_calls[%d])", toolTag, prevIdx)
duration := time.Since(toolCallStartTime).Milliseconds()
return types.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Args: args,
Result: &types.ToolResult{Success: true, Output: BuildDuplicateToolCallWarningCN(tc.Function.Name, -1, prevIdx)},
Duration: duration,
}
}

// Emit tool hint for UI progress display
toolHint := formatToolHint(tc.Function.Name, args)
e.eventBus.Emit(ctx, event.Event{
Expand Down
2 changes: 1 addition & 1 deletion internal/agent/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ func (e *AgentEngine) executeLoop(
}

// 3. Act: Execute tool calls
e.executeToolCalls(ctx, response, &step, state.CurrentRound, sessionID)
e.executeToolCalls(ctx, response, &step, state.CurrentRound, sessionID, messages)

// 4. Observe: Add tool results to messages and write to context
state.RoundSteps = append(state.RoundSteps, step)
Expand Down
115 changes: 115 additions & 0 deletions internal/agent/tool_call_repeat_guard.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package agent

import (
"encoding/json"
"fmt"
"reflect"
"strings"

"github.com/Tencent/WeKnora/internal/models/chat"
"github.com/Tencent/WeKnora/internal/types"
)

// JSONToolArgumentsSemanticallyEqual compares two JSON argument strings after parsing.
// Whitespace, key order, and number formatting differences do not affect equality.
// Invalid JSON on either side returns false.
func JSONToolArgumentsSemanticallyEqual(a, b string) bool {
var va, vb any
if err := json.Unmarshal([]byte(a), &va); err != nil {
return false
}
if err := json.Unmarshal([]byte(b), &vb); err != nil {
return false
}
return reflect.DeepEqual(va, vb)
}

// CanonicalJSONArgs returns a stable string form for logging and fingerprinting:
// parsed then re-marshaled so map key order is deterministic (Go json.Marshal sorts map keys).
func CanonicalJSONArgs(argsJSON string) (string, error) {
var v any
if err := json.Unmarshal([]byte(argsJSON), &v); err != nil {
return "", err
}
b, err := json.Marshal(v)
if err != nil {
return "", err
}
return string(b), nil
}

// BuildDuplicateToolCallWarningCN returns a warning with first seen indexes.
func BuildDuplicateToolCallWarningCN(toolName string, historyMessageIndex, toolCallIndex int) string {
if historyMessageIndex < 0 {
return fmt.Sprintf(
"[System Notice] Duplicate tool call detected: `%s` has arguments identical to an earlier call in the current response (ignoring JSON formatting differences), at current_response.tool_calls[%d]. Do not repeat the same call; adjust parameters or continue reasoning based on existing results.",
toolName, toolCallIndex,
)
}
return fmt.Sprintf(
"[System Notice] Duplicate tool call detected: `%s` has arguments identical to a previous historical call (ignoring JSON formatting differences). First seen at messages[%d].tool_calls[%d]. Do not repeat the same call; adjust parameters or continue reasoning based on existing results.",
toolName, historyMessageIndex, toolCallIndex,
)
}

// FindDuplicateToolCallInMessages scans assistant messages to find whether
// (tool name + semantically equal JSON args) has appeared before.
// Returns first hit indexes and true when duplicated.
func FindDuplicateToolCallInMessages(
messages []chat.Message,
toolName string,
argsJSON string,
) (historyMessageIndex int, toolCallIndex int, duplicated bool) {
target, err := CanonicalJSONArgs(argsJSON)
if err != nil {
return -1, -1, false
}
for mi, m := range messages {
if m.Role != "assistant" || len(m.ToolCalls) == 0 {
continue
}
for ti, tc := range m.ToolCalls {
if !strings.EqualFold(tc.Function.Name, toolName) {
continue
}
canon, err := CanonicalJSONArgs(tc.Function.Arguments)
if err != nil {
continue
}
if canon == target {
return mi, ti, true
}
}
}
return -1, -1, false
}

// FindDuplicateToolCallInCurrentResponseBeforeIndex checks tool calls that
// already appeared earlier in the same LLM response.
func FindDuplicateToolCallInCurrentResponseBeforeIndex(
toolCalls []types.LLMToolCall,
currentIndex int,
) (prevIndex int, duplicated bool) {
if currentIndex <= 0 || currentIndex >= len(toolCalls) {
return -1, false
}
curr := toolCalls[currentIndex]
currCanon, err := CanonicalJSONArgs(curr.Function.Arguments)
if err != nil {
return -1, false
}
for i := 0; i < currentIndex; i++ {
prev := toolCalls[i]
if !strings.EqualFold(prev.Function.Name, curr.Function.Name) {
continue
}
prevCanon, err := CanonicalJSONArgs(prev.Function.Arguments)
if err != nil {
continue
}
if prevCanon == currCanon {
return i, true
}
}
return -1, false
}
151 changes: 151 additions & 0 deletions internal/agent/tool_call_repeat_guard_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package agent

import (
"testing"

"github.com/Tencent/WeKnora/internal/models/chat"
"github.com/Tencent/WeKnora/internal/types"
)

func TestJSONToolArgumentsSemanticallyEqual_userExample(t *testing.T) {
a := `{
"req": [
{
"knowledge_id": "fbd587fe-6249-448e-b174-d5818af5b42f",
"limit": 20,
"offset": 0
},
{
"knowledge_id": "a6e73511-c239-4c35-b2ca-a78b4354e5e5",
"limit": 20,
"offset": 0
},
{
"knowledge_id": "06bc4b39-3118-48e9-b4cc-ff8cc1e618d2",
"limit": 20,
"offset": 0
},
{
"knowledge_id": "65fb8383-99e6-4b2f-b711-130c5a6dd4aa",
"limit": 20,
"offset": 0
},
{
"knowledge_id": "cdb87e7f-2a47-4394-8a6f-ae0bc75e9969",
"limit": 20,
"offset": 0
}
]
}`
b := `{
"req" : [ {
"knowledge_id" : "a6e73511-c239-4c35-b2ca-a78b4354e5e5",
"limit" : 20,
"offset" : 0
}, {
"knowledge_id" : "fbd587fe-6249-448e-b174-d5818af5b42f",
"limit" : 20,
"offset" : 0
}, {
"knowledge_id" : "06bc4b39-3118-48e9-b4cc-ff8cc1e618d2",
"limit" : 20,
"offset" : 0
}, {
"knowledge_id" : "65fb8383-99e6-4b2f-b711-130c5a6dd4aa",
"limit" : 20,
"offset" : 0
}, {
"knowledge_id" : "cdb87e7f-2a47-4394-8a6f-ae0bc75e9969",
"limit" : 20,
"offset" : 0
} ]
}`

if !JSONToolArgumentsSemanticallyEqual(a, b) {
t.Fatal("expected semantically equal JSON (format differs)")
}
}

func TestFindDuplicateToolCallInMessages(t *testing.T) {
messages := []chat.Message{
{Role: "user", Content: "hi"},
{
Role: "assistant",
ToolCalls: []chat.ToolCall{
{
ID: "call_1",
Type: "function",
Function: chat.FunctionCall{
Name: "list_knowledge_chunks",
Arguments: `{"req":[{"knowledge_id":"x","limit":20,"offset":0}]}`,
},
},
},
},
}

msgIdx, tcIdx, ok := FindDuplicateToolCallInMessages(
messages,
"list_knowledge_chunks",
`{
"req" : [ {
"knowledge_id" : "x",
"limit" : 20,
"offset" : 0
} ]
}`,
)
if !ok {
t.Fatal("expected duplicate in assistant history")
}
if msgIdx != 1 || tcIdx != 0 {
t.Fatalf("unexpected hit index: msg=%d tc=%d", msgIdx, tcIdx)
}
}

func TestFindDuplicateToolCallInMessages_notDuplicate(t *testing.T) {
messages := []chat.Message{
{
Role: "assistant",
ToolCalls: []chat.ToolCall{
{
ID: "call_1",
Type: "function",
Function: chat.FunctionCall{
Name: "knowledge_search",
Arguments: `{"query":"a"}`,
},
},
},
},
}
_, _, ok := FindDuplicateToolCallInMessages(messages, "knowledge_search", `{"query":"b"}`)
if ok {
t.Fatal("expected non-duplicate when args differ")
}
}

func TestFindDuplicateToolCallInCurrentResponseBeforeIndex(t *testing.T) {
toolCalls := []types.LLMToolCall{
{Function: types.FunctionCall{Name: "knowledge_search", Arguments: `{"query":"a"}`}},
{Function: types.FunctionCall{Name: "knowledge_search", Arguments: `{"query":"a"}`}},
{Function: types.FunctionCall{Name: "knowledge_search", Arguments: `{"query":"b"}`}},
}
prev, ok := FindDuplicateToolCallInCurrentResponseBeforeIndex(toolCalls, 1)
if !ok || prev != 0 {
t.Fatalf("expected duplicate at index 0, got prev=%d ok=%v", prev, ok)
}
if _, ok := FindDuplicateToolCallInCurrentResponseBeforeIndex(toolCalls, 2); ok {
t.Fatal("index 2 should not be duplicate")
}
}

func TestBuildDuplicateToolCallWarningCN(t *testing.T) {
msg := BuildDuplicateToolCallWarningCN("knowledge_search", 3, 1)
if msg == "" {
t.Fatal("warning should not be empty")
}
if msg[0] != '[' {
t.Fatal("warning format seems broken")
}
}