407 lines
16 KiB
Python
407 lines
16 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""Agent loop: ReAct pattern with native function calling, budget control, and error handling."""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import json
|
|
import re
|
|
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Awaitable
|
|
|
|
from modules.datamodels.datamodelAi import (
|
|
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum
|
|
)
|
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
|
AgentState, AgentStatusEnum, AgentConfig, AgentEvent, AgentEventTypeEnum,
|
|
ToolCallRequest, ToolResult, ToolCallLog, AgentRoundLog, AgentTrace
|
|
)
|
|
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
|
|
from modules.serviceCenter.services.serviceAgent.conversationManager import (
|
|
ConversationManager, buildSystemPrompt
|
|
)
|
|
from modules.shared.timeUtils import getUtcTimestamp
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def runAgentLoop(
|
|
prompt: str,
|
|
toolRegistry: ToolRegistry,
|
|
config: AgentConfig,
|
|
aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]],
|
|
getWorkflowCostFn: Callable[[], Awaitable[float]],
|
|
workflowId: str,
|
|
userId: str = "",
|
|
featureInstanceId: str = "",
|
|
buildRagContextFn: Callable[..., Awaitable[str]] = None,
|
|
mandateId: str = "",
|
|
aiCallStreamFn: Callable = None,
|
|
userLanguage: str = "",
|
|
conversationHistory: List[Dict[str, Any]] = None,
|
|
) -> AsyncGenerator[AgentEvent, None]:
|
|
"""Run the agent loop. Yields AgentEvent for each step (SSE-ready).
|
|
|
|
Args:
|
|
prompt: User prompt
|
|
toolRegistry: Registry with available tools
|
|
config: Agent configuration (maxRounds, maxCostCHF, etc.)
|
|
aiCallFn: Function to call the AI (wraps serviceAi.callAi with billing)
|
|
getWorkflowCostFn: Function to get current workflow cost
|
|
workflowId: Workflow ID for tracking
|
|
userId: User ID for tracing
|
|
featureInstanceId: Feature instance ID for tracing
|
|
buildRagContextFn: Optional async function to build RAG context before each round
|
|
mandateId: Mandate ID for RAG scoping
|
|
userLanguage: ISO 639-1 language code for agent responses
|
|
conversationHistory: Prior messages [{role, content/message}] for follow-up context
|
|
"""
|
|
state = AgentState(workflowId=workflowId, maxRounds=config.maxRounds)
|
|
trace = AgentTrace(
|
|
workflowId=workflowId, userId=userId,
|
|
featureInstanceId=featureInstanceId
|
|
)
|
|
|
|
tools = toolRegistry.getTools()
|
|
toolDefinitions = toolRegistry.formatToolsForFunctionCalling()
|
|
toolsText = toolRegistry.formatToolsForPrompt()
|
|
|
|
systemPrompt = buildSystemPrompt(tools, toolsText, userLanguage=userLanguage)
|
|
conversation = ConversationManager(systemPrompt)
|
|
if conversationHistory:
|
|
conversation.loadHistory(conversationHistory)
|
|
conversation.addUserMessage(prompt)
|
|
|
|
while state.status == AgentStatusEnum.RUNNING and state.currentRound < state.maxRounds:
|
|
await asyncio.sleep(0)
|
|
state.currentRound += 1
|
|
roundStartTime = time.time()
|
|
roundLog = AgentRoundLog(roundNumber=state.currentRound)
|
|
|
|
# RAG context injection (before each round for fresh relevance)
|
|
if buildRagContextFn:
|
|
try:
|
|
latestUserMsg = ""
|
|
for msg in reversed(conversation.messages):
|
|
if msg.get("role") == "user":
|
|
latestUserMsg = msg.get("content", "")
|
|
break
|
|
ragContext = await buildRagContextFn(
|
|
currentPrompt=latestUserMsg or prompt,
|
|
workflowId=workflowId,
|
|
userId=userId,
|
|
featureInstanceId=featureInstanceId,
|
|
mandateId=mandateId,
|
|
)
|
|
if ragContext:
|
|
conversation.injectRagContext(ragContext)
|
|
except Exception as ragErr:
|
|
logger.warning(f"RAG context injection failed (non-blocking): {ragErr}")
|
|
|
|
# Budget check
|
|
budgetExceeded = await _checkBudget(config, getWorkflowCostFn)
|
|
if budgetExceeded:
|
|
state.status = AgentStatusEnum.BUDGET_EXCEEDED
|
|
state.abortReason = "Workflow cost budget exceeded"
|
|
yield AgentEvent(
|
|
type=AgentEventTypeEnum.FINAL,
|
|
content=_buildProgressSummary(state, "Budget exceeded. Here is the progress so far.")
|
|
)
|
|
break
|
|
|
|
logger.info(f"Agent round {state.currentRound}/{state.maxRounds} for workflow {workflowId} (tools={state.totalToolCalls}, cost={state.totalCostCHF:.4f})")
|
|
yield AgentEvent(
|
|
type=AgentEventTypeEnum.AGENT_PROGRESS,
|
|
data={
|
|
"round": state.currentRound,
|
|
"maxRounds": state.maxRounds,
|
|
"totalAiCalls": state.totalAiCalls,
|
|
"totalToolCalls": state.totalToolCalls,
|
|
"costCHF": state.totalCostCHF
|
|
}
|
|
)
|
|
|
|
# Progressive summarization
|
|
if conversation.needsSummarization(state.currentRound):
|
|
async def _summarizeCall(summaryPrompt: str) -> str:
|
|
req = AiCallRequest(
|
|
prompt=summaryPrompt,
|
|
options=AiCallOptions(operationType=OperationTypeEnum.DATA_ANALYSE)
|
|
)
|
|
resp = await aiCallFn(req)
|
|
state.totalCostCHF += resp.priceCHF
|
|
state.totalAiCalls += 1
|
|
return resp.content
|
|
|
|
await conversation.summarize(state.currentRound, _summarizeCall)
|
|
|
|
# AI call
|
|
aiRequest = AiCallRequest(
|
|
prompt="",
|
|
options=AiCallOptions(
|
|
operationType=OperationTypeEnum.AGENT,
|
|
temperature=config.temperature
|
|
),
|
|
messages=conversation.messages,
|
|
tools=toolDefinitions
|
|
)
|
|
|
|
try:
|
|
aiResponse = None
|
|
streamedText = ""
|
|
isFirstChunkOfRound = True
|
|
|
|
if aiCallStreamFn:
|
|
async for chunk in aiCallStreamFn(aiRequest):
|
|
if isinstance(chunk, str):
|
|
if isFirstChunkOfRound and state.currentRound > 1:
|
|
chunk = "\n\n" + chunk
|
|
isFirstChunkOfRound = False
|
|
elif isFirstChunkOfRound:
|
|
isFirstChunkOfRound = False
|
|
streamedText += chunk
|
|
yield AgentEvent(type=AgentEventTypeEnum.CHUNK, content=chunk)
|
|
else:
|
|
aiResponse = chunk
|
|
|
|
if aiResponse is None:
|
|
raise RuntimeError("Stream ended without final AiCallResponse")
|
|
else:
|
|
aiResponse = await aiCallFn(aiRequest)
|
|
|
|
except Exception as e:
|
|
logger.error(f"AI call failed in round {state.currentRound}: {e}", exc_info=True)
|
|
state.status = AgentStatusEnum.ERROR
|
|
state.abortReason = f"AI call error: {e}"
|
|
yield AgentEvent(type=AgentEventTypeEnum.ERROR, content=str(e))
|
|
break
|
|
|
|
state.totalAiCalls += 1
|
|
state.totalCostCHF += aiResponse.priceCHF
|
|
state.totalProcessingTime += aiResponse.processingTime
|
|
roundLog.aiModel = aiResponse.modelName
|
|
roundLog.costCHF = aiResponse.priceCHF
|
|
|
|
if aiResponse.errorCount > 0:
|
|
state.status = AgentStatusEnum.ERROR
|
|
state.abortReason = f"AI returned error: {aiResponse.content}"
|
|
yield AgentEvent(type=AgentEventTypeEnum.ERROR, content=aiResponse.content)
|
|
break
|
|
|
|
# Parse response for tool calls
|
|
toolCalls = _parseToolCalls(aiResponse)
|
|
textContent = _extractTextContent(aiResponse)
|
|
|
|
if textContent and not streamedText:
|
|
yield AgentEvent(type=AgentEventTypeEnum.MESSAGE, content=textContent)
|
|
|
|
if not toolCalls:
|
|
state.status = AgentStatusEnum.COMPLETED
|
|
conversation.addAssistantMessage(aiResponse.content)
|
|
roundLog.durationMs = int((time.time() - roundStartTime) * 1000)
|
|
trace.rounds.append(roundLog)
|
|
yield AgentEvent(type=AgentEventTypeEnum.FINAL, content=textContent or aiResponse.content)
|
|
break
|
|
|
|
# Add assistant message with tool calls to conversation
|
|
assistantToolCalls = _formatAssistantToolCalls(toolCalls)
|
|
conversation.addAssistantMessage(textContent or "", assistantToolCalls)
|
|
|
|
# Execute tool calls
|
|
for tc in toolCalls:
|
|
yield AgentEvent(
|
|
type=AgentEventTypeEnum.TOOL_CALL,
|
|
data={"toolName": tc.name, "args": tc.args}
|
|
)
|
|
|
|
results = await _executeToolCalls(toolCalls, toolRegistry, {
|
|
"workflowId": workflowId,
|
|
"userId": userId,
|
|
"featureInstanceId": featureInstanceId,
|
|
"mandateId": mandateId,
|
|
})
|
|
state.totalToolCalls += len(results)
|
|
|
|
for result in results:
|
|
roundLog.toolCalls.append(ToolCallLog(
|
|
toolName=result.toolName,
|
|
args=next((tc.args for tc in toolCalls if tc.id == result.toolCallId), {}),
|
|
success=result.success,
|
|
durationMs=result.durationMs,
|
|
error=result.error
|
|
))
|
|
if not result.success:
|
|
logger.warning(f"Tool '{result.toolName}' failed: {result.error}")
|
|
yield AgentEvent(
|
|
type=AgentEventTypeEnum.TOOL_RESULT,
|
|
data={
|
|
"toolName": result.toolName,
|
|
"success": result.success,
|
|
"data": result.data[:500] if result.data else "",
|
|
"error": result.error
|
|
}
|
|
)
|
|
if result.sideEvents:
|
|
for sideEvt in result.sideEvents:
|
|
evtType = sideEvt.get("type", "")
|
|
try:
|
|
evtEnum = AgentEventTypeEnum(evtType)
|
|
except (ValueError, KeyError):
|
|
continue
|
|
yield AgentEvent(
|
|
type=evtEnum,
|
|
data=sideEvt.get("data"),
|
|
content=sideEvt.get("content"),
|
|
)
|
|
|
|
# Add tool results to conversation
|
|
toolResultMessages = [
|
|
{"toolCallId": r.toolCallId, "toolName": r.toolName,
|
|
"content": r.data if r.success else f"Error: {r.error}"}
|
|
for r in results
|
|
]
|
|
conversation.addToolResults(toolResultMessages)
|
|
|
|
roundLog.durationMs = int((time.time() - roundStartTime) * 1000)
|
|
trace.rounds.append(roundLog)
|
|
|
|
# maxRounds reached
|
|
if state.currentRound >= state.maxRounds and state.status == AgentStatusEnum.RUNNING:
|
|
state.status = AgentStatusEnum.MAX_ROUNDS_REACHED
|
|
state.abortReason = f"Maximum rounds ({state.maxRounds}) reached"
|
|
yield AgentEvent(
|
|
type=AgentEventTypeEnum.FINAL,
|
|
content=_buildProgressSummary(state, "Maximum rounds reached.")
|
|
)
|
|
|
|
# Agent summary
|
|
trace.completedAt = getUtcTimestamp()
|
|
trace.status = state.status
|
|
trace.totalRounds = state.currentRound
|
|
trace.totalToolCalls = state.totalToolCalls
|
|
trace.totalCostCHF = state.totalCostCHF
|
|
trace.abortReason = state.abortReason
|
|
|
|
yield AgentEvent(
|
|
type=AgentEventTypeEnum.AGENT_SUMMARY,
|
|
data={
|
|
"rounds": state.currentRound,
|
|
"totalAiCalls": state.totalAiCalls,
|
|
"totalToolCalls": state.totalToolCalls,
|
|
"costCHF": round(state.totalCostCHF, 4),
|
|
"processingTime": round(state.totalProcessingTime, 2),
|
|
"status": state.status.value,
|
|
"abortReason": state.abortReason
|
|
}
|
|
)
|
|
|
|
|
|
async def _checkBudget(config: AgentConfig,
|
|
getWorkflowCostFn: Callable[[], Awaitable[float]]) -> bool:
|
|
"""Check if workflow budget is exceeded. Returns True if exceeded."""
|
|
if config.maxCostCHF is None:
|
|
return False
|
|
try:
|
|
currentCost = await getWorkflowCostFn()
|
|
return currentCost > config.maxCostCHF
|
|
except Exception as e:
|
|
logger.warning(f"Could not check workflow cost: {e}")
|
|
return False
|
|
|
|
|
|
async def _executeToolCalls(toolCalls: List[ToolCallRequest],
|
|
toolRegistry: ToolRegistry,
|
|
context: Dict[str, Any]) -> List[ToolResult]:
|
|
"""Execute tool calls: readOnly tools in parallel, others sequentially."""
|
|
readOnlyCalls = [tc for tc in toolCalls if toolRegistry.isReadOnly(tc.name)]
|
|
writeCalls = [tc for tc in toolCalls if not toolRegistry.isReadOnly(tc.name)]
|
|
|
|
results: Dict[str, ToolResult] = {}
|
|
|
|
if readOnlyCalls:
|
|
readResults = await asyncio.gather(*[
|
|
toolRegistry.dispatch(tc, context) for tc in readOnlyCalls
|
|
])
|
|
for tc, result in zip(readOnlyCalls, readResults):
|
|
results[tc.id] = result
|
|
|
|
for tc in writeCalls:
|
|
results[tc.id] = await toolRegistry.dispatch(tc, context)
|
|
|
|
return [results[tc.id] for tc in toolCalls]
|
|
|
|
|
|
def _parseToolCalls(aiResponse: AiCallResponse) -> List[ToolCallRequest]:
|
|
"""Parse tool calls from AI response. Supports native function calling and text-based fallback."""
|
|
toolCalls = []
|
|
|
|
# Native function calling: check response metadata
|
|
if hasattr(aiResponse, 'toolCalls') and aiResponse.toolCalls:
|
|
for tc in aiResponse.toolCalls:
|
|
rawArgs = tc["function"]["arguments"]
|
|
if isinstance(rawArgs, str):
|
|
rawArgs = rawArgs.strip()
|
|
try:
|
|
parsedArgs = json.loads(rawArgs) if rawArgs else {}
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse tool args for '{tc['function']['name']}': {rawArgs[:200]}")
|
|
parsedArgs = {}
|
|
else:
|
|
parsedArgs = rawArgs if rawArgs else {}
|
|
toolCalls.append(ToolCallRequest(
|
|
id=tc.get("id", str(len(toolCalls))),
|
|
name=tc["function"]["name"],
|
|
args=parsedArgs,
|
|
))
|
|
return toolCalls
|
|
|
|
# Text-based fallback: parse ```tool_call blocks
|
|
content = aiResponse.content or ""
|
|
pattern = r"```tool_call\s*\n\s*tool:\s*(\S+)\s*\n\s*args:\s*(\{.*?\})\s*\n\s*```"
|
|
matches = re.finditer(pattern, content, re.DOTALL)
|
|
|
|
for match in matches:
|
|
toolName = match.group(1).strip()
|
|
argsStr = match.group(2).strip()
|
|
try:
|
|
args = json.loads(argsStr)
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse tool args for '{toolName}': {argsStr}")
|
|
args = {}
|
|
toolCalls.append(ToolCallRequest(name=toolName, args=args))
|
|
|
|
return toolCalls
|
|
|
|
|
|
def _extractTextContent(aiResponse: AiCallResponse) -> str:
|
|
"""Extract text content from AI response, removing tool_call blocks."""
|
|
content = aiResponse.content or ""
|
|
cleaned = re.sub(r"```tool_call\s*\n.*?\n\s*```", "", content, flags=re.DOTALL)
|
|
return cleaned.strip()
|
|
|
|
|
|
def _formatAssistantToolCalls(toolCalls: List[ToolCallRequest]) -> List[Dict[str, Any]]:
|
|
"""Format tool calls for the conversation history (OpenAI tool_calls format)."""
|
|
return [
|
|
{
|
|
"id": tc.id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": tc.name,
|
|
"arguments": json.dumps(tc.args)
|
|
}
|
|
}
|
|
for tc in toolCalls
|
|
]
|
|
|
|
|
|
def _buildProgressSummary(state: AgentState, reason: str) -> str:
|
|
"""Build a human-readable summary of agent progress for graceful termination."""
|
|
return (
|
|
f"{reason}\n\n"
|
|
f"Progress after {state.currentRound} rounds:\n"
|
|
f"- AI calls: {state.totalAiCalls}\n"
|
|
f"- Tool calls: {state.totalToolCalls}\n"
|
|
f"- Cost: {state.totalCostCHF:.4f} CHF\n"
|
|
f"- Processing time: {state.totalProcessingTime:.1f}s"
|
|
)
|