170 lines
5.7 KiB
Python
170 lines
5.7 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Adapter to use AI Center as a LangChain-compatible chat model.
|
|
Maps LangChain message format to AI Center requests and responses.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any, AsyncIterator, Iterator, List, Optional
|
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
BaseMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
)
|
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
|
from langchain_core.callbacks import AsyncCallbackHandlerForLLMRun, CallbackManagerForLLMRun
|
|
|
|
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, ProcessingModeEnum
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AICenterChatModel(BaseChatModel):
|
|
"""
|
|
Adapter to use AI center as LangChain chat model.
|
|
Converts LangChain messages to AI center format and back.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
services,
|
|
system_prompt: str = "",
|
|
temperature: float = 0.2,
|
|
**kwargs
|
|
):
|
|
"""
|
|
Initialize AI Center chat model adapter.
|
|
|
|
Args:
|
|
services: Services instance with AI access
|
|
system_prompt: System prompt to use
|
|
temperature: Temperature for AI calls
|
|
"""
|
|
super().__init__(**kwargs)
|
|
self.services = services
|
|
self.system_prompt = system_prompt
|
|
self.temperature = temperature
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return identifier of LLM type."""
|
|
return "ai_center"
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""
|
|
Synchronous generation - not supported, use async version.
|
|
"""
|
|
raise NotImplementedError("Use async version: _agenerate")
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackHandlerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""
|
|
Generate chat response using AI center.
|
|
|
|
Args:
|
|
messages: List of LangChain messages
|
|
stop: Optional list of stop sequences
|
|
run_manager: Optional callback manager
|
|
**kwargs: Additional arguments
|
|
|
|
Returns:
|
|
ChatResult with generated message
|
|
"""
|
|
# Convert LangChain messages to AI center prompt format
|
|
prompt_parts = []
|
|
|
|
# Add system prompt if present
|
|
if self.system_prompt:
|
|
prompt_parts.append(self.system_prompt)
|
|
|
|
# Convert messages to text format
|
|
for msg in messages:
|
|
if isinstance(msg, SystemMessage):
|
|
# System messages are already in system_prompt or can be added here
|
|
if not self.system_prompt:
|
|
prompt_parts.append(f"System: {msg.content}")
|
|
elif isinstance(msg, HumanMessage):
|
|
prompt_parts.append(f"User: {msg.content}")
|
|
elif isinstance(msg, AIMessage):
|
|
prompt_parts.append(f"Assistant: {msg.content}")
|
|
else:
|
|
# Generic message
|
|
prompt_parts.append(str(msg.content))
|
|
|
|
# Combine into single prompt
|
|
full_prompt = "\n\n".join(prompt_parts)
|
|
|
|
# Create AI center request
|
|
ai_request = AiCallRequest(
|
|
prompt=full_prompt,
|
|
options=AiCallOptions(
|
|
resultFormat="txt",
|
|
operationType=OperationTypeEnum.DATA_ANALYSE,
|
|
processingMode=ProcessingModeEnum.DETAILED,
|
|
temperature=self.temperature
|
|
)
|
|
)
|
|
|
|
# Call AI center
|
|
try:
|
|
await self.services.ai.ensureAiObjectsInitialized()
|
|
ai_response = await self.services.ai.callAi(ai_request)
|
|
|
|
# Extract content
|
|
content = ai_response.content if hasattr(ai_response, 'content') else str(ai_response)
|
|
|
|
# Create AIMessage from response
|
|
ai_message = AIMessage(content=content)
|
|
|
|
# Create ChatGeneration
|
|
generation = ChatGeneration(message=ai_message)
|
|
|
|
# Return ChatResult
|
|
return ChatResult(generations=[generation])
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error calling AI center: {e}", exc_info=True)
|
|
# Return error message
|
|
error_message = AIMessage(content=f"Error: {str(e)}")
|
|
generation = ChatGeneration(message=error_message)
|
|
return ChatResult(generations=[generation])
|
|
|
|
async def astream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackHandlerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterator[BaseMessage]:
|
|
"""
|
|
Stream chat response (not fully supported by AI center, returns single chunk).
|
|
|
|
Args:
|
|
messages: List of LangChain messages
|
|
stop: Optional list of stop sequences
|
|
run_manager: Optional callback manager
|
|
**kwargs: Additional arguments
|
|
|
|
Yields:
|
|
BaseMessage chunks
|
|
"""
|
|
# For now, just return the full response as a single chunk
|
|
# TODO: Implement proper streaming if AI center supports it
|
|
result = await self._agenerate(messages, stop, run_manager, **kwargs)
|
|
if result.generations:
|
|
yield result.generations[0].message
|