fixed chat history context
This commit is contained in:
parent
6154eb2553
commit
f3de454b66
4 changed files with 93 additions and 3 deletions
|
|
@ -9,7 +9,7 @@ and Playground into a single agent-driven workspace.
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Optional, List
|
from typing import Dict, Optional, List
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request, UploadFile, File
|
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request, UploadFile, File
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
|
@ -146,6 +146,37 @@ def _buildDataSourceContext(chatService, dataSourceIds: List[str]) -> str:
|
||||||
return "\n".join(parts) if found else ""
|
return "\n".join(parts) if found else ""
|
||||||
|
|
||||||
|
|
||||||
|
def _loadConversationHistory(chatInterface, workflowId: str, currentPrompt: str) -> List[Dict[str, str]]:
|
||||||
|
"""Load prior messages from DB for follow-up context, excluding the current prompt."""
|
||||||
|
try:
|
||||||
|
rawMessages = chatInterface.getMessages(workflowId) or []
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load conversation history: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
history = []
|
||||||
|
for msg in rawMessages:
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
role = msg.get("role", "")
|
||||||
|
content = msg.get("message", "") or msg.get("content", "")
|
||||||
|
else:
|
||||||
|
role = getattr(msg, "role", "")
|
||||||
|
content = getattr(msg, "message", "") or getattr(msg, "content", "")
|
||||||
|
if role in ("user", "assistant") and content:
|
||||||
|
history.append({"role": role, "content": content})
|
||||||
|
|
||||||
|
if not history:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Drop the last user message if it matches the current prompt (already added by the agent loop)
|
||||||
|
if history[-1]["role"] == "user" and history[-1]["content"].strip() == currentPrompt.strip():
|
||||||
|
history = history[:-1]
|
||||||
|
|
||||||
|
if history:
|
||||||
|
logger.info(f"Loaded {len(history)} prior messages for workflow {workflowId}")
|
||||||
|
return history
|
||||||
|
|
||||||
|
|
||||||
async def _deriveWorkflowName(prompt: str, aiService) -> str:
|
async def _deriveWorkflowName(prompt: str, aiService) -> str:
|
||||||
"""Use AI to generate a concise workflow title from the user prompt."""
|
"""Use AI to generate a concise workflow title from the user prompt."""
|
||||||
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum
|
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum
|
||||||
|
|
@ -320,16 +351,25 @@ async def _runWorkspaceAgent(
|
||||||
if dsInfo:
|
if dsInfo:
|
||||||
enrichedPrompt = f"{prompt}\n\n[Active Data Sources]\n{dsInfo}"
|
enrichedPrompt = f"{prompt}\n\n[Active Data Sources]\n{dsInfo}"
|
||||||
|
|
||||||
|
conversationHistory = _loadConversationHistory(chatInterface, workflowId, prompt)
|
||||||
|
|
||||||
|
accumulatedText = ""
|
||||||
|
messagePersisted = False
|
||||||
|
|
||||||
async for event in agentService.runAgent(
|
async for event in agentService.runAgent(
|
||||||
prompt=enrichedPrompt,
|
prompt=enrichedPrompt,
|
||||||
fileIds=fileIds,
|
fileIds=fileIds,
|
||||||
workflowId=workflowId,
|
workflowId=workflowId,
|
||||||
userLanguage=userLanguage,
|
userLanguage=userLanguage,
|
||||||
|
conversationHistory=conversationHistory,
|
||||||
):
|
):
|
||||||
if eventManager.is_cancelled(queueId):
|
if eventManager.is_cancelled(queueId):
|
||||||
logger.info(f"Agent cancelled by user for workflow {workflowId}")
|
logger.info(f"Agent cancelled by user for workflow {workflowId}")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if event.type == AgentEventTypeEnum.CHUNK and event.content:
|
||||||
|
accumulatedText += event.content
|
||||||
|
|
||||||
sseEvent = {
|
sseEvent = {
|
||||||
"type": event.type.value if hasattr(event.type, "value") else event.type,
|
"type": event.type.value if hasattr(event.type, "value") else event.type,
|
||||||
"workflowId": workflowId,
|
"workflowId": workflowId,
|
||||||
|
|
@ -337,6 +377,7 @@ async def _runWorkspaceAgent(
|
||||||
if event.content:
|
if event.content:
|
||||||
sseEvent["content"] = event.content
|
sseEvent["content"] = event.content
|
||||||
if event.type == AgentEventTypeEnum.MESSAGE:
|
if event.type == AgentEventTypeEnum.MESSAGE:
|
||||||
|
accumulatedText += event.content
|
||||||
sseEvent["item"] = {
|
sseEvent["item"] = {
|
||||||
"id": f"msg-{workflowId}-{id(event)}",
|
"id": f"msg-{workflowId}-{id(event)}",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
|
@ -349,16 +390,30 @@ async def _runWorkspaceAgent(
|
||||||
await eventManager.emit_event(queueId, sseEvent["type"], sseEvent)
|
await eventManager.emit_event(queueId, sseEvent["type"], sseEvent)
|
||||||
|
|
||||||
if event.type in (AgentEventTypeEnum.FINAL, AgentEventTypeEnum.ERROR):
|
if event.type in (AgentEventTypeEnum.FINAL, AgentEventTypeEnum.ERROR):
|
||||||
if event.content:
|
finalContent = event.content or accumulatedText
|
||||||
|
if finalContent:
|
||||||
try:
|
try:
|
||||||
chatInterface.createMessage({
|
chatInterface.createMessage({
|
||||||
"workflowId": workflowId,
|
"workflowId": workflowId,
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"message": event.content,
|
"message": finalContent,
|
||||||
})
|
})
|
||||||
|
messagePersisted = True
|
||||||
except Exception as msgErr:
|
except Exception as msgErr:
|
||||||
logger.error(f"Failed to persist assistant message: {msgErr}")
|
logger.error(f"Failed to persist assistant message: {msgErr}")
|
||||||
|
|
||||||
|
# Persist any streamed content that wasn't saved via FINAL (e.g. cancellation)
|
||||||
|
if not messagePersisted and accumulatedText.strip():
|
||||||
|
try:
|
||||||
|
chatInterface.createMessage({
|
||||||
|
"workflowId": workflowId,
|
||||||
|
"role": "assistant",
|
||||||
|
"message": accumulatedText,
|
||||||
|
})
|
||||||
|
logger.info(f"Persisted partial assistant response ({len(accumulatedText)} chars) for workflow {workflowId}")
|
||||||
|
except Exception as msgErr:
|
||||||
|
logger.error(f"Failed to persist partial assistant message: {msgErr}")
|
||||||
|
|
||||||
logger.info(f"Agent loop completed for workflow {workflowId}, sending 'complete' event")
|
logger.info(f"Agent loop completed for workflow {workflowId}, sending 'complete' event")
|
||||||
await eventManager.emit_event(queueId, "complete", {
|
await eventManager.emit_event(queueId, "complete", {
|
||||||
"type": "complete",
|
"type": "complete",
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ async def runAgentLoop(
|
||||||
mandateId: str = "",
|
mandateId: str = "",
|
||||||
aiCallStreamFn: Callable = None,
|
aiCallStreamFn: Callable = None,
|
||||||
userLanguage: str = "",
|
userLanguage: str = "",
|
||||||
|
conversationHistory: List[Dict[str, Any]] = None,
|
||||||
) -> AsyncGenerator[AgentEvent, None]:
|
) -> AsyncGenerator[AgentEvent, None]:
|
||||||
"""Run the agent loop. Yields AgentEvent for each step (SSE-ready).
|
"""Run the agent loop. Yields AgentEvent for each step (SSE-ready).
|
||||||
|
|
||||||
|
|
@ -56,6 +57,7 @@ async def runAgentLoop(
|
||||||
buildRagContextFn: Optional async function to build RAG context before each round
|
buildRagContextFn: Optional async function to build RAG context before each round
|
||||||
mandateId: Mandate ID for RAG scoping
|
mandateId: Mandate ID for RAG scoping
|
||||||
userLanguage: ISO 639-1 language code for agent responses
|
userLanguage: ISO 639-1 language code for agent responses
|
||||||
|
conversationHistory: Prior messages [{role, content/message}] for follow-up context
|
||||||
"""
|
"""
|
||||||
state = AgentState(workflowId=workflowId, maxRounds=config.maxRounds)
|
state = AgentState(workflowId=workflowId, maxRounds=config.maxRounds)
|
||||||
trace = AgentTrace(
|
trace = AgentTrace(
|
||||||
|
|
@ -69,6 +71,8 @@ async def runAgentLoop(
|
||||||
|
|
||||||
systemPrompt = buildSystemPrompt(tools, toolsText, userLanguage=userLanguage)
|
systemPrompt = buildSystemPrompt(tools, toolsText, userLanguage=userLanguage)
|
||||||
conversation = ConversationManager(systemPrompt)
|
conversation = ConversationManager(systemPrompt)
|
||||||
|
if conversationHistory:
|
||||||
|
conversation.loadHistory(conversationHistory)
|
||||||
conversation.addUserMessage(prompt)
|
conversation.addUserMessage(prompt)
|
||||||
|
|
||||||
while state.status == AgentStatusEnum.RUNNING and state.currentRound < state.maxRounds:
|
while state.status == AgentStatusEnum.RUNNING and state.currentRound < state.maxRounds:
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,8 @@ FIRST_SUMMARY_ROUND = 4
|
||||||
META_SUMMARY_ROUND = 7
|
META_SUMMARY_ROUND = 7
|
||||||
KEEP_RECENT_MESSAGES = 4
|
KEEP_RECENT_MESSAGES = 4
|
||||||
MAX_ESTIMATED_TOKENS = 60000
|
MAX_ESTIMATED_TOKENS = 60000
|
||||||
|
_MAX_HISTORY_MESSAGES = 40
|
||||||
|
_MAX_HISTORY_MSG_CHARS = 12000
|
||||||
|
|
||||||
|
|
||||||
class ConversationManager:
|
class ConversationManager:
|
||||||
|
|
@ -33,6 +35,32 @@ class ConversationManager:
|
||||||
self._lastSummarizedRound: int = 0
|
self._lastSummarizedRound: int = 0
|
||||||
self._ragContextInjected: bool = False
|
self._ragContextInjected: bool = False
|
||||||
|
|
||||||
|
def loadHistory(self, messages: List[Dict[str, Any]]):
|
||||||
|
"""Load prior conversation messages for follow-up context.
|
||||||
|
|
||||||
|
Accepts messages with {role, content/message} format (as stored in DB).
|
||||||
|
Truncates long messages and limits total count to keep the context window
|
||||||
|
manageable. Must be called BEFORE addUserMessage with the current prompt.
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return
|
||||||
|
|
||||||
|
recent = messages[-_MAX_HISTORY_MESSAGES:]
|
||||||
|
loaded = 0
|
||||||
|
for msg in recent:
|
||||||
|
role = msg.get("role", "")
|
||||||
|
content = msg.get("content", "") or msg.get("message", "") or ""
|
||||||
|
if role not in ("user", "assistant"):
|
||||||
|
continue
|
||||||
|
if not content.strip():
|
||||||
|
continue
|
||||||
|
if len(content) > _MAX_HISTORY_MSG_CHARS:
|
||||||
|
content = content[:_MAX_HISTORY_MSG_CHARS] + "…"
|
||||||
|
self._messages.append({"role": role, "content": content})
|
||||||
|
loaded += 1
|
||||||
|
if loaded:
|
||||||
|
logger.info(f"Loaded {loaded} history messages into conversation context")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def messages(self) -> List[Dict[str, Any]]:
|
def messages(self) -> List[Dict[str, Any]]:
|
||||||
"""Current messages for the next AI call (internal markers stripped)."""
|
"""Current messages for the next AI call (internal markers stripped)."""
|
||||||
|
|
|
||||||
|
|
@ -109,6 +109,7 @@ class AgentService:
|
||||||
workflowId: str = None,
|
workflowId: str = None,
|
||||||
additionalTools: List[Dict[str, Any]] = None,
|
additionalTools: List[Dict[str, Any]] = None,
|
||||||
userLanguage: str = "",
|
userLanguage: str = "",
|
||||||
|
conversationHistory: List[Dict[str, Any]] = None,
|
||||||
) -> AsyncGenerator[AgentEvent, None]:
|
) -> AsyncGenerator[AgentEvent, None]:
|
||||||
"""Run an agent with the given prompt and tools.
|
"""Run an agent with the given prompt and tools.
|
||||||
|
|
||||||
|
|
@ -120,6 +121,7 @@ class AgentService:
|
||||||
workflowId: Workflow ID for tracking and billing
|
workflowId: Workflow ID for tracking and billing
|
||||||
additionalTools: Extra tool definitions to register dynamically
|
additionalTools: Extra tool definitions to register dynamically
|
||||||
userLanguage: ISO 639-1 language code; falls back to user.language from profile
|
userLanguage: ISO 639-1 language code; falls back to user.language from profile
|
||||||
|
conversationHistory: Prior messages for follow-up context
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
AgentEvent for each step (SSE-ready)
|
AgentEvent for each step (SSE-ready)
|
||||||
|
|
@ -154,6 +156,7 @@ class AgentService:
|
||||||
mandateId=self.services.mandateId or "",
|
mandateId=self.services.mandateId or "",
|
||||||
aiCallStreamFn=aiCallStreamFn,
|
aiCallStreamFn=aiCallStreamFn,
|
||||||
userLanguage=resolvedLanguage,
|
userLanguage=resolvedLanguage,
|
||||||
|
conversationHistory=conversationHistory,
|
||||||
):
|
):
|
||||||
if event.type == AgentEventTypeEnum.AGENT_SUMMARY:
|
if event.type == AgentEventTypeEnum.AGENT_SUMMARY:
|
||||||
await self._persistTrace(workflowId, event.data or {})
|
await self._persistTrace(workflowId, event.data or {})
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue