251 lines
8 KiB
Python
251 lines
8 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
CommCoach Context Retrieval.
|
|
Intent detection, retrieval strategies, and context assembly for intelligent session continuity.
|
|
"""
|
|
|
|
import re
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Optional, Dict, Any, List, Tuple
|
|
from enum import Enum
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Retrieval config
|
|
PREVIOUS_SESSION_SUMMARIES_COUNT = 5
|
|
ROLLING_OVERVIEW_SESSION_THRESHOLD = 10
|
|
ROLLING_OVERVIEW_EVERY_N_SESSIONS = 10
|
|
TOPIC_SEARCH_MAX_RESULTS = 5
|
|
|
|
|
|
class RetrievalIntent(str, Enum):
|
|
NORMAL = "normal"
|
|
SUMMARIZE_ALL = "summarize_all"
|
|
RECALL_SESSION = "recall_session"
|
|
RECALL_TOPIC = "recall_topic"
|
|
|
|
|
|
def detectIntent(userMessage: str) -> RetrievalIntent:
|
|
"""
|
|
Lightweight intent detection from user message.
|
|
Uses keyword/regex heuristics.
|
|
"""
|
|
text = (userMessage or "").strip().lower()
|
|
if len(text) < 3:
|
|
return RetrievalIntent.NORMAL
|
|
|
|
summarizePatterns = [
|
|
r"\b(fasse|zusammenfass|zusammenfassung|ueberblick|gesamte?r?\s*chat|alles\s+zusammen)\b",
|
|
r"\b(summarize|summary\s+of\s+all|complete\s+summary)\b",
|
|
r"zusammenfassung\s+(des\s+)?gesamten",
|
|
r"gesamten\s+chat",
|
|
]
|
|
for p in summarizePatterns:
|
|
if re.search(p, text, re.IGNORECASE):
|
|
return RetrievalIntent.SUMMARIZE_ALL
|
|
|
|
datePatterns = [
|
|
r"\b(session|gespraech|besprechung)\s+(vom|am|vom)\s*(\d{1,2})\.(\d{1,2})\.(\d{2,4})",
|
|
r"\b(am|vom)\s*(\d{1,2})\.(\d{1,2})\.(\d{2,4})",
|
|
r"\b(letzte\s+woche|voriger\s+monat|gestern)\b",
|
|
r"\b(session|gespraech)\s+(vom|from)\s+(\d{4}-\d{2}-\d{2})",
|
|
]
|
|
for p in datePatterns:
|
|
if re.search(p, text, re.IGNORECASE):
|
|
return RetrievalIntent.RECALL_SESSION
|
|
|
|
recallTopicPatterns = [
|
|
r"\b(erinnerst\s+du\s+dich|damals\s+als|thema\s+.*\s+von|ueber\s+was\s+haben\s+wir)\b",
|
|
r"\b(was\s+war\s+.*\s+nochmal|thema\s+.*\s+besprochen)\b",
|
|
r"\b(recall|remember|vor\s+\d+\s+sessions?)\b",
|
|
]
|
|
for p in recallTopicPatterns:
|
|
if re.search(p, text, re.IGNORECASE):
|
|
return RetrievalIntent.RECALL_TOPIC
|
|
|
|
return RetrievalIntent.NORMAL
|
|
|
|
|
|
def _parseDateFromMessage(text: str) -> Optional[datetime]:
|
|
"""Extract date from user message. Returns date or None."""
|
|
text = text.strip()
|
|
patterns = [
|
|
(r"(\d{1,2})\.(\d{1,2})\.(\d{2,4})", lambda m: (int(m[1]), int(m[2]), int(m[3]))),
|
|
(r"(\d{4})-(\d{2})-(\d{2})", lambda m: (int(m[3]), int(m[2]), int(m[1]))),
|
|
]
|
|
for pattern, extractor in patterns:
|
|
match = re.search(pattern, text)
|
|
if match:
|
|
try:
|
|
day, month, year = extractor(match)
|
|
if year < 100:
|
|
year += 2000
|
|
return datetime(year, month, day)
|
|
except (ValueError, IndexError):
|
|
pass
|
|
return None
|
|
|
|
|
|
def findSessionByDate(
|
|
sessions: List[Dict[str, Any]],
|
|
targetDate: Optional[datetime],
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Find session closest to targetDate.
|
|
sessions: list of session dicts with startedAt/endedAt.
|
|
"""
|
|
if not targetDate or not sessions:
|
|
return None
|
|
|
|
targetDateOnly = targetDate.date()
|
|
bestMatch = None
|
|
bestDiff = None
|
|
|
|
for s in sessions:
|
|
if s.get("status") != "completed":
|
|
continue
|
|
startedAt = s.get("startedAt") or s.get("endedAt") or s.get("createdAt")
|
|
if not startedAt:
|
|
continue
|
|
try:
|
|
dt = datetime.fromisoformat(startedAt.replace("Z", "+00:00"))
|
|
sessionDate = dt.date()
|
|
diff = abs((sessionDate - targetDateOnly).days)
|
|
if bestDiff is None or diff < bestDiff:
|
|
bestDiff = diff
|
|
bestMatch = s
|
|
except Exception:
|
|
continue
|
|
|
|
return bestMatch
|
|
|
|
|
|
def searchSessionsByTopic(
|
|
sessions: List[Dict[str, Any]],
|
|
query: str,
|
|
maxResults: int = TOPIC_SEARCH_MAX_RESULTS,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Topic search over sessions.
|
|
Phase 5: Keyword-based (keyTopics + summary).
|
|
Phase 7: Falls back to embedding search when available; for now uses keyword only.
|
|
"""
|
|
if not query or not sessions:
|
|
return []
|
|
|
|
queryWords = set(re.findall(r"\w+", query.lower()))
|
|
if not queryWords:
|
|
return []
|
|
|
|
scored = []
|
|
for s in sessions:
|
|
if s.get("status") != "completed":
|
|
continue
|
|
score = 0
|
|
summary = (s.get("summary") or "").lower()
|
|
keyTopicsRaw = s.get("keyTopics")
|
|
keyTopics = []
|
|
if keyTopicsRaw:
|
|
try:
|
|
import json
|
|
parsed = json.loads(keyTopicsRaw) if isinstance(keyTopicsRaw, str) else keyTopicsRaw
|
|
keyTopics = [t.lower() if isinstance(t, str) else str(t).lower() for t in parsed] if isinstance(parsed, list) else []
|
|
except Exception:
|
|
pass
|
|
|
|
for word in queryWords:
|
|
if len(word) < 3:
|
|
continue
|
|
if word in summary:
|
|
score += 1
|
|
for topic in keyTopics:
|
|
if word in topic:
|
|
score += 2
|
|
|
|
if score > 0:
|
|
scored.append((score, s))
|
|
|
|
scored.sort(key=lambda x: -x[0])
|
|
return [s for _, s in scored[:maxResults]]
|
|
|
|
|
|
def searchSessionsByTopicRag(
|
|
query: str,
|
|
userId: str,
|
|
instanceId: str,
|
|
mandateId: str = None,
|
|
queryVector: List[float] = None,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Search using platform RAG (semantic search across mandate-wide knowledge data).
|
|
|
|
Requires a pre-computed queryVector (embedding). The caller is responsible
|
|
for generating the embedding via AiService.callEmbedding before invoking this.
|
|
"""
|
|
if not queryVector:
|
|
logger.warning("searchSessionsByTopicRag called without queryVector, skipping RAG search")
|
|
return []
|
|
try:
|
|
from modules.interfaces.interfaceDbKnowledge import getInterface as _getKnowledgeInterface
|
|
|
|
knowledgeDb = _getKnowledgeInterface()
|
|
|
|
results = knowledgeDb.semanticSearch(
|
|
queryVector=queryVector,
|
|
userId=userId,
|
|
featureInstanceId=instanceId,
|
|
mandateId=mandateId,
|
|
isSysAdmin=False,
|
|
limit=TOPIC_SEARCH_MAX_RESULTS,
|
|
)
|
|
|
|
formatted = []
|
|
for r in (results or []):
|
|
rData = r if isinstance(r, dict) else r.model_dump() if hasattr(r, "model_dump") else {}
|
|
contextRef = rData.get("contextRef") or {}
|
|
formatted.append({
|
|
"source": "rag",
|
|
"content": rData.get("data") or rData.get("summary") or "",
|
|
"fileName": contextRef.get("containerPath") or "RAG-Ergebnis",
|
|
"score": rData.get("_score") or 0,
|
|
})
|
|
return formatted
|
|
except Exception as e:
|
|
logger.warning(f"RAG search failed for query '{query[:50]}': {e}")
|
|
return []
|
|
|
|
|
|
def buildSessionSummariesForPrompt(
|
|
sessions: List[Dict[str, Any]],
|
|
excludeSessionId: Optional[str] = None,
|
|
limit: int = PREVIOUS_SESSION_SUMMARIES_COUNT,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Build list of session summaries with date for prompt.
|
|
Each item: {summary, date, sessionId, keyTopics}.
|
|
"""
|
|
completed = [
|
|
s for s in sessions
|
|
if s.get("status") == "completed"
|
|
and s.get("summary")
|
|
and s.get("id") != excludeSessionId
|
|
]
|
|
completed.sort(key=lambda x: x.get("startedAt") or x.get("createdAt") or "", reverse=True)
|
|
result = []
|
|
for s in completed[:limit]:
|
|
startedAt = s.get("startedAt") or s.get("createdAt") or ""
|
|
dateStr = ""
|
|
if startedAt:
|
|
try:
|
|
dt = datetime.fromisoformat(startedAt.replace("Z", "+00:00"))
|
|
dateStr = dt.strftime("%d.%m.%Y")
|
|
except Exception:
|
|
pass
|
|
result.append({
|
|
"summary": s.get("summary", ""),
|
|
"date": dateStr,
|
|
"sessionId": s.get("id"),
|
|
"keyTopics": s.get("keyTopics"),
|
|
})
|
|
return result
|