# Copyright (c) 2025 Patrick Motsch # All rights reserved. """Agent loop: ReAct pattern with native function calling, budget control, and error handling.""" import asyncio import logging import time import json import re from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Awaitable from modules.datamodels.datamodelAi import ( AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum ) from modules.serviceCenter.services.serviceAgent.datamodelAgent import ( AgentState, AgentStatusEnum, AgentConfig, AgentEvent, AgentEventTypeEnum, ToolCallRequest, ToolResult, ToolCallLog, AgentRoundLog, AgentTrace ) from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry from modules.serviceCenter.services.serviceAgent.conversationManager import ( ConversationManager, buildSystemPrompt ) from modules.shared.timeUtils import getUtcTimestamp from modules.shared.jsonUtils import closeJsonStructures from modules.serviceCenter.services.serviceBilling.mainServiceBilling import ( InsufficientBalanceException, ) from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( SubscriptionInactiveException, ) logger = logging.getLogger(__name__) async def runAgentLoop( prompt: str, toolRegistry: ToolRegistry, config: AgentConfig, aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]], getWorkflowCostFn: Callable[[], Awaitable[float]], workflowId: str, userId: str = "", featureInstanceId: str = "", buildRagContextFn: Callable[..., Awaitable[str]] = None, mandateId: str = "", aiCallStreamFn: Callable = None, userLanguage: str = "", conversationHistory: List[Dict[str, Any]] = None, persistRoundMemoryFn: Callable[..., Awaitable[None]] = None, getExternalMemoryKeysFn: Callable[[], List[str]] = None, systemPromptOverride: str = None, ) -> AsyncGenerator[AgentEvent, None]: """Run the agent loop. Yields AgentEvent for each step (SSE-ready). Args: prompt: User prompt toolRegistry: Registry with available tools config: Agent configuration (maxRounds, maxCostCHF, etc.) aiCallFn: Function to call the AI (wraps serviceAi.callAi with billing) getWorkflowCostFn: Function to get current workflow cost workflowId: Workflow ID for tracking userId: User ID for tracing featureInstanceId: Feature instance ID for tracing buildRagContextFn: Optional async function to build RAG context before each round mandateId: Mandate ID for RAG scoping userLanguage: ISO 639-1 language code for agent responses conversationHistory: Prior messages [{role, content/message}] for follow-up context persistRoundMemoryFn: Optional callback to persist round memories after tool execution getExternalMemoryKeysFn: Optional callback that returns RoundMemory keys for this workflow, used by summarization to de-duplicate persisted facts """ state = AgentState(workflowId=workflowId, maxRounds=config.maxRounds) trace = AgentTrace( workflowId=workflowId, userId=userId, featureInstanceId=featureInstanceId ) activeToolSet = config.toolSet if config else None tools = toolRegistry.getTools(toolSet=activeToolSet) toolDefinitions = toolRegistry.formatToolsForFunctionCalling(toolSet=activeToolSet) # Text-based tool descriptions are ONLY used as fallback when native function # calling is unavailable. Including both creates conflicting instructions # (text ```tool_call format vs native tool_use blocks) and can cause the model # to respond with plain text instead of actual tool calls. toolsText = "" if toolDefinitions else toolRegistry.formatToolsForPrompt(toolSet=activeToolSet) if systemPromptOverride: systemPrompt = systemPromptOverride else: systemPrompt = buildSystemPrompt(tools, toolsText, userLanguage=userLanguage) conversation = ConversationManager(systemPrompt) if conversationHistory: conversation.loadHistory(conversationHistory) conversation.addUserMessage(prompt, isCurrentPrompt=True) while state.status == AgentStatusEnum.RUNNING and state.currentRound < state.maxRounds: await asyncio.sleep(0) state.currentRound += 1 roundStartTime = time.time() roundLog = AgentRoundLog(roundNumber=state.currentRound) # RAG context injection (before each round for fresh relevance) if buildRagContextFn: try: latestUserMsg = "" for msg in reversed(conversation.messages): if msg.get("role") == "user": latestUserMsg = msg.get("content", "") break ragContext = await buildRagContextFn( currentPrompt=latestUserMsg or prompt, workflowId=workflowId, userId=userId, featureInstanceId=featureInstanceId, mandateId=mandateId, ) if ragContext: conversation.injectRagContext(ragContext) except Exception as ragErr: logger.warning(f"RAG context injection failed (non-blocking): {ragErr}") # Budget check budgetExceeded = await _checkBudget(config, getWorkflowCostFn) if budgetExceeded: state.status = AgentStatusEnum.BUDGET_EXCEEDED state.abortReason = "Workflow cost budget exceeded" yield AgentEvent( type=AgentEventTypeEnum.FINAL, content=_buildProgressSummary(state, "Budget exceeded. Here is the progress so far.") ) break logger.info(f"Agent round {state.currentRound}/{state.maxRounds} for workflow {workflowId} (tools={state.totalToolCalls}, cost={state.totalCostCHF:.4f})") yield AgentEvent( type=AgentEventTypeEnum.AGENT_PROGRESS, data={ "round": state.currentRound, "maxRounds": state.maxRounds, "totalAiCalls": state.totalAiCalls, "totalToolCalls": state.totalToolCalls, "costCHF": state.totalCostCHF } ) # Progressive summarization if conversation.needsSummarization(state.currentRound): async def _summarizeCall(summaryPrompt: str) -> str: req = AiCallRequest( prompt=summaryPrompt, options=AiCallOptions(operationType=OperationTypeEnum.DATA_ANALYSE) ) resp = await aiCallFn(req) state.totalCostCHF += resp.priceCHF state.totalAiCalls += 1 return resp.content memKeys: List[str] = [] if getExternalMemoryKeysFn: try: memKeys = getExternalMemoryKeysFn() except Exception as e: logger.warning(f"getExternalMemoryKeysFn failed: {e}") await conversation.summarize( state.currentRound, _summarizeCall, externalMemoryKeys=memKeys or None ) # AI call aiRequest = AiCallRequest( prompt="", options=AiCallOptions( operationType=config.operationType or OperationTypeEnum.AGENT, temperature=config.temperature ), messages=conversation.messages, tools=toolDefinitions if toolDefinitions else None, ) try: aiResponse = None streamedText = "" isFirstChunkOfRound = True if aiCallStreamFn: async for chunk in aiCallStreamFn(aiRequest): if isinstance(chunk, str): if isFirstChunkOfRound and state.currentRound > 1: chunk = "\n\n" + chunk isFirstChunkOfRound = False elif isFirstChunkOfRound: isFirstChunkOfRound = False streamedText += chunk yield AgentEvent(type=AgentEventTypeEnum.CHUNK, content=chunk) else: aiResponse = chunk if aiResponse is None: raise RuntimeError("Stream ended without final AiCallResponse") else: aiResponse = await aiCallFn(aiRequest) except SubscriptionInactiveException as e: logger.warning( f"Subscription inactive in round {state.currentRound} (mandate={mandateId}): {e.message}" ) state.status = AgentStatusEnum.ERROR state.abortReason = e.message yield AgentEvent( type=AgentEventTypeEnum.ERROR, content=e.message, data=e.toClientDict(), ) break except InsufficientBalanceException as e: logger.warning( f"Insufficient balance in round {state.currentRound} (mandate={mandateId}): {e.message}" ) state.status = AgentStatusEnum.ERROR state.abortReason = e.message yield AgentEvent( type=AgentEventTypeEnum.ERROR, content=e.message, data=e.toClientDict(), ) break except Exception as e: logger.error(f"AI call failed in round {state.currentRound}: {e}", exc_info=True) state.status = AgentStatusEnum.ERROR state.abortReason = f"AI call error: {e}" yield AgentEvent(type=AgentEventTypeEnum.ERROR, content=str(e)) break state.totalAiCalls += 1 state.totalCostCHF += aiResponse.priceCHF state.totalProcessingTime += aiResponse.processingTime roundLog.aiModel = aiResponse.modelName roundLog.costCHF = aiResponse.priceCHF if aiResponse.errorCount > 0: state.status = AgentStatusEnum.ERROR state.abortReason = f"AI returned error: {aiResponse.content}" yield AgentEvent(type=AgentEventTypeEnum.ERROR, content=aiResponse.content) break # Parse response for tool calls toolCalls = _parseToolCalls(aiResponse) textContent = _extractTextContent(aiResponse) logger.debug( f"Round {state.currentRound} AI response: model={aiResponse.modelName}, " f"toolCalls={len(toolCalls)}, nativeToolCalls={'yes' if aiResponse.toolCalls else 'no'}, " f"contentLen={len(aiResponse.content)}, streamedLen={len(streamedText)}" ) # Empty response (no content, no tool calls) = model returned nothing useful. # Burn the round but let the loop continue so the next iteration can retry # (the failover mechanism in the AI layer will try alternative models). if not toolCalls and not textContent and not streamedText: logger.warning( f"Round {state.currentRound}: AI returned empty response " f"(model={aiResponse.modelName}). Retrying next round." ) conversation.addUserMessage( "Your previous response was empty. Please use the available tools " "to accomplish the task. Start by planning the steps, then call the " "appropriate tools." ) roundLog.durationMs = int((time.time() - roundStartTime) * 1000) trace.rounds.append(roundLog) continue if textContent and not streamedText: yield AgentEvent(type=AgentEventTypeEnum.MESSAGE, content=textContent) if not toolCalls: state.status = AgentStatusEnum.COMPLETED conversation.addAssistantMessage(aiResponse.content) roundLog.durationMs = int((time.time() - roundStartTime) * 1000) trace.rounds.append(roundLog) yield AgentEvent(type=AgentEventTypeEnum.FINAL, content=textContent or aiResponse.content) break # Add assistant message with tool calls to conversation assistantToolCalls = _formatAssistantToolCalls(toolCalls) conversation.addAssistantMessage(textContent or "", assistantToolCalls) # Execute tool calls for tc in toolCalls: yield AgentEvent( type=AgentEventTypeEnum.TOOL_CALL, data={"toolName": tc.name, "args": tc.args} ) results = await _executeToolCalls(toolCalls, toolRegistry, { "workflowId": workflowId, "userId": userId, "featureInstanceId": featureInstanceId, "mandateId": mandateId, "modelMaxOutputTokens": getattr(aiResponse, "_modelMaxTokens", None) or 0, # Propagate the parent agent's budget to sub-agent tools (e.g. # queryFeatureInstance) so they don't cap themselves at a smaller # hardcoded round count than the user configured for the workspace. "parentMaxRounds": state.maxRounds, "parentMaxCostCHF": config.maxCostCHF, }) state.totalToolCalls += len(results) for result in results: roundLog.toolCalls.append(ToolCallLog( toolName=result.toolName, args=next((tc.args for tc in toolCalls if tc.id == result.toolCallId), {}), success=result.success, durationMs=result.durationMs, error=result.error, resultData=result.data[:300] if result.data else "", )) if not result.success: logger.warning(f"Tool '{result.toolName}' failed: {result.error}") yield AgentEvent( type=AgentEventTypeEnum.TOOL_RESULT, data={ "toolName": result.toolName, "success": result.success, "data": result.data[:500] if result.data else "", "error": result.error } ) if result.sideEvents: for sideEvt in result.sideEvents: evtType = sideEvt.get("type", "") try: evtEnum = AgentEventTypeEnum(evtType) except (ValueError, KeyError): continue yield AgentEvent( type=evtEnum, data=sideEvt.get("data"), content=sideEvt.get("content"), ) # Check if requestToolbox was called -- refresh tool definitions for next round _toolboxEscalated = False for result in results: if result.toolName == "requestToolbox" and result.success: _toolboxEscalated = True if _toolboxEscalated: tools = toolRegistry.getTools(toolSet=activeToolSet) toolDefinitions = toolRegistry.formatToolsForFunctionCalling(toolSet=activeToolSet) toolsText = "" if toolDefinitions else toolRegistry.formatToolsForPrompt(toolSet=activeToolSet) logger.info("Toolbox escalation: refreshed tool definitions (%d tools)", len(tools)) # Add tool results to conversation toolResultMessages = [ {"toolCallId": r.toolCallId, "toolName": r.toolName, "content": r.data if r.success else f"Error: {r.error}"} for r in results ] conversation.addToolResults(toolResultMessages) # Persist round memories (file refs, tool results, decisions) if persistRoundMemoryFn: try: await persistRoundMemoryFn( toolCalls=toolCalls, results=results, textContent=textContent, roundNumber=state.currentRound, ) except Exception as memErr: logger.warning(f"RoundMemory persist failed (non-blocking): {memErr}") roundLog.durationMs = int((time.time() - roundStartTime) * 1000) trace.rounds.append(roundLog) # maxRounds reached if state.currentRound >= state.maxRounds and state.status == AgentStatusEnum.RUNNING: state.status = AgentStatusEnum.MAX_ROUNDS_REACHED state.abortReason = f"Maximum rounds ({state.maxRounds}) reached" yield AgentEvent( type=AgentEventTypeEnum.FINAL, content=_buildProgressSummary(state, "Maximum rounds reached.") ) # Agent summary trace.completedAt = getUtcTimestamp() trace.status = state.status trace.totalRounds = state.currentRound trace.totalToolCalls = state.totalToolCalls trace.totalCostCHF = state.totalCostCHF trace.abortReason = state.abortReason artifactSummary = _buildArtifactSummary(trace.rounds) yield AgentEvent( type=AgentEventTypeEnum.AGENT_SUMMARY, data={ "rounds": state.currentRound, "totalAiCalls": state.totalAiCalls, "totalToolCalls": state.totalToolCalls, "costCHF": round(state.totalCostCHF, 4), "processingTime": round(state.totalProcessingTime, 2), "status": state.status.value, "abortReason": state.abortReason, "artifacts": artifactSummary, } ) async def _checkBudget(config: AgentConfig, getWorkflowCostFn: Callable[[], Awaitable[float]]) -> bool: """Check if workflow budget is exceeded. Returns True if exceeded.""" if config.maxCostCHF is None: return False try: currentCost = await getWorkflowCostFn() return currentCost > config.maxCostCHF except Exception as e: logger.warning(f"Could not check workflow cost: {e}") return False async def _executeToolCalls(toolCalls: List[ToolCallRequest], toolRegistry: ToolRegistry, context: Dict[str, Any]) -> List[ToolResult]: """Execute tool calls: readOnly tools in parallel, others sequentially. Tool calls with _parseError (truncated JSON from LLM) are short-circuited with an error result so the agent can retry. """ readOnlyCalls = [tc for tc in toolCalls if toolRegistry.isReadOnly(tc.name)] writeCalls = [tc for tc in toolCalls if not toolRegistry.isReadOnly(tc.name)] results: Dict[str, ToolResult] = {} for tc in toolCalls: if "_parseError" in tc.args: results[tc.id] = ToolResult( toolCallId=tc.id, toolName=tc.name, success=False, data="", error=tc.args["_parseError"], durationMs=0, ) activeCalls = [tc for tc in toolCalls if tc.id not in results] activeReadOnly = [tc for tc in activeCalls if toolRegistry.isReadOnly(tc.name)] activeWrite = [tc for tc in activeCalls if not toolRegistry.isReadOnly(tc.name)] if activeReadOnly: readResults = await asyncio.gather(*[ toolRegistry.dispatch(tc, context) for tc in activeReadOnly ]) for tc, result in zip(activeReadOnly, readResults): results[tc.id] = result for tc in activeWrite: results[tc.id] = await toolRegistry.dispatch(tc, context) return [results[tc.id] for tc in toolCalls] def _repairTruncatedJson(raw: str) -> Optional[Dict[str, Any]]: """Repair truncated JSON using the shared jsonUtils toolbox. Uses closeJsonStructures which handles open strings, brackets, braces, and trailing commas with stack-based structure tracking. Returns parsed dict on success, None if unrecoverable. """ if not raw or not raw.strip().startswith("{"): return None try: closed = closeJsonStructures(raw) return json.loads(closed) except (json.JSONDecodeError, Exception): return None def _validateRepairedToolArgs(toolName: str, args: Dict[str, Any]) -> Optional[str]: """After closeJsonStructures + json.loads, args can be syntactically valid but useless (truncation cut off before required fields). Return a user-facing _parseError message, or None if OK. Without this, renderDocument runs with missing `content` and only returns \"content is required\", hiding the real cause (output token limit). """ if toolName == "renderDocument": content = args.get("content") sourceFileId = args.get("sourceFileId") hasInline = isinstance(content, str) and bool(content.strip()) hasFile = isinstance(sourceFileId, str) and bool(sourceFileId.strip()) if not hasInline and not hasFile: return ( "Your tool call JSON was repaired after truncation, but neither `content` nor `sourceFileId` is usable. " "Large documents must not be inlined in the tool call (output limit).\n" "Preferred: writeFile(mode='create') + writeFile(mode='append') to build a .md file, then " "renderDocument(sourceFileId=, outputFormat='pdf', title='...') — the tool call stays small.\n" "Alternatives: replaceInFile for edits; shorter outline first." ) return None def _parseToolCalls(aiResponse: AiCallResponse) -> List[ToolCallRequest]: """Parse tool calls from AI response. Supports native function calling and text-based fallback.""" toolCalls = [] # Native function calling: check response metadata if hasattr(aiResponse, 'toolCalls') and aiResponse.toolCalls: for tc in aiResponse.toolCalls: rawArgs = tc["function"]["arguments"] if isinstance(rawArgs, str): rawArgs = rawArgs.strip() try: parsedArgs = json.loads(rawArgs) if rawArgs else {} except json.JSONDecodeError: parsedArgs = _repairTruncatedJson(rawArgs) if parsedArgs is None: logger.warning(f"Unrecoverable truncated JSON for '{tc['function']['name']}': {rawArgs[:200]}") parsedArgs = {"_parseError": ( "Your tool call arguments were truncated (output cut off by token limit). " "Do not put the full document body in renderDocument JSON.\n" "1. writeFile(create) + writeFile(append) to a .md file, then " "renderDocument(sourceFileId=, outputFormat=..., title=...) — tiny tool call.\n" "2. Or replaceInFile for targeted edits.\n" "3. Or split into multiple smaller files." )} else: logger.info(f"Repaired truncated JSON for '{tc['function']['name']}'") repairIssue = _validateRepairedToolArgs(tc["function"]["name"], parsedArgs) if repairIssue: logger.warning( f"Repaired JSON for '{tc['function']['name']}' still invalid for execution: {repairIssue[:80]}..." ) parsedArgs = {"_parseError": repairIssue} else: parsedArgs = rawArgs if rawArgs else {} toolCalls.append(ToolCallRequest( id=tc.get("id", str(len(toolCalls))), name=tc["function"]["name"], args=parsedArgs, )) return toolCalls # Text-based fallback: parse ```tool_call blocks content = aiResponse.content or "" pattern = r"```tool_call\s*\n\s*tool:\s*(\S+)\s*\n\s*args:\s*(\{.*?\})\s*\n\s*```" matches = re.finditer(pattern, content, re.DOTALL) for match in matches: toolName = match.group(1).strip() argsStr = match.group(2).strip() try: args = json.loads(argsStr) except json.JSONDecodeError: logger.warning(f"Failed to parse tool args for '{toolName}': {argsStr}") args = {} toolCalls.append(ToolCallRequest(name=toolName, args=args)) return toolCalls def _extractTextContent(aiResponse: AiCallResponse) -> str: """Extract text content from AI response, removing tool_call blocks.""" content = aiResponse.content or "" cleaned = re.sub(r"```tool_call\s*\n.*?\n\s*```", "", content, flags=re.DOTALL) return cleaned.strip() def _formatAssistantToolCalls(toolCalls: List[ToolCallRequest]) -> List[Dict[str, Any]]: """Format tool calls for the conversation history (OpenAI tool_calls format).""" return [ { "id": tc.id, "type": "function", "function": { "name": tc.name, "arguments": json.dumps(tc.args) } } for tc in toolCalls ] def _buildProgressSummary(state: AgentState, reason: str) -> str: """Build a human-readable summary of agent progress for graceful termination.""" return ( f"{reason}\n\n" f"Progress after {state.currentRound} rounds:\n" f"- AI calls: {state.totalAiCalls}\n" f"- Tool calls: {state.totalToolCalls}\n" f"- Cost: {state.totalCostCHF:.4f} CHF\n" f"- Processing time: {state.totalProcessingTime:.1f}s" ) _FILE_REF_TOOLS = {"readFile", "readContentObjects", "describeImage", "listFiles"} _DATA_SOURCE_TOOLS = {"browseDataSource", "searchDataSource", "downloadFromDataSource"} _DECISION_TOOLS = {"writeFile", "replaceInFile"} def classifyToolResult( tc: ToolCallRequest, result: ToolResult ) -> Optional[Dict[str, Any]]: """Classify a successful tool result into a RoundMemory dict. Returns a dict with keys {memoryType, key, summary, fullData, fileIds} or None if the result is not worth persisting. """ name = tc.name data = result.data or "" if len(data) < 50: return None truncSummary = data[:2000] fullData = data if len(data) < 8000 else None fileId = tc.args.get("fileId", "") fileIds = [fileId] if fileId else [] if name in _FILE_REF_TOOLS: return { "memoryType": "file_ref", "key": f"{name}:{fileId}" if fileId else name, "summary": truncSummary, "fullData": fullData, "fileIds": fileIds, } if name in _DATA_SOURCE_TOOLS: dsId = tc.args.get("dataSourceId", "") or tc.args.get("featureDataSourceId", "") path = tc.args.get("path", "") return { "memoryType": "data_source_ref", "key": f"{name}:{dsId}:{path}" if dsId else name, "summary": truncSummary, "fullData": fullData, "fileIds": fileIds, } if name in _DECISION_TOOLS: return { "memoryType": "decision", "key": f"{name}:{fileId}" if fileId else name, "summary": truncSummary, "fullData": None, "fileIds": fileIds, } if name == "queryFeatureInstance": return { "memoryType": "tool_result", "key": f"queryFeatureInstance:{tc.args.get('query', '')[:60]}", "summary": truncSummary, "fullData": fullData, "fileIds": [], } if len(data) > 500: return { "memoryType": "tool_result", "key": f"{name}:{tc.id}", "summary": truncSummary, "fullData": fullData, "fileIds": fileIds, } return None _ARTIFACT_TOOLS = {"writeFile", "replaceInFile", "deleteFile", "renameFile", "copyFile", "createFolder", "deleteFolder", "renderDocument", "generateImage"} def _buildArtifactSummary(roundLogs: List[AgentRoundLog]) -> str: """Extract file operations and key results from all agent rounds. Produces a concise summary persisted as _workflowArtifacts so follow-up rounds have immediate context (file IDs, names, actions). """ ops = [] for log in roundLogs: for tc in log.toolCalls: if tc.toolName not in _ARTIFACT_TOOLS or not tc.success: continue ops.append(f"- {tc.resultData}" if tc.resultData else f"- {tc.toolName}") if not ops: return "" return "File operations in this run:\n" + "\n".join(ops)