fixed chat history context
This commit is contained in:
parent
6154eb2553
commit
f3de454b66
4 changed files with 93 additions and 3 deletions
|
|
@ -9,7 +9,7 @@ and Playground into a single agent-driven workspace.
|
|||
import logging
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Optional, List
|
||||
from typing import Dict, Optional, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request, UploadFile, File
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
|
|
@ -146,6 +146,37 @@ def _buildDataSourceContext(chatService, dataSourceIds: List[str]) -> str:
|
|||
return "\n".join(parts) if found else ""
|
||||
|
||||
|
||||
def _loadConversationHistory(chatInterface, workflowId: str, currentPrompt: str) -> List[Dict[str, str]]:
|
||||
"""Load prior messages from DB for follow-up context, excluding the current prompt."""
|
||||
try:
|
||||
rawMessages = chatInterface.getMessages(workflowId) or []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load conversation history: {e}")
|
||||
return []
|
||||
|
||||
history = []
|
||||
for msg in rawMessages:
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("message", "") or msg.get("content", "")
|
||||
else:
|
||||
role = getattr(msg, "role", "")
|
||||
content = getattr(msg, "message", "") or getattr(msg, "content", "")
|
||||
if role in ("user", "assistant") and content:
|
||||
history.append({"role": role, "content": content})
|
||||
|
||||
if not history:
|
||||
return []
|
||||
|
||||
# Drop the last user message if it matches the current prompt (already added by the agent loop)
|
||||
if history[-1]["role"] == "user" and history[-1]["content"].strip() == currentPrompt.strip():
|
||||
history = history[:-1]
|
||||
|
||||
if history:
|
||||
logger.info(f"Loaded {len(history)} prior messages for workflow {workflowId}")
|
||||
return history
|
||||
|
||||
|
||||
async def _deriveWorkflowName(prompt: str, aiService) -> str:
|
||||
"""Use AI to generate a concise workflow title from the user prompt."""
|
||||
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum
|
||||
|
|
@ -320,16 +351,25 @@ async def _runWorkspaceAgent(
|
|||
if dsInfo:
|
||||
enrichedPrompt = f"{prompt}\n\n[Active Data Sources]\n{dsInfo}"
|
||||
|
||||
conversationHistory = _loadConversationHistory(chatInterface, workflowId, prompt)
|
||||
|
||||
accumulatedText = ""
|
||||
messagePersisted = False
|
||||
|
||||
async for event in agentService.runAgent(
|
||||
prompt=enrichedPrompt,
|
||||
fileIds=fileIds,
|
||||
workflowId=workflowId,
|
||||
userLanguage=userLanguage,
|
||||
conversationHistory=conversationHistory,
|
||||
):
|
||||
if eventManager.is_cancelled(queueId):
|
||||
logger.info(f"Agent cancelled by user for workflow {workflowId}")
|
||||
break
|
||||
|
||||
if event.type == AgentEventTypeEnum.CHUNK and event.content:
|
||||
accumulatedText += event.content
|
||||
|
||||
sseEvent = {
|
||||
"type": event.type.value if hasattr(event.type, "value") else event.type,
|
||||
"workflowId": workflowId,
|
||||
|
|
@ -337,6 +377,7 @@ async def _runWorkspaceAgent(
|
|||
if event.content:
|
||||
sseEvent["content"] = event.content
|
||||
if event.type == AgentEventTypeEnum.MESSAGE:
|
||||
accumulatedText += event.content
|
||||
sseEvent["item"] = {
|
||||
"id": f"msg-{workflowId}-{id(event)}",
|
||||
"role": "assistant",
|
||||
|
|
@ -349,16 +390,30 @@ async def _runWorkspaceAgent(
|
|||
await eventManager.emit_event(queueId, sseEvent["type"], sseEvent)
|
||||
|
||||
if event.type in (AgentEventTypeEnum.FINAL, AgentEventTypeEnum.ERROR):
|
||||
if event.content:
|
||||
finalContent = event.content or accumulatedText
|
||||
if finalContent:
|
||||
try:
|
||||
chatInterface.createMessage({
|
||||
"workflowId": workflowId,
|
||||
"role": "assistant",
|
||||
"message": event.content,
|
||||
"message": finalContent,
|
||||
})
|
||||
messagePersisted = True
|
||||
except Exception as msgErr:
|
||||
logger.error(f"Failed to persist assistant message: {msgErr}")
|
||||
|
||||
# Persist any streamed content that wasn't saved via FINAL (e.g. cancellation)
|
||||
if not messagePersisted and accumulatedText.strip():
|
||||
try:
|
||||
chatInterface.createMessage({
|
||||
"workflowId": workflowId,
|
||||
"role": "assistant",
|
||||
"message": accumulatedText,
|
||||
})
|
||||
logger.info(f"Persisted partial assistant response ({len(accumulatedText)} chars) for workflow {workflowId}")
|
||||
except Exception as msgErr:
|
||||
logger.error(f"Failed to persist partial assistant message: {msgErr}")
|
||||
|
||||
logger.info(f"Agent loop completed for workflow {workflowId}, sending 'complete' event")
|
||||
await eventManager.emit_event(queueId, "complete", {
|
||||
"type": "complete",
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ async def runAgentLoop(
|
|||
mandateId: str = "",
|
||||
aiCallStreamFn: Callable = None,
|
||||
userLanguage: str = "",
|
||||
conversationHistory: List[Dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[AgentEvent, None]:
|
||||
"""Run the agent loop. Yields AgentEvent for each step (SSE-ready).
|
||||
|
||||
|
|
@ -56,6 +57,7 @@ async def runAgentLoop(
|
|||
buildRagContextFn: Optional async function to build RAG context before each round
|
||||
mandateId: Mandate ID for RAG scoping
|
||||
userLanguage: ISO 639-1 language code for agent responses
|
||||
conversationHistory: Prior messages [{role, content/message}] for follow-up context
|
||||
"""
|
||||
state = AgentState(workflowId=workflowId, maxRounds=config.maxRounds)
|
||||
trace = AgentTrace(
|
||||
|
|
@ -69,6 +71,8 @@ async def runAgentLoop(
|
|||
|
||||
systemPrompt = buildSystemPrompt(tools, toolsText, userLanguage=userLanguage)
|
||||
conversation = ConversationManager(systemPrompt)
|
||||
if conversationHistory:
|
||||
conversation.loadHistory(conversationHistory)
|
||||
conversation.addUserMessage(prompt)
|
||||
|
||||
while state.status == AgentStatusEnum.RUNNING and state.currentRound < state.maxRounds:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ FIRST_SUMMARY_ROUND = 4
|
|||
META_SUMMARY_ROUND = 7
|
||||
KEEP_RECENT_MESSAGES = 4
|
||||
MAX_ESTIMATED_TOKENS = 60000
|
||||
_MAX_HISTORY_MESSAGES = 40
|
||||
_MAX_HISTORY_MSG_CHARS = 12000
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
|
|
@ -33,6 +35,32 @@ class ConversationManager:
|
|||
self._lastSummarizedRound: int = 0
|
||||
self._ragContextInjected: bool = False
|
||||
|
||||
def loadHistory(self, messages: List[Dict[str, Any]]):
|
||||
"""Load prior conversation messages for follow-up context.
|
||||
|
||||
Accepts messages with {role, content/message} format (as stored in DB).
|
||||
Truncates long messages and limits total count to keep the context window
|
||||
manageable. Must be called BEFORE addUserMessage with the current prompt.
|
||||
"""
|
||||
if not messages:
|
||||
return
|
||||
|
||||
recent = messages[-_MAX_HISTORY_MESSAGES:]
|
||||
loaded = 0
|
||||
for msg in recent:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "") or msg.get("message", "") or ""
|
||||
if role not in ("user", "assistant"):
|
||||
continue
|
||||
if not content.strip():
|
||||
continue
|
||||
if len(content) > _MAX_HISTORY_MSG_CHARS:
|
||||
content = content[:_MAX_HISTORY_MSG_CHARS] + "…"
|
||||
self._messages.append({"role": role, "content": content})
|
||||
loaded += 1
|
||||
if loaded:
|
||||
logger.info(f"Loaded {loaded} history messages into conversation context")
|
||||
|
||||
@property
|
||||
def messages(self) -> List[Dict[str, Any]]:
|
||||
"""Current messages for the next AI call (internal markers stripped)."""
|
||||
|
|
|
|||
|
|
@ -109,6 +109,7 @@ class AgentService:
|
|||
workflowId: str = None,
|
||||
additionalTools: List[Dict[str, Any]] = None,
|
||||
userLanguage: str = "",
|
||||
conversationHistory: List[Dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[AgentEvent, None]:
|
||||
"""Run an agent with the given prompt and tools.
|
||||
|
||||
|
|
@ -120,6 +121,7 @@ class AgentService:
|
|||
workflowId: Workflow ID for tracking and billing
|
||||
additionalTools: Extra tool definitions to register dynamically
|
||||
userLanguage: ISO 639-1 language code; falls back to user.language from profile
|
||||
conversationHistory: Prior messages for follow-up context
|
||||
|
||||
Yields:
|
||||
AgentEvent for each step (SSE-ready)
|
||||
|
|
@ -154,6 +156,7 @@ class AgentService:
|
|||
mandateId=self.services.mandateId or "",
|
||||
aiCallStreamFn=aiCallStreamFn,
|
||||
userLanguage=resolvedLanguage,
|
||||
conversationHistory=conversationHistory,
|
||||
):
|
||||
if event.type == AgentEventTypeEnum.AGENT_SUMMARY:
|
||||
await self._persistTrace(workflowId, event.data or {})
|
||||
|
|
|
|||
Loading…
Reference in a new issue