gateway/modules/serviceCenter/services/serviceAgent/agentLoop.py
2026-03-17 23:23:40 +01:00

473 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Agent loop: ReAct pattern with native function calling, budget control, and error handling."""
import asyncio
import logging
import time
import json
import re
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Awaitable
from modules.datamodels.datamodelAi import (
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum
)
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
AgentState, AgentStatusEnum, AgentConfig, AgentEvent, AgentEventTypeEnum,
ToolCallRequest, ToolResult, ToolCallLog, AgentRoundLog, AgentTrace
)
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
from modules.serviceCenter.services.serviceAgent.conversationManager import (
ConversationManager, buildSystemPrompt
)
from modules.shared.timeUtils import getUtcTimestamp
logger = logging.getLogger(__name__)
async def runAgentLoop(
prompt: str,
toolRegistry: ToolRegistry,
config: AgentConfig,
aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]],
getWorkflowCostFn: Callable[[], Awaitable[float]],
workflowId: str,
userId: str = "",
featureInstanceId: str = "",
buildRagContextFn: Callable[..., Awaitable[str]] = None,
mandateId: str = "",
aiCallStreamFn: Callable = None,
userLanguage: str = "",
conversationHistory: List[Dict[str, Any]] = None,
) -> AsyncGenerator[AgentEvent, None]:
"""Run the agent loop. Yields AgentEvent for each step (SSE-ready).
Args:
prompt: User prompt
toolRegistry: Registry with available tools
config: Agent configuration (maxRounds, maxCostCHF, etc.)
aiCallFn: Function to call the AI (wraps serviceAi.callAi with billing)
getWorkflowCostFn: Function to get current workflow cost
workflowId: Workflow ID for tracking
userId: User ID for tracing
featureInstanceId: Feature instance ID for tracing
buildRagContextFn: Optional async function to build RAG context before each round
mandateId: Mandate ID for RAG scoping
userLanguage: ISO 639-1 language code for agent responses
conversationHistory: Prior messages [{role, content/message}] for follow-up context
"""
state = AgentState(workflowId=workflowId, maxRounds=config.maxRounds)
trace = AgentTrace(
workflowId=workflowId, userId=userId,
featureInstanceId=featureInstanceId
)
tools = toolRegistry.getTools()
toolDefinitions = toolRegistry.formatToolsForFunctionCalling()
toolsText = toolRegistry.formatToolsForPrompt()
systemPrompt = buildSystemPrompt(tools, toolsText, userLanguage=userLanguage)
conversation = ConversationManager(systemPrompt)
if conversationHistory:
conversation.loadHistory(conversationHistory)
conversation.addUserMessage(prompt)
while state.status == AgentStatusEnum.RUNNING and state.currentRound < state.maxRounds:
await asyncio.sleep(0)
state.currentRound += 1
roundStartTime = time.time()
roundLog = AgentRoundLog(roundNumber=state.currentRound)
# RAG context injection (before each round for fresh relevance)
if buildRagContextFn:
try:
latestUserMsg = ""
for msg in reversed(conversation.messages):
if msg.get("role") == "user":
latestUserMsg = msg.get("content", "")
break
ragContext = await buildRagContextFn(
currentPrompt=latestUserMsg or prompt,
workflowId=workflowId,
userId=userId,
featureInstanceId=featureInstanceId,
mandateId=mandateId,
)
if ragContext:
conversation.injectRagContext(ragContext)
except Exception as ragErr:
logger.warning(f"RAG context injection failed (non-blocking): {ragErr}")
# Budget check
budgetExceeded = await _checkBudget(config, getWorkflowCostFn)
if budgetExceeded:
state.status = AgentStatusEnum.BUDGET_EXCEEDED
state.abortReason = "Workflow cost budget exceeded"
yield AgentEvent(
type=AgentEventTypeEnum.FINAL,
content=_buildProgressSummary(state, "Budget exceeded. Here is the progress so far.")
)
break
logger.info(f"Agent round {state.currentRound}/{state.maxRounds} for workflow {workflowId} (tools={state.totalToolCalls}, cost={state.totalCostCHF:.4f})")
yield AgentEvent(
type=AgentEventTypeEnum.AGENT_PROGRESS,
data={
"round": state.currentRound,
"maxRounds": state.maxRounds,
"totalAiCalls": state.totalAiCalls,
"totalToolCalls": state.totalToolCalls,
"costCHF": state.totalCostCHF
}
)
# Progressive summarization
if conversation.needsSummarization(state.currentRound):
async def _summarizeCall(summaryPrompt: str) -> str:
req = AiCallRequest(
prompt=summaryPrompt,
options=AiCallOptions(operationType=OperationTypeEnum.DATA_ANALYSE)
)
resp = await aiCallFn(req)
state.totalCostCHF += resp.priceCHF
state.totalAiCalls += 1
return resp.content
await conversation.summarize(state.currentRound, _summarizeCall)
# AI call
aiRequest = AiCallRequest(
prompt="",
options=AiCallOptions(
operationType=OperationTypeEnum.AGENT,
temperature=config.temperature
),
messages=conversation.messages,
tools=toolDefinitions
)
try:
aiResponse = None
streamedText = ""
isFirstChunkOfRound = True
if aiCallStreamFn:
async for chunk in aiCallStreamFn(aiRequest):
if isinstance(chunk, str):
if isFirstChunkOfRound and state.currentRound > 1:
chunk = "\n\n" + chunk
isFirstChunkOfRound = False
elif isFirstChunkOfRound:
isFirstChunkOfRound = False
streamedText += chunk
yield AgentEvent(type=AgentEventTypeEnum.CHUNK, content=chunk)
else:
aiResponse = chunk
if aiResponse is None:
raise RuntimeError("Stream ended without final AiCallResponse")
else:
aiResponse = await aiCallFn(aiRequest)
except Exception as e:
logger.error(f"AI call failed in round {state.currentRound}: {e}", exc_info=True)
state.status = AgentStatusEnum.ERROR
state.abortReason = f"AI call error: {e}"
yield AgentEvent(type=AgentEventTypeEnum.ERROR, content=str(e))
break
state.totalAiCalls += 1
state.totalCostCHF += aiResponse.priceCHF
state.totalProcessingTime += aiResponse.processingTime
roundLog.aiModel = aiResponse.modelName
roundLog.costCHF = aiResponse.priceCHF
if aiResponse.errorCount > 0:
state.status = AgentStatusEnum.ERROR
state.abortReason = f"AI returned error: {aiResponse.content}"
yield AgentEvent(type=AgentEventTypeEnum.ERROR, content=aiResponse.content)
break
# Parse response for tool calls
toolCalls = _parseToolCalls(aiResponse)
textContent = _extractTextContent(aiResponse)
if textContent and not streamedText:
yield AgentEvent(type=AgentEventTypeEnum.MESSAGE, content=textContent)
if not toolCalls:
state.status = AgentStatusEnum.COMPLETED
conversation.addAssistantMessage(aiResponse.content)
roundLog.durationMs = int((time.time() - roundStartTime) * 1000)
trace.rounds.append(roundLog)
yield AgentEvent(type=AgentEventTypeEnum.FINAL, content=textContent or aiResponse.content)
break
# Add assistant message with tool calls to conversation
assistantToolCalls = _formatAssistantToolCalls(toolCalls)
conversation.addAssistantMessage(textContent or "", assistantToolCalls)
# Execute tool calls
for tc in toolCalls:
yield AgentEvent(
type=AgentEventTypeEnum.TOOL_CALL,
data={"toolName": tc.name, "args": tc.args}
)
results = await _executeToolCalls(toolCalls, toolRegistry, {
"workflowId": workflowId,
"userId": userId,
"featureInstanceId": featureInstanceId,
"mandateId": mandateId,
})
state.totalToolCalls += len(results)
for result in results:
roundLog.toolCalls.append(ToolCallLog(
toolName=result.toolName,
args=next((tc.args for tc in toolCalls if tc.id == result.toolCallId), {}),
success=result.success,
durationMs=result.durationMs,
error=result.error
))
if not result.success:
logger.warning(f"Tool '{result.toolName}' failed: {result.error}")
yield AgentEvent(
type=AgentEventTypeEnum.TOOL_RESULT,
data={
"toolName": result.toolName,
"success": result.success,
"data": result.data[:500] if result.data else "",
"error": result.error
}
)
if result.sideEvents:
for sideEvt in result.sideEvents:
evtType = sideEvt.get("type", "")
try:
evtEnum = AgentEventTypeEnum(evtType)
except (ValueError, KeyError):
continue
yield AgentEvent(
type=evtEnum,
data=sideEvt.get("data"),
content=sideEvt.get("content"),
)
# Add tool results to conversation
toolResultMessages = [
{"toolCallId": r.toolCallId, "toolName": r.toolName,
"content": r.data if r.success else f"Error: {r.error}"}
for r in results
]
conversation.addToolResults(toolResultMessages)
roundLog.durationMs = int((time.time() - roundStartTime) * 1000)
trace.rounds.append(roundLog)
# maxRounds reached
if state.currentRound >= state.maxRounds and state.status == AgentStatusEnum.RUNNING:
state.status = AgentStatusEnum.MAX_ROUNDS_REACHED
state.abortReason = f"Maximum rounds ({state.maxRounds}) reached"
yield AgentEvent(
type=AgentEventTypeEnum.FINAL,
content=_buildProgressSummary(state, "Maximum rounds reached.")
)
# Agent summary
trace.completedAt = getUtcTimestamp()
trace.status = state.status
trace.totalRounds = state.currentRound
trace.totalToolCalls = state.totalToolCalls
trace.totalCostCHF = state.totalCostCHF
trace.abortReason = state.abortReason
yield AgentEvent(
type=AgentEventTypeEnum.AGENT_SUMMARY,
data={
"rounds": state.currentRound,
"totalAiCalls": state.totalAiCalls,
"totalToolCalls": state.totalToolCalls,
"costCHF": round(state.totalCostCHF, 4),
"processingTime": round(state.totalProcessingTime, 2),
"status": state.status.value,
"abortReason": state.abortReason
}
)
async def _checkBudget(config: AgentConfig,
getWorkflowCostFn: Callable[[], Awaitable[float]]) -> bool:
"""Check if workflow budget is exceeded. Returns True if exceeded."""
if config.maxCostCHF is None:
return False
try:
currentCost = await getWorkflowCostFn()
return currentCost > config.maxCostCHF
except Exception as e:
logger.warning(f"Could not check workflow cost: {e}")
return False
async def _executeToolCalls(toolCalls: List[ToolCallRequest],
toolRegistry: ToolRegistry,
context: Dict[str, Any]) -> List[ToolResult]:
"""Execute tool calls: readOnly tools in parallel, others sequentially.
Tool calls with _parseError (truncated JSON from LLM) are short-circuited
with an error result so the agent can retry.
"""
readOnlyCalls = [tc for tc in toolCalls if toolRegistry.isReadOnly(tc.name)]
writeCalls = [tc for tc in toolCalls if not toolRegistry.isReadOnly(tc.name)]
results: Dict[str, ToolResult] = {}
for tc in toolCalls:
if "_parseError" in tc.args:
results[tc.id] = ToolResult(
toolCallId=tc.id,
toolName=tc.name,
success=False,
data="",
error=tc.args["_parseError"],
durationMs=0,
)
activeCalls = [tc for tc in toolCalls if tc.id not in results]
activeReadOnly = [tc for tc in activeCalls if toolRegistry.isReadOnly(tc.name)]
activeWrite = [tc for tc in activeCalls if not toolRegistry.isReadOnly(tc.name)]
if activeReadOnly:
readResults = await asyncio.gather(*[
toolRegistry.dispatch(tc, context) for tc in activeReadOnly
])
for tc, result in zip(activeReadOnly, readResults):
results[tc.id] = result
for tc in activeWrite:
results[tc.id] = await toolRegistry.dispatch(tc, context)
return [results[tc.id] for tc in toolCalls]
def _repairTruncatedJson(raw: str) -> Optional[Dict[str, Any]]:
"""Try to repair truncated JSON from LLM output by closing open brackets/braces.
Returns parsed dict on success, None if unrecoverable.
"""
if not raw or not raw.strip().startswith("{"):
return None
openBraces = raw.count("{") - raw.count("}")
openBrackets = raw.count("[") - raw.count("]")
inString = False
lastQuoteEscaped = False
quoteCount = 0
for ch in raw:
if ch == '"' and not lastQuoteEscaped:
quoteCount += 1
inString = not inString
lastQuoteEscaped = (ch == '\\')
candidate = raw
if quoteCount % 2 != 0:
candidate += '"'
candidate += "]" * max(0, openBrackets)
candidate += "}" * max(0, openBraces)
try:
return json.loads(candidate)
except json.JSONDecodeError:
pass
lastComma = candidate.rfind(",")
if lastComma > 0:
trimmed = candidate[:lastComma] + candidate[lastComma + 1:]
try:
return json.loads(trimmed)
except json.JSONDecodeError:
pass
return None
def _parseToolCalls(aiResponse: AiCallResponse) -> List[ToolCallRequest]:
"""Parse tool calls from AI response. Supports native function calling and text-based fallback."""
toolCalls = []
# Native function calling: check response metadata
if hasattr(aiResponse, 'toolCalls') and aiResponse.toolCalls:
for tc in aiResponse.toolCalls:
rawArgs = tc["function"]["arguments"]
if isinstance(rawArgs, str):
rawArgs = rawArgs.strip()
try:
parsedArgs = json.loads(rawArgs) if rawArgs else {}
except json.JSONDecodeError:
parsedArgs = _repairTruncatedJson(rawArgs)
if parsedArgs is None:
logger.warning(f"Unrecoverable truncated JSON for '{tc['function']['name']}': {rawArgs[:200]}")
parsedArgs = {"_parseError": f"Truncated JSON arguments model output was cut off. Raw start: {rawArgs[:120]}"}
else:
logger.info(f"Repaired truncated JSON for '{tc['function']['name']}'")
else:
parsedArgs = rawArgs if rawArgs else {}
toolCalls.append(ToolCallRequest(
id=tc.get("id", str(len(toolCalls))),
name=tc["function"]["name"],
args=parsedArgs,
))
return toolCalls
# Text-based fallback: parse ```tool_call blocks
content = aiResponse.content or ""
pattern = r"```tool_call\s*\n\s*tool:\s*(\S+)\s*\n\s*args:\s*(\{.*?\})\s*\n\s*```"
matches = re.finditer(pattern, content, re.DOTALL)
for match in matches:
toolName = match.group(1).strip()
argsStr = match.group(2).strip()
try:
args = json.loads(argsStr)
except json.JSONDecodeError:
logger.warning(f"Failed to parse tool args for '{toolName}': {argsStr}")
args = {}
toolCalls.append(ToolCallRequest(name=toolName, args=args))
return toolCalls
def _extractTextContent(aiResponse: AiCallResponse) -> str:
"""Extract text content from AI response, removing tool_call blocks."""
content = aiResponse.content or ""
cleaned = re.sub(r"```tool_call\s*\n.*?\n\s*```", "", content, flags=re.DOTALL)
return cleaned.strip()
def _formatAssistantToolCalls(toolCalls: List[ToolCallRequest]) -> List[Dict[str, Any]]:
"""Format tool calls for the conversation history (OpenAI tool_calls format)."""
return [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.args)
}
}
for tc in toolCalls
]
def _buildProgressSummary(state: AgentState, reason: str) -> str:
"""Build a human-readable summary of agent progress for graceful termination."""
return (
f"{reason}\n\n"
f"Progress after {state.currentRound} rounds:\n"
f"- AI calls: {state.totalAiCalls}\n"
f"- Tool calls: {state.totalToolCalls}\n"
f"- Cost: {state.totalCostCHF:.4f} CHF\n"
f"- Processing time: {state.totalProcessingTime:.1f}s"
)