gateway/modules/features/chatbotV2/chatbotV2.py

210 lines
8.7 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Chatbot V2 domain logic - simple chat LangGraph with context injection.
Uses chunk-based retrieval with retry: when AI says "nicht enthalten",
tries the next chunk batch until content is found or all chunks searched.
"""
import asyncio
import logging
import re
from typing import Annotated, Optional, TYPE_CHECKING, List, Dict, Any
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from pydantic import BaseModel
from modules.features.chatbotV2.bridges import AICenterChatModel, ChatbotV2Checkpointer
from modules.features.chatbotV2.contextChunkRetrieval import (
chunk_sections,
chunk_text_blocks,
get_ordered_chunks_for_question,
get_chunk_batch,
response_indicates_not_found,
DEFAULT_CHUNK_SIZE,
DEFAULT_CHUNK_OVERLAP,
)
if TYPE_CHECKING:
from modules.features.chatbotV2.config import ChatbotV2Config
logger = logging.getLogger(__name__)
class ChatState(BaseModel):
"""State for Chatbot V2 chat session."""
messages: Annotated[list[BaseMessage], add_messages]
# Optional context for chunk retrieval (passed at invoke, not persisted)
chatbotv2_context: Optional[Dict[str, Any]] = None
# Default max context chars (~20k tokens) - fits GPT 25k limit with room for prompt + response
DEFAULT_MAX_CONTEXT_CHARS = 60_000
def _build_system_prompt_from_chunks(
base_prompt: str,
chunks: List[Dict[str, Any]]
) -> str:
"""Build system prompt from a list of chunks."""
if not chunks:
return base_prompt
header = "\n\n--- DOCUMENT CONTEXT (use this to answer user questions) ---\n"
parts = [base_prompt, header]
current_file = None
for chunk in chunks:
fn = chunk.get("fileName", "document")
if fn != current_file:
parts.append(f"\n### {fn}\n")
current_file = fn
parts.append(chunk.get("text", ""))
return "\n".join(parts)
def build_context_system_prompt(
base_prompt: str,
extracted_context: dict,
user_question: str,
max_context_chars: Optional[int] = None,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
) -> str:
"""
Build system prompt with first chunk batch (for single-call use).
For retry loop, use get_ordered_chunks_for_question + get_chunk_batch + _build_system_prompt_from_chunks.
"""
chunks = _get_all_chunks(extracted_context, chunk_size, chunk_overlap)
if not chunks:
return base_prompt
ordered = get_ordered_chunks_for_question(chunks, user_question or "")
selected = get_chunk_batch(ordered, 0, max_context_chars or DEFAULT_MAX_CONTEXT_CHARS)
return _build_system_prompt_from_chunks(base_prompt, selected)
def _get_all_chunks(
extracted_context: dict,
chunk_size: Optional[int],
chunk_overlap: Optional[int],
) -> List[Dict[str, Any]]:
"""Get all chunks from extracted context."""
sections = extracted_context.get("sections", [])
text_blocks = extracted_context.get("textBlocks", [])
if not sections and not text_blocks:
return []
cs = chunk_size if chunk_size and chunk_size > 0 else DEFAULT_CHUNK_SIZE
co = chunk_overlap if chunk_overlap is not None and chunk_overlap >= 0 else DEFAULT_CHUNK_OVERLAP
if sections:
return chunk_sections(sections, chunk_size=cs, chunk_overlap=co)
return chunk_text_blocks(text_blocks, chunk_size=cs, chunk_overlap=co)
def create_chat_graph(
model: AICenterChatModel,
memory: ChatbotV2Checkpointer,
) -> CompiledStateGraph:
"""
Create chat graph with retry loop: when AI says content not found,
tries next chunk batch until found or exhausted.
Context params passed via state.chatbotv2_context at invoke time.
"""
async def chat_node(state: ChatState) -> dict:
# State can be dict (LangGraph) or Pydantic model
state_dict = state if isinstance(state, dict) else (state.model_dump() if hasattr(state, "model_dump") else {})
msgs = state_dict.get("messages", [])
if not msgs:
return {}
ctx = state_dict.get("chatbotv2_context") or {}
ctx_dict = ctx.get("ctx_dict", {})
user_question = ctx.get("user_question", "")
base_prompt = ctx.get("base_prompt", "Answer based on the provided context.")
max_chars = ctx.get("max_context_chars") or DEFAULT_MAX_CONTEXT_CHARS
chunk_size = ctx.get("chunk_size") or DEFAULT_CHUNK_SIZE
chunk_overlap = ctx.get("chunk_overlap")
if max_chars <= 0:
max_chars = DEFAULT_MAX_CONTEXT_CHARS
if chunk_overlap is None or chunk_overlap < 0:
chunk_overlap = DEFAULT_CHUNK_OVERLAP
user_msgs = [m for m in msgs if not isinstance(m, SystemMessage)]
if not user_msgs:
return {}
# Get chunks - use DOCUMENT ORDER for retry (batch 0 = start, batch 1 = next part, etc.)
chunks = _get_all_chunks(ctx_dict, chunk_size, chunk_overlap)
if not chunks:
logger.warning("No chunks from ctx_dict - sections=%s, textBlocks=%s",
len(ctx_dict.get("sections", [])), len(ctx_dict.get("textBlocks", [])))
# Always use document order (chunkIndex) for systematic search through entire document
ordered = sorted(chunks, key=lambda c: c.get("chunkIndex", 0))
batch_index = 0
last_response = None
logger.info("Chunk retrieval: %d chunks total, max_chars=%d, will try batches until found or exhausted",
len(ordered), max_chars)
while True:
batch = get_chunk_batch(ordered, batch_index, max_chars) if ordered else []
if not batch:
# No more chunks - return last response or final message
if last_response:
return {"messages": [last_response]}
return {"messages": [AIMessage(
content="Ich habe das gesamte Dokument durchsucht, konnte aber keine "
"passende Information zu Ihrer Frage finden. Bitte formulieren Sie die Frage "
"ggf. anders oder prüfen Sie, ob das Dokument die gewünschten Angaben enthält."
)]}
system_prompt = _build_system_prompt_from_chunks(base_prompt, batch)
window = [SystemMessage(content=system_prompt)] + user_msgs
response = None
for attempt in range(3): # Max 3 attempts (initial + 2 retries on rate limit)
try:
response = await model.ainvoke(window)
break
except Exception as exc:
err_str = str(exc).lower()
if ("429" in err_str or "rate limit" in err_str) and attempt < 2:
wait_secs = 6
match = re.search(r"try again in ([\d.]+)s", err_str, re.IGNORECASE)
if match:
wait_secs = max(6, int(float(match.group(1))) + 1)
logger.warning("Rate limit hit on chunk batch %d, waiting %ds before retry (attempt %d/3)",
batch_index, wait_secs, attempt + 1)
await asyncio.sleep(wait_secs)
else:
if "No suitable model found" in str(exc):
return {"messages": [AIMessage(
content="Es tut mir leid, derzeit steht kein passendes KI-Modell zur Verfügung. "
"Bitte versuchen Sie es später erneut."
)]}
raise
if response is None:
return {"messages": [AIMessage(
content="Ein Fehler ist aufgetreten. Bitte versuchen Sie es später erneut."
)]}
content = response.content if hasattr(response, "content") else str(response)
if response_indicates_not_found(content):
logger.info("Chunk batch %d: AI said not found (%.0f chars), trying next batch",
batch_index, len(content))
batch_index += 1
last_response = response
await asyncio.sleep(5) # Pause before next batch to avoid rate limits
continue
logger.info("Chunk batch %d: Found answer (%.0f chars)", batch_index, len(content))
return {"messages": [response]}
workflow = StateGraph(ChatState)
workflow.add_node("chat", chat_node)
workflow.add_edge(START, "chat")
workflow.add_edge("chat", END)
return workflow.compile(checkpointer=memory)