473 lines
18 KiB
Python
473 lines
18 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
||
# All rights reserved.
|
||
"""Agent loop: ReAct pattern with native function calling, budget control, and error handling."""
|
||
|
||
import asyncio
|
||
import logging
|
||
import time
|
||
import json
|
||
import re
|
||
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Awaitable
|
||
|
||
from modules.datamodels.datamodelAi import (
|
||
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum
|
||
)
|
||
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
||
AgentState, AgentStatusEnum, AgentConfig, AgentEvent, AgentEventTypeEnum,
|
||
ToolCallRequest, ToolResult, ToolCallLog, AgentRoundLog, AgentTrace
|
||
)
|
||
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
|
||
from modules.serviceCenter.services.serviceAgent.conversationManager import (
|
||
ConversationManager, buildSystemPrompt
|
||
)
|
||
from modules.shared.timeUtils import getUtcTimestamp
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
async def runAgentLoop(
|
||
prompt: str,
|
||
toolRegistry: ToolRegistry,
|
||
config: AgentConfig,
|
||
aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]],
|
||
getWorkflowCostFn: Callable[[], Awaitable[float]],
|
||
workflowId: str,
|
||
userId: str = "",
|
||
featureInstanceId: str = "",
|
||
buildRagContextFn: Callable[..., Awaitable[str]] = None,
|
||
mandateId: str = "",
|
||
aiCallStreamFn: Callable = None,
|
||
userLanguage: str = "",
|
||
conversationHistory: List[Dict[str, Any]] = None,
|
||
) -> AsyncGenerator[AgentEvent, None]:
|
||
"""Run the agent loop. Yields AgentEvent for each step (SSE-ready).
|
||
|
||
Args:
|
||
prompt: User prompt
|
||
toolRegistry: Registry with available tools
|
||
config: Agent configuration (maxRounds, maxCostCHF, etc.)
|
||
aiCallFn: Function to call the AI (wraps serviceAi.callAi with billing)
|
||
getWorkflowCostFn: Function to get current workflow cost
|
||
workflowId: Workflow ID for tracking
|
||
userId: User ID for tracing
|
||
featureInstanceId: Feature instance ID for tracing
|
||
buildRagContextFn: Optional async function to build RAG context before each round
|
||
mandateId: Mandate ID for RAG scoping
|
||
userLanguage: ISO 639-1 language code for agent responses
|
||
conversationHistory: Prior messages [{role, content/message}] for follow-up context
|
||
"""
|
||
state = AgentState(workflowId=workflowId, maxRounds=config.maxRounds)
|
||
trace = AgentTrace(
|
||
workflowId=workflowId, userId=userId,
|
||
featureInstanceId=featureInstanceId
|
||
)
|
||
|
||
tools = toolRegistry.getTools()
|
||
toolDefinitions = toolRegistry.formatToolsForFunctionCalling()
|
||
toolsText = toolRegistry.formatToolsForPrompt()
|
||
|
||
systemPrompt = buildSystemPrompt(tools, toolsText, userLanguage=userLanguage)
|
||
conversation = ConversationManager(systemPrompt)
|
||
if conversationHistory:
|
||
conversation.loadHistory(conversationHistory)
|
||
conversation.addUserMessage(prompt)
|
||
|
||
while state.status == AgentStatusEnum.RUNNING and state.currentRound < state.maxRounds:
|
||
await asyncio.sleep(0)
|
||
state.currentRound += 1
|
||
roundStartTime = time.time()
|
||
roundLog = AgentRoundLog(roundNumber=state.currentRound)
|
||
|
||
# RAG context injection (before each round for fresh relevance)
|
||
if buildRagContextFn:
|
||
try:
|
||
latestUserMsg = ""
|
||
for msg in reversed(conversation.messages):
|
||
if msg.get("role") == "user":
|
||
latestUserMsg = msg.get("content", "")
|
||
break
|
||
ragContext = await buildRagContextFn(
|
||
currentPrompt=latestUserMsg or prompt,
|
||
workflowId=workflowId,
|
||
userId=userId,
|
||
featureInstanceId=featureInstanceId,
|
||
mandateId=mandateId,
|
||
)
|
||
if ragContext:
|
||
conversation.injectRagContext(ragContext)
|
||
except Exception as ragErr:
|
||
logger.warning(f"RAG context injection failed (non-blocking): {ragErr}")
|
||
|
||
# Budget check
|
||
budgetExceeded = await _checkBudget(config, getWorkflowCostFn)
|
||
if budgetExceeded:
|
||
state.status = AgentStatusEnum.BUDGET_EXCEEDED
|
||
state.abortReason = "Workflow cost budget exceeded"
|
||
yield AgentEvent(
|
||
type=AgentEventTypeEnum.FINAL,
|
||
content=_buildProgressSummary(state, "Budget exceeded. Here is the progress so far.")
|
||
)
|
||
break
|
||
|
||
logger.info(f"Agent round {state.currentRound}/{state.maxRounds} for workflow {workflowId} (tools={state.totalToolCalls}, cost={state.totalCostCHF:.4f})")
|
||
yield AgentEvent(
|
||
type=AgentEventTypeEnum.AGENT_PROGRESS,
|
||
data={
|
||
"round": state.currentRound,
|
||
"maxRounds": state.maxRounds,
|
||
"totalAiCalls": state.totalAiCalls,
|
||
"totalToolCalls": state.totalToolCalls,
|
||
"costCHF": state.totalCostCHF
|
||
}
|
||
)
|
||
|
||
# Progressive summarization
|
||
if conversation.needsSummarization(state.currentRound):
|
||
async def _summarizeCall(summaryPrompt: str) -> str:
|
||
req = AiCallRequest(
|
||
prompt=summaryPrompt,
|
||
options=AiCallOptions(operationType=OperationTypeEnum.DATA_ANALYSE)
|
||
)
|
||
resp = await aiCallFn(req)
|
||
state.totalCostCHF += resp.priceCHF
|
||
state.totalAiCalls += 1
|
||
return resp.content
|
||
|
||
await conversation.summarize(state.currentRound, _summarizeCall)
|
||
|
||
# AI call
|
||
aiRequest = AiCallRequest(
|
||
prompt="",
|
||
options=AiCallOptions(
|
||
operationType=OperationTypeEnum.AGENT,
|
||
temperature=config.temperature
|
||
),
|
||
messages=conversation.messages,
|
||
tools=toolDefinitions
|
||
)
|
||
|
||
try:
|
||
aiResponse = None
|
||
streamedText = ""
|
||
isFirstChunkOfRound = True
|
||
|
||
if aiCallStreamFn:
|
||
async for chunk in aiCallStreamFn(aiRequest):
|
||
if isinstance(chunk, str):
|
||
if isFirstChunkOfRound and state.currentRound > 1:
|
||
chunk = "\n\n" + chunk
|
||
isFirstChunkOfRound = False
|
||
elif isFirstChunkOfRound:
|
||
isFirstChunkOfRound = False
|
||
streamedText += chunk
|
||
yield AgentEvent(type=AgentEventTypeEnum.CHUNK, content=chunk)
|
||
else:
|
||
aiResponse = chunk
|
||
|
||
if aiResponse is None:
|
||
raise RuntimeError("Stream ended without final AiCallResponse")
|
||
else:
|
||
aiResponse = await aiCallFn(aiRequest)
|
||
|
||
except Exception as e:
|
||
logger.error(f"AI call failed in round {state.currentRound}: {e}", exc_info=True)
|
||
state.status = AgentStatusEnum.ERROR
|
||
state.abortReason = f"AI call error: {e}"
|
||
yield AgentEvent(type=AgentEventTypeEnum.ERROR, content=str(e))
|
||
break
|
||
|
||
state.totalAiCalls += 1
|
||
state.totalCostCHF += aiResponse.priceCHF
|
||
state.totalProcessingTime += aiResponse.processingTime
|
||
roundLog.aiModel = aiResponse.modelName
|
||
roundLog.costCHF = aiResponse.priceCHF
|
||
|
||
if aiResponse.errorCount > 0:
|
||
state.status = AgentStatusEnum.ERROR
|
||
state.abortReason = f"AI returned error: {aiResponse.content}"
|
||
yield AgentEvent(type=AgentEventTypeEnum.ERROR, content=aiResponse.content)
|
||
break
|
||
|
||
# Parse response for tool calls
|
||
toolCalls = _parseToolCalls(aiResponse)
|
||
textContent = _extractTextContent(aiResponse)
|
||
|
||
if textContent and not streamedText:
|
||
yield AgentEvent(type=AgentEventTypeEnum.MESSAGE, content=textContent)
|
||
|
||
if not toolCalls:
|
||
state.status = AgentStatusEnum.COMPLETED
|
||
conversation.addAssistantMessage(aiResponse.content)
|
||
roundLog.durationMs = int((time.time() - roundStartTime) * 1000)
|
||
trace.rounds.append(roundLog)
|
||
yield AgentEvent(type=AgentEventTypeEnum.FINAL, content=textContent or aiResponse.content)
|
||
break
|
||
|
||
# Add assistant message with tool calls to conversation
|
||
assistantToolCalls = _formatAssistantToolCalls(toolCalls)
|
||
conversation.addAssistantMessage(textContent or "", assistantToolCalls)
|
||
|
||
# Execute tool calls
|
||
for tc in toolCalls:
|
||
yield AgentEvent(
|
||
type=AgentEventTypeEnum.TOOL_CALL,
|
||
data={"toolName": tc.name, "args": tc.args}
|
||
)
|
||
|
||
results = await _executeToolCalls(toolCalls, toolRegistry, {
|
||
"workflowId": workflowId,
|
||
"userId": userId,
|
||
"featureInstanceId": featureInstanceId,
|
||
"mandateId": mandateId,
|
||
})
|
||
state.totalToolCalls += len(results)
|
||
|
||
for result in results:
|
||
roundLog.toolCalls.append(ToolCallLog(
|
||
toolName=result.toolName,
|
||
args=next((tc.args for tc in toolCalls if tc.id == result.toolCallId), {}),
|
||
success=result.success,
|
||
durationMs=result.durationMs,
|
||
error=result.error
|
||
))
|
||
if not result.success:
|
||
logger.warning(f"Tool '{result.toolName}' failed: {result.error}")
|
||
yield AgentEvent(
|
||
type=AgentEventTypeEnum.TOOL_RESULT,
|
||
data={
|
||
"toolName": result.toolName,
|
||
"success": result.success,
|
||
"data": result.data[:500] if result.data else "",
|
||
"error": result.error
|
||
}
|
||
)
|
||
if result.sideEvents:
|
||
for sideEvt in result.sideEvents:
|
||
evtType = sideEvt.get("type", "")
|
||
try:
|
||
evtEnum = AgentEventTypeEnum(evtType)
|
||
except (ValueError, KeyError):
|
||
continue
|
||
yield AgentEvent(
|
||
type=evtEnum,
|
||
data=sideEvt.get("data"),
|
||
content=sideEvt.get("content"),
|
||
)
|
||
|
||
# Add tool results to conversation
|
||
toolResultMessages = [
|
||
{"toolCallId": r.toolCallId, "toolName": r.toolName,
|
||
"content": r.data if r.success else f"Error: {r.error}"}
|
||
for r in results
|
||
]
|
||
conversation.addToolResults(toolResultMessages)
|
||
|
||
roundLog.durationMs = int((time.time() - roundStartTime) * 1000)
|
||
trace.rounds.append(roundLog)
|
||
|
||
# maxRounds reached
|
||
if state.currentRound >= state.maxRounds and state.status == AgentStatusEnum.RUNNING:
|
||
state.status = AgentStatusEnum.MAX_ROUNDS_REACHED
|
||
state.abortReason = f"Maximum rounds ({state.maxRounds}) reached"
|
||
yield AgentEvent(
|
||
type=AgentEventTypeEnum.FINAL,
|
||
content=_buildProgressSummary(state, "Maximum rounds reached.")
|
||
)
|
||
|
||
# Agent summary
|
||
trace.completedAt = getUtcTimestamp()
|
||
trace.status = state.status
|
||
trace.totalRounds = state.currentRound
|
||
trace.totalToolCalls = state.totalToolCalls
|
||
trace.totalCostCHF = state.totalCostCHF
|
||
trace.abortReason = state.abortReason
|
||
|
||
yield AgentEvent(
|
||
type=AgentEventTypeEnum.AGENT_SUMMARY,
|
||
data={
|
||
"rounds": state.currentRound,
|
||
"totalAiCalls": state.totalAiCalls,
|
||
"totalToolCalls": state.totalToolCalls,
|
||
"costCHF": round(state.totalCostCHF, 4),
|
||
"processingTime": round(state.totalProcessingTime, 2),
|
||
"status": state.status.value,
|
||
"abortReason": state.abortReason
|
||
}
|
||
)
|
||
|
||
|
||
async def _checkBudget(config: AgentConfig,
|
||
getWorkflowCostFn: Callable[[], Awaitable[float]]) -> bool:
|
||
"""Check if workflow budget is exceeded. Returns True if exceeded."""
|
||
if config.maxCostCHF is None:
|
||
return False
|
||
try:
|
||
currentCost = await getWorkflowCostFn()
|
||
return currentCost > config.maxCostCHF
|
||
except Exception as e:
|
||
logger.warning(f"Could not check workflow cost: {e}")
|
||
return False
|
||
|
||
|
||
async def _executeToolCalls(toolCalls: List[ToolCallRequest],
|
||
toolRegistry: ToolRegistry,
|
||
context: Dict[str, Any]) -> List[ToolResult]:
|
||
"""Execute tool calls: readOnly tools in parallel, others sequentially.
|
||
|
||
Tool calls with _parseError (truncated JSON from LLM) are short-circuited
|
||
with an error result so the agent can retry.
|
||
"""
|
||
readOnlyCalls = [tc for tc in toolCalls if toolRegistry.isReadOnly(tc.name)]
|
||
writeCalls = [tc for tc in toolCalls if not toolRegistry.isReadOnly(tc.name)]
|
||
|
||
results: Dict[str, ToolResult] = {}
|
||
|
||
for tc in toolCalls:
|
||
if "_parseError" in tc.args:
|
||
results[tc.id] = ToolResult(
|
||
toolCallId=tc.id,
|
||
toolName=tc.name,
|
||
success=False,
|
||
data="",
|
||
error=tc.args["_parseError"],
|
||
durationMs=0,
|
||
)
|
||
|
||
activeCalls = [tc for tc in toolCalls if tc.id not in results]
|
||
activeReadOnly = [tc for tc in activeCalls if toolRegistry.isReadOnly(tc.name)]
|
||
activeWrite = [tc for tc in activeCalls if not toolRegistry.isReadOnly(tc.name)]
|
||
|
||
if activeReadOnly:
|
||
readResults = await asyncio.gather(*[
|
||
toolRegistry.dispatch(tc, context) for tc in activeReadOnly
|
||
])
|
||
for tc, result in zip(activeReadOnly, readResults):
|
||
results[tc.id] = result
|
||
|
||
for tc in activeWrite:
|
||
results[tc.id] = await toolRegistry.dispatch(tc, context)
|
||
|
||
return [results[tc.id] for tc in toolCalls]
|
||
|
||
|
||
def _repairTruncatedJson(raw: str) -> Optional[Dict[str, Any]]:
|
||
"""Try to repair truncated JSON from LLM output by closing open brackets/braces.
|
||
|
||
Returns parsed dict on success, None if unrecoverable.
|
||
"""
|
||
if not raw or not raw.strip().startswith("{"):
|
||
return None
|
||
|
||
openBraces = raw.count("{") - raw.count("}")
|
||
openBrackets = raw.count("[") - raw.count("]")
|
||
|
||
inString = False
|
||
lastQuoteEscaped = False
|
||
quoteCount = 0
|
||
for ch in raw:
|
||
if ch == '"' and not lastQuoteEscaped:
|
||
quoteCount += 1
|
||
inString = not inString
|
||
lastQuoteEscaped = (ch == '\\')
|
||
|
||
candidate = raw
|
||
if quoteCount % 2 != 0:
|
||
candidate += '"'
|
||
|
||
candidate += "]" * max(0, openBrackets)
|
||
candidate += "}" * max(0, openBraces)
|
||
|
||
try:
|
||
return json.loads(candidate)
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
lastComma = candidate.rfind(",")
|
||
if lastComma > 0:
|
||
trimmed = candidate[:lastComma] + candidate[lastComma + 1:]
|
||
try:
|
||
return json.loads(trimmed)
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
return None
|
||
|
||
|
||
def _parseToolCalls(aiResponse: AiCallResponse) -> List[ToolCallRequest]:
|
||
"""Parse tool calls from AI response. Supports native function calling and text-based fallback."""
|
||
toolCalls = []
|
||
|
||
# Native function calling: check response metadata
|
||
if hasattr(aiResponse, 'toolCalls') and aiResponse.toolCalls:
|
||
for tc in aiResponse.toolCalls:
|
||
rawArgs = tc["function"]["arguments"]
|
||
if isinstance(rawArgs, str):
|
||
rawArgs = rawArgs.strip()
|
||
try:
|
||
parsedArgs = json.loads(rawArgs) if rawArgs else {}
|
||
except json.JSONDecodeError:
|
||
parsedArgs = _repairTruncatedJson(rawArgs)
|
||
if parsedArgs is None:
|
||
logger.warning(f"Unrecoverable truncated JSON for '{tc['function']['name']}': {rawArgs[:200]}")
|
||
parsedArgs = {"_parseError": f"Truncated JSON arguments – model output was cut off. Raw start: {rawArgs[:120]}"}
|
||
else:
|
||
logger.info(f"Repaired truncated JSON for '{tc['function']['name']}'")
|
||
else:
|
||
parsedArgs = rawArgs if rawArgs else {}
|
||
toolCalls.append(ToolCallRequest(
|
||
id=tc.get("id", str(len(toolCalls))),
|
||
name=tc["function"]["name"],
|
||
args=parsedArgs,
|
||
))
|
||
return toolCalls
|
||
|
||
# Text-based fallback: parse ```tool_call blocks
|
||
content = aiResponse.content or ""
|
||
pattern = r"```tool_call\s*\n\s*tool:\s*(\S+)\s*\n\s*args:\s*(\{.*?\})\s*\n\s*```"
|
||
matches = re.finditer(pattern, content, re.DOTALL)
|
||
|
||
for match in matches:
|
||
toolName = match.group(1).strip()
|
||
argsStr = match.group(2).strip()
|
||
try:
|
||
args = json.loads(argsStr)
|
||
except json.JSONDecodeError:
|
||
logger.warning(f"Failed to parse tool args for '{toolName}': {argsStr}")
|
||
args = {}
|
||
toolCalls.append(ToolCallRequest(name=toolName, args=args))
|
||
|
||
return toolCalls
|
||
|
||
|
||
def _extractTextContent(aiResponse: AiCallResponse) -> str:
|
||
"""Extract text content from AI response, removing tool_call blocks."""
|
||
content = aiResponse.content or ""
|
||
cleaned = re.sub(r"```tool_call\s*\n.*?\n\s*```", "", content, flags=re.DOTALL)
|
||
return cleaned.strip()
|
||
|
||
|
||
def _formatAssistantToolCalls(toolCalls: List[ToolCallRequest]) -> List[Dict[str, Any]]:
|
||
"""Format tool calls for the conversation history (OpenAI tool_calls format)."""
|
||
return [
|
||
{
|
||
"id": tc.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tc.name,
|
||
"arguments": json.dumps(tc.args)
|
||
}
|
||
}
|
||
for tc in toolCalls
|
||
]
|
||
|
||
|
||
def _buildProgressSummary(state: AgentState, reason: str) -> str:
|
||
"""Build a human-readable summary of agent progress for graceful termination."""
|
||
return (
|
||
f"{reason}\n\n"
|
||
f"Progress after {state.currentRound} rounds:\n"
|
||
f"- AI calls: {state.totalAiCalls}\n"
|
||
f"- Tool calls: {state.totalToolCalls}\n"
|
||
f"- Cost: {state.totalCostCHF:.4f} CHF\n"
|
||
f"- Processing time: {state.totalProcessingTime:.1f}s"
|
||
)
|