commit
e6e35286fe
257 changed files with 44879 additions and 27909 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -167,4 +167,5 @@ cython_debug/
|
|||
# local data
|
||||
gwserver/_database*
|
||||
gwserver/results/*
|
||||
*.log.*
|
||||
*.log.*
|
||||
test-chat
|
||||
365
app.py
365
app.py
|
|
@ -1,19 +1,75 @@
|
|||
import os
|
||||
import sys
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
os.environ["NUMEXPR_MAX_THREADS"] = "12"
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Depends, Body, status, Response
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security import HTTPBearer
|
||||
from contextlib import asynccontextmanager
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from datetime import timedelta
|
||||
import pathlib
|
||||
from datetime import datetime
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from modules.shared.eventManagement import eventManager
|
||||
from modules.features import featuresLifecycle as featuresLifecycle
|
||||
|
||||
class DailyRotatingFileHandler(RotatingFileHandler):
|
||||
"""
|
||||
A rotating file handler that automatically switches to a new file when the date changes.
|
||||
The log file name includes the current date and switches at midnight.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, logDir, filenamePrefix, maxBytes=10485760, backupCount=5, **kwargs
|
||||
):
|
||||
self.logDir = logDir
|
||||
self.filenamePrefix = filenamePrefix
|
||||
self.currentDate = None
|
||||
self.currentFile = None
|
||||
|
||||
# Initialize with today's file
|
||||
self._updateFileIfNeeded()
|
||||
|
||||
# Call parent constructor with current file
|
||||
super().__init__(
|
||||
self.currentFile, maxBytes=maxBytes, backupCount=backupCount, **kwargs
|
||||
)
|
||||
|
||||
def _updateFileIfNeeded(self):
|
||||
"""Update the log file if the date has changed"""
|
||||
today = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
if self.currentDate != today:
|
||||
self.currentDate = today
|
||||
newFile = os.path.join(self.logDir, f"{self.filenamePrefix}_{today}.log")
|
||||
|
||||
if self.currentFile != newFile:
|
||||
self.currentFile = newFile
|
||||
return True
|
||||
return False
|
||||
|
||||
def emit(self, record):
|
||||
"""Emit a log record, switching files if date has changed"""
|
||||
# Check if we need to switch to a new file
|
||||
if self._updateFileIfNeeded():
|
||||
# Close current file and open new one
|
||||
if self.stream:
|
||||
self.stream.close()
|
||||
self.stream = None
|
||||
|
||||
# Update the baseFilename for the parent class
|
||||
self.baseFilename = self.currentFile
|
||||
# Reopen the stream
|
||||
if not self.delay:
|
||||
self.stream = self._open()
|
||||
|
||||
# Call parent emit method
|
||||
super().emit(record)
|
||||
|
||||
|
||||
def initLogging():
|
||||
"""Initialize logging with configuration from APP_CONFIG"""
|
||||
|
|
@ -21,29 +77,45 @@ def initLogging():
|
|||
logLevelName = APP_CONFIG.get("APP_LOGGING_LOG_LEVEL", "WARNING")
|
||||
logLevel = getattr(logging, logLevelName)
|
||||
|
||||
# Get log directory from config
|
||||
logDir = APP_CONFIG.get("APP_LOGGING_LOG_DIR", "./")
|
||||
if not os.path.isabs(logDir):
|
||||
# If relative path, make it relative to the gateway directory
|
||||
gatewayDir = os.path.dirname(os.path.abspath(__file__))
|
||||
logDir = os.path.join(gatewayDir, logDir)
|
||||
|
||||
# Ensure log directory exists
|
||||
os.makedirs(logDir, exist_ok=True)
|
||||
|
||||
# Create formatters - using single line format
|
||||
consoleFormatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S")
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
|
||||
# File formatter with more detailed error information but still single line
|
||||
fileFormatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s - %(pathname)s:%(lineno)d - %(funcName)s",
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S")
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"),
|
||||
)
|
||||
|
||||
# Add filter to exclude Chrome DevTools requests
|
||||
class ChromeDevToolsFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
return not (isinstance(record.msg, str) and
|
||||
('.well-known/appspecific/com.chrome.devtools.json' in record.msg or
|
||||
'Request: /index.html' in record.msg))
|
||||
return not (
|
||||
isinstance(record.msg, str)
|
||||
and (
|
||||
".well-known/appspecific/com.chrome.devtools.json" in record.msg
|
||||
or "Request: /index.html" in record.msg
|
||||
)
|
||||
)
|
||||
|
||||
# Add filter to exclude all httpcore loggers (including sub-loggers)
|
||||
class HttpcoreStarFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
return not (record.name == 'httpcore' or record.name.startswith('httpcore.'))
|
||||
return not (
|
||||
record.name == "httpcore" or record.name.startswith("httpcore.")
|
||||
)
|
||||
|
||||
# Add filter to exclude HTTP debug messages
|
||||
class HTTPDebugFilter(logging.Filter):
|
||||
|
|
@ -51,14 +123,14 @@ def initLogging():
|
|||
if isinstance(record.msg, str):
|
||||
# Filter out HTTP debug messages
|
||||
http_debug_patterns = [
|
||||
'receive_response_body.started',
|
||||
'receive_response_body.complete',
|
||||
'response_closed.started',
|
||||
'_send_single_request',
|
||||
'httpcore.http11',
|
||||
'httpx._client',
|
||||
'HTTP Request',
|
||||
'multipart.multipart'
|
||||
"receive_response_body.started",
|
||||
"receive_response_body.complete",
|
||||
"response_closed.started",
|
||||
"_send_single_request",
|
||||
"httpcore.http11",
|
||||
"httpx._client",
|
||||
"HTTP Request",
|
||||
"multipart.multipart",
|
||||
]
|
||||
return not any(pattern in record.msg for pattern in http_debug_patterns)
|
||||
return True
|
||||
|
|
@ -70,8 +142,39 @@ def initLogging():
|
|||
# Remove only emojis, preserve other Unicode characters like quotes
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
# Remove emoji characters specifically
|
||||
record.msg = ''.join(char for char in record.msg if unicodedata.category(char) != 'So' or not (0x1F600 <= ord(char) <= 0x1F64F or 0x1F300 <= ord(char) <= 0x1F5FF or 0x1F680 <= ord(char) <= 0x1F6FF or 0x1F1E0 <= ord(char) <= 0x1F1FF or 0x2600 <= ord(char) <= 0x26FF or 0x2700 <= ord(char) <= 0x27BF))
|
||||
record.msg = "".join(
|
||||
char
|
||||
for char in record.msg
|
||||
if unicodedata.category(char) != "So"
|
||||
or not (
|
||||
0x1F600 <= ord(char) <= 0x1F64F
|
||||
or 0x1F300 <= ord(char) <= 0x1F5FF
|
||||
or 0x1F680 <= ord(char) <= 0x1F6FF
|
||||
or 0x1F1E0 <= ord(char) <= 0x1F1FF
|
||||
or 0x2600 <= ord(char) <= 0x26FF
|
||||
or 0x2700 <= ord(char) <= 0x27BF
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
# Add filter to normalize problematic unicode (e.g., arrows) to ASCII for terminals like cp1252
|
||||
class UnicodeArrowFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
if isinstance(record.msg, str):
|
||||
translation_map = {
|
||||
"\u2192": "->", # rightwards arrow
|
||||
"\u2190": "<-", # leftwards arrow
|
||||
"\u2194": "<->", # left right arrow
|
||||
"\u21D2": "=>", # rightwards double arrow
|
||||
"\u21D0": "<=", # leftwards double arrow
|
||||
"\u21D4": "<=>", # left right double arrow
|
||||
"\u00AB": "<<", # left-pointing double angle quotation mark
|
||||
"\u00BB": ">>", # right-pointing double angle quotation mark
|
||||
}
|
||||
for u, ascii_eq in translation_map.items():
|
||||
record.msg = record.msg.replace(u, ascii_eq)
|
||||
return True
|
||||
|
||||
# Configure handlers based on config
|
||||
|
|
@ -85,35 +188,30 @@ def initLogging():
|
|||
consoleHandler.addFilter(HttpcoreStarFilter())
|
||||
consoleHandler.addFilter(HTTPDebugFilter())
|
||||
consoleHandler.addFilter(EmojiFilter())
|
||||
consoleHandler.addFilter(UnicodeArrowFilter())
|
||||
handlers.append(consoleHandler)
|
||||
|
||||
# Add file handler if enabled
|
||||
if APP_CONFIG.get("APP_LOGGING_FILE_ENABLED", True):
|
||||
# Get log file path and ensure it's absolute
|
||||
logFile = APP_CONFIG.get("APP_LOGGING_LOG_FILE", "app.log")
|
||||
if not os.path.isabs(logFile):
|
||||
# If relative path, make it relative to the gateway directory
|
||||
gatewayDir = os.path.dirname(os.path.abspath(__file__))
|
||||
logFile = os.path.join(gatewayDir, logFile)
|
||||
|
||||
# Ensure log directory exists
|
||||
logDir = os.path.dirname(logFile)
|
||||
if logDir:
|
||||
os.makedirs(logDir, exist_ok=True)
|
||||
|
||||
rotationSize = int(APP_CONFIG.get("APP_LOGGING_ROTATION_SIZE", 10485760)) # Default: 10MB
|
||||
# Create daily application log file with automatic date switching
|
||||
rotationSize = int(
|
||||
APP_CONFIG.get("APP_LOGGING_ROTATION_SIZE", 10485760)
|
||||
) # Default: 10MB
|
||||
backupCount = int(APP_CONFIG.get("APP_LOGGING_BACKUP_COUNT", 5))
|
||||
|
||||
fileHandler = RotatingFileHandler(
|
||||
logFile,
|
||||
maxBytes=rotationSize,
|
||||
backupCount=backupCount
|
||||
|
||||
fileHandler = DailyRotatingFileHandler(
|
||||
logDir=logDir,
|
||||
filenamePrefix="log_app",
|
||||
maxBytes=rotationSize,
|
||||
backupCount=backupCount,
|
||||
encoding="utf-8",
|
||||
)
|
||||
fileHandler.setFormatter(fileFormatter)
|
||||
fileHandler.addFilter(ChromeDevToolsFilter())
|
||||
fileHandler.addFilter(HttpcoreStarFilter())
|
||||
fileHandler.addFilter(HTTPDebugFilter())
|
||||
fileHandler.addFilter(EmojiFilter())
|
||||
fileHandler.addFilter(UnicodeArrowFilter())
|
||||
handlers.append(fileHandler)
|
||||
|
||||
# Configure the root logger
|
||||
|
|
@ -122,101 +220,162 @@ def initLogging():
|
|||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt=APP_CONFIG.get("APP_LOGGING_DATE_FORMAT", "%Y-%m-%d %H:%M:%S"),
|
||||
handlers=handlers,
|
||||
force=True # Force reconfiguration of the root logger
|
||||
force=True, # Force reconfiguration of the root logger
|
||||
)
|
||||
|
||||
# Silence noisy third-party libraries - use the same level as the root logger
|
||||
noisyLoggers = ["httpx", "httpcore", "urllib3", "asyncio", "fastapi.security.oauth2", "msal"]
|
||||
noisyLoggers = [
|
||||
"httpx",
|
||||
"httpcore",
|
||||
"urllib3",
|
||||
"asyncio",
|
||||
"fastapi.security.oauth2",
|
||||
"msal",
|
||||
]
|
||||
for loggerName in noisyLoggers:
|
||||
logging.getLogger(loggerName).setLevel(logging.WARNING)
|
||||
|
||||
# Log the current logging configuration
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Logging initialized with level {logLevelName}")
|
||||
logger.info(f"Log file: {logFile if APP_CONFIG.get('APP_LOGGING_FILE_ENABLED', True) else 'disabled'}")
|
||||
logger.info(f"Console logging: {'enabled' if APP_CONFIG.get('APP_LOGGING_CONSOLE_ENABLED', True) else 'disabled'}")
|
||||
logger.info(f"Log directory: {logDir}")
|
||||
|
||||
if APP_CONFIG.get("APP_LOGGING_FILE_ENABLED", True):
|
||||
today = datetime.now().strftime("%Y%m%d")
|
||||
appLogFile = os.path.join(logDir, f"log_app_{today}.log")
|
||||
logger.info(f"Application log file: {appLogFile} (auto-switches daily)")
|
||||
else:
|
||||
logger.info("Application log file: disabled")
|
||||
|
||||
logger.info(
|
||||
f"Console logging: {'enabled' if APP_CONFIG.get('APP_LOGGING_CONSOLE_ENABLED', True) else 'disabled'}"
|
||||
)
|
||||
|
||||
|
||||
def makeSqlalchemyDbUrl() -> str:
|
||||
host = APP_CONFIG.get("SQLALCHEMY_DB_HOST", "localhost")
|
||||
port = APP_CONFIG.get("SQLALCHEMY_DB_PORT", "5432")
|
||||
db = APP_CONFIG.get("SQLALCHEMY_DB_DATABASE", "project_gateway")
|
||||
user = APP_CONFIG.get("SQLALCHEMY_DB_USER", "postgres")
|
||||
pwd = quote_plus(APP_CONFIG.get("SQLALCHEMY_DB_PASSWORD_SECRET", ""))
|
||||
# On Windows, prefer asyncpg to avoid psycopg + ProactorEventLoop incompatibility
|
||||
if sys.platform == "win32":
|
||||
return f"postgresql+asyncpg://{user}:{pwd}@{host}:{port}/{db}"
|
||||
return f"postgresql+psycopg://{user}:{pwd}@{host}:{port}/{db}"
|
||||
|
||||
|
||||
# Initialize logging
|
||||
initLogging()
|
||||
logger = logging.getLogger(__name__)
|
||||
instanceLabel = APP_CONFIG.get("APP_ENV_LABEL")
|
||||
|
||||
|
||||
# Define lifespan context manager for application startup/shutdown events
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup logic
|
||||
logger.info("Application is starting up")
|
||||
|
||||
# Initialize root interface to ensure database is properly set up
|
||||
from modules.interfaces.interfaceAppObjects import getRootInterface
|
||||
getRootInterface()
|
||||
|
||||
# Setup APScheduler for JIRA sync
|
||||
scheduler = AsyncIOScheduler(timezone=ZoneInfo("Europe/Zurich"))
|
||||
try:
|
||||
from modules.services.serviceDeltaSync import perform_sync_jira_delta_group
|
||||
# Schedule sync every 20 minutes (at minutes 00, 20, 40)
|
||||
scheduler.add_job(
|
||||
perform_sync_jira_delta_group,
|
||||
CronTrigger(minute="0,20,40"),
|
||||
id="jira_delta_group_sync",
|
||||
replace_existing=True,
|
||||
coalesce=True,
|
||||
max_instances=1,
|
||||
misfire_grace_time=1800,
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info("APScheduler started (jira_delta_group_sync every 20 minutes at 00, 20, 40)")
|
||||
|
||||
# Run initial sync on startup (non-blocking failure)
|
||||
try:
|
||||
logger.info("Running initial JIRA sync on app startup...")
|
||||
await perform_sync_jira_delta_group()
|
||||
logger.info("Initial JIRA sync completed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Initial JIRA sync failed: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize scheduler or JIRA sync: {str(e)}")
|
||||
|
||||
|
||||
# --- Init Managers ---
|
||||
await featuresLifecycle.start()
|
||||
eventManager.start()
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown logic
|
||||
|
||||
# --- Stop Managers ---
|
||||
eventManager.stop()
|
||||
await featuresLifecycle.stop()
|
||||
logger.info("Application has been shut down")
|
||||
try:
|
||||
if 'scheduler' in locals() and scheduler.running:
|
||||
scheduler.shutdown(wait=False)
|
||||
logger.info("APScheduler stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error shutting down scheduler: {str(e)}")
|
||||
|
||||
|
||||
# START APP
|
||||
app = FastAPI(
|
||||
title="PowerOn | Data Platform API",
|
||||
title="PowerOn | Data Platform API",
|
||||
description=f"Backend API for the Multi-Agent Platform by ValueOn AG ({instanceLabel})",
|
||||
lifespan=lifespan
|
||||
lifespan=lifespan,
|
||||
swagger_ui_init_oauth={
|
||||
"usePkceWithAuthorizationCodeGrant": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Configure OpenAPI security scheme for Swagger UI
|
||||
# This adds the "Authorize" button to the /docs page
|
||||
securityScheme = HTTPBearer()
|
||||
app.openapi_schema = None # Reset schema to regenerate with security
|
||||
|
||||
|
||||
def customOpenapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
openapiSchema = get_openapi(
|
||||
title=app.title,
|
||||
version="1.0.0",
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
# Add security scheme definition
|
||||
openapiSchema["components"]["securitySchemes"] = {
|
||||
"BearerAuth": {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
"bearerFormat": "JWT",
|
||||
"description": "Enter your JWT token (obtained from login endpoint or browser cookies)",
|
||||
}
|
||||
}
|
||||
|
||||
# Apply security globally to all endpoints
|
||||
# Individual endpoints can override this if needed
|
||||
openapiSchema["security"] = [{"BearerAuth": []}]
|
||||
|
||||
app.openapi_schema = openapiSchema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = customOpenapi
|
||||
|
||||
|
||||
# Parse CORS origins from environment variable
|
||||
def get_allowed_origins():
|
||||
origins_str = APP_CONFIG.get("APP_ALLOWED_ORIGINS", "http://localhost:8080")
|
||||
def getAllowedOrigins():
|
||||
originsStr = APP_CONFIG.get("APP_ALLOWED_ORIGINS", "http://localhost:8080")
|
||||
# Split by comma and strip whitespace
|
||||
origins = [origin.strip() for origin in origins_str.split(",")]
|
||||
origins = [origin.strip() for origin in originsStr.split(",")]
|
||||
logger.info(f"CORS allowed origins: {origins}")
|
||||
return origins
|
||||
|
||||
|
||||
# CORS configuration using environment variables
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins= get_allowed_origins(),
|
||||
allow_origins=getAllowedOrigins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["*"],
|
||||
max_age=86400 # Increased caching for preflight requests
|
||||
max_age=86400, # Increased caching for preflight requests
|
||||
)
|
||||
|
||||
# CSRF protection middleware
|
||||
from modules.security.csrf import CSRFMiddleware
|
||||
from modules.security.tokenRefreshMiddleware import (
|
||||
TokenRefreshMiddleware,
|
||||
ProactiveTokenRefreshMiddleware,
|
||||
)
|
||||
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
# Token refresh middleware (silent refresh for expired OAuth tokens)
|
||||
app.add_middleware(TokenRefreshMiddleware, enabled=True)
|
||||
|
||||
# Proactive token refresh middleware (refresh tokens before they expire)
|
||||
app.add_middleware(
|
||||
ProactiveTokenRefreshMiddleware, enabled=True, check_interval_minutes=5
|
||||
)
|
||||
|
||||
# Include all routers
|
||||
|
||||
from modules.routes.routeAdmin import router as generalRouter
|
||||
app.include_router(generalRouter)
|
||||
|
||||
|
|
@ -232,6 +391,9 @@ app.include_router(userRouter)
|
|||
from modules.routes.routeDataFiles import router as fileRouter
|
||||
app.include_router(fileRouter)
|
||||
|
||||
from modules.routes.routeDataNeutralization import router as neutralizationRouter
|
||||
app.include_router(neutralizationRouter)
|
||||
|
||||
from modules.routes.routeDataPrompts import router as promptRouter
|
||||
app.include_router(promptRouter)
|
||||
|
||||
|
|
@ -241,6 +403,9 @@ app.include_router(connectionsRouter)
|
|||
from modules.routes.routeWorkflows import router as workflowRouter
|
||||
app.include_router(workflowRouter)
|
||||
|
||||
from modules.routes.routeChatPlayground import router as chatPlaygroundRouter
|
||||
app.include_router(chatPlaygroundRouter)
|
||||
|
||||
from modules.routes.routeSecurityLocal import router as localRouter
|
||||
app.include_router(localRouter)
|
||||
|
||||
|
|
@ -253,9 +418,15 @@ app.include_router(googleRouter)
|
|||
from modules.routes.routeVoiceGoogle import router as voiceGoogleRouter
|
||||
app.include_router(voiceGoogleRouter)
|
||||
|
||||
from modules.routes.routeVoiceStreaming import router as voiceStreamingRouter
|
||||
app.include_router(voiceStreamingRouter)
|
||||
|
||||
# Admin security routes (token listing and revocation, logs, db tools)
|
||||
from modules.routes.routeSecurityAdmin import router as adminSecurityRouter
|
||||
app.include_router(adminSecurityRouter)
|
||||
app.include_router(adminSecurityRouter)
|
||||
|
||||
from modules.routes.routeSharepoint import router as sharepointRouter
|
||||
app.include_router(sharepointRouter)
|
||||
|
||||
from modules.routes.routeDataAutomation import router as automationRouter
|
||||
app.include_router(automationRouter)
|
||||
|
||||
from modules.routes.routeAdminAutomationEvents import router as adminAutomationEventsRouter
|
||||
app.include_router(adminAutomationEventsRouter)
|
||||
|
||||
|
|
|
|||
49
config.ini
49
config.ini
|
|
@ -5,21 +5,6 @@
|
|||
Auth_ALGORITHM = HS256
|
||||
Auth_TOKEN_TYPE = bearer
|
||||
|
||||
# OpenAI configuration
|
||||
Connector_AiOpenai_API_URL = https://api.openai.com/v1/chat/completions
|
||||
Connector_AiOpenai_API_SECRET = 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
|
||||
Connector_AiOpenai_MODEL_NAME = gpt-4o
|
||||
Connector_AiOpenai_TEMPERATURE = 0.2
|
||||
Connector_AiOpenai_MAX_TOKENS = 2000
|
||||
|
||||
# Anthropic configuration
|
||||
Connector_AiAnthropic_API_URL = https://api.anthropic.com/v1/messages
|
||||
Connector_AiAnthropic_API_SECRET_OLD = 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
|
||||
Connector_AiAnthropic_API_SECRET = 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
|
||||
Connector_AiAnthropic_MODEL_NAME = claude-3-5-sonnet-20241022
|
||||
Connector_AiAnthropic_TEMPERATURE = 0.2
|
||||
Connector_AiAnthropic_MAX_TOKENS = 2000
|
||||
|
||||
# File management configuration
|
||||
File_Management_MAX_UPLOAD_SIZE_MB = 50
|
||||
File_Management_CLEANUP_INTERVAL = 240
|
||||
|
|
@ -36,33 +21,6 @@ Security_LOCK_DURATION_MINUTES = 30
|
|||
# Content Neutralization configuration
|
||||
Content_Neutralization_ENABLED = False
|
||||
|
||||
# Agent Mail configuration
|
||||
Service_MSFT_CLIENT_ID = c7e7112d-61dc-4f3a-8cd3-08cc4cd7504c
|
||||
Service_MSFT_CLIENT_SECRET = Kxf8Q~2lJIteZ~JaI32kMf1lfaWKATqxXiNiFbzV
|
||||
Service_MSFT_TENANT_ID = common
|
||||
|
||||
# Google Service configuration
|
||||
Service_GOOGLE_CLIENT_ID = 354925410565-aqs2b2qaiqmm73qpjnel6al8eid78uvg.apps.googleusercontent.com
|
||||
Service_GOOGLE_CLIENT_SECRET = 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
|
||||
|
||||
# Tavily Web Search configuration
|
||||
Connector_WebTavily_API_KEY = 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
|
||||
|
||||
# Google Cloud Speech Services configuration
|
||||
Connector_GoogleSpeech_API_KEY = {
|
||||
"type": "service_account",
|
||||
"project_id": "poweronid",
|
||||
"private_key_id": "88db66e4248326e9baeac4231bc196fd46a9a441",
|
||||
"private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDTnJuxA+xBL3LA\nPgFILYCsGuppkkdO6d153Q36f2jTj6zpH3OhKMVsaaTBknG2o2+D0Whlk6Yh5rOw\nkWzpMC3y81leRLm5kucERMkBUgd2GL4v16k6m+QGuC3BFlt/XeyuckJNW0V6v/Dy\n3+bSYM7/5o1ftPNWJeAIEWoE/V4wKCYde8RE4Vp1LO5YwhgcM4rRuPmF2OhekpA+\npteYwkY/8/gTTRpZIc8OTsBYRbaMwsjoDj5riuL3boVtkwZwKRb+ZLvupXeU7Ds7\n1305odTcZUwnImHiHfuq83ZJViQiLRNhUAFnQIXPrYLwEpCmzRBGzYHaRlb69ga/\nzqUbKnclAgMBAAECggEAH6W9qHehubioPMAJM7Y6bC2KU/JLNS4csBZd+idb52gG\nwBwIEFjR+H4ZjymhAA4+pe7c4h7MKyh0RI/l7eoFX98Cb+rEq/r1udm1BhGH3s2h\n2UiI8qRQh1YRjF2/nrN5VjhDBOFa6W9opaopZy/l8AzsT8f21zIgPen8z8o6GpFg\n64fJFcbqCGk2ykN2+x2pIOT04tmCszrfbXZP8LEs4xrUB/XwlHL1vT/M3EWIKbnj\njDaIMjw7q/KRgNUvmKS6SU9b3fnOLcQCz9f5cKdiWACKIU/UvuiWhWJ9ou6BWLWU\nva1A6Fi4XJjhW7s3po58/ioQfl0A9p/L92lGg4ST8QKBgQDx8LIM1g0dh9Ql6LmH\nBUGCOewNNXTs+y3ZznUfvVMoyyZK5w/pzeUvkmOwzbRGnZJ9WyCghq8aezyEpo2D\nPL7Odf988IeHmvhyZIM4PLJYgDvSwGXyf/gh6gJkf/4wpx+tx/yQYNBm3Rht7sA0\npSaLehK0E0kW1uyBzHGKgyQOhwKBgQDf6LiZ7hSQqh54vIU1XMDRth0UOo/s/HGi\nDoij29KjmHjLkm8vOlCo83e79X0WhcnyB5kM7nWFegwcM1PJ0Dl8gidUuTlOVDtM\n5u2AaxDoyXAUL457U5dGFAIW+R653ZDkzMfCglacP8HixXEyIpL1cTLqiCAgzszS\nLcSWwoAr8wKBgQC4CGm3X97sFpTmHSd6sCHLaDnJNl9xoAKZifUHpqCqCBVhpm8x\nXp+11vmj1GULzfJPDlE8Khbp4tH+6R39tOhC7fjgVaoSGWxgv1odHfZfYXOf9R/X\nHUZmrbUSM1XsNkPfkZ7pR+teQ1HA1Xo40WMHd1zgw0a2a9fNR/EZ9nUn4wKBgGaK\nUEgGNRrPHadTRnnaoV8o1IZYD2OLdIqvtzm7SOqsv90SkaKCRUAqR5InaYKwAHy7\nqAa5Cc73xqX/h4arujff7x0ouiq5/nJIa0ndPmAtKAvGf6zQ6j0ompBkxAKAioON\nmInmYL2roSI2I5G/LagDkDrB3lzH+Brk5NvZ9RKrAoGAGox462GGGb/NbGdDkahN\ndifzYYvq4FPiWFFo0ynKAulxCBWLXO/N45XNuAyen433d8eREcAYz1Dzax44+MdQ\nHo9dU7YcZvFyt6iZsYeQF8dluHui3vzMpUe0KbqpZC5KMOSw53ZdNIwzo8NTAK59\n+uv3dHGj7sS8fhDo3yCifzc=\n-----END PRIVATE KEY-----\n",
|
||||
"client_email": "poweron-voice-services@poweronid.iam.gserviceaccount.com",
|
||||
"client_id": "116641749406798186404",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/poweron-voice-services%40poweronid.iam.gserviceaccount.com",
|
||||
"universe_domain": "googleapis.com"
|
||||
}
|
||||
|
||||
# Web Search configuration
|
||||
Web_Search_MAX_QUERY_LENGTH = 400
|
||||
Web_Search_MAX_RESULTS = 20
|
||||
|
|
@ -71,4 +29,9 @@ Web_Search_MIN_RESULTS = 1
|
|||
# Web Crawl configuration
|
||||
Web_Crawl_TIMEOUT = 30
|
||||
Web_Crawl_MAX_RETRIES = 3
|
||||
Web_Crawl_RETRY_DELAY = 2
|
||||
Web_Crawl_RETRY_DELAY = 2
|
||||
|
||||
# Web Research configuration
|
||||
Web_Research_MAX_DEPTH = 2
|
||||
Web_Research_MAX_LINKS_PER_DOMAIN = 4
|
||||
Web_Research_CRAWL_TIMEOUT_MINUTES = 10
|
||||
Binary file not shown.
67
env_dev.env
67
env_dev.env
|
|
@ -4,51 +4,33 @@
|
|||
APP_ENV_TYPE = dev
|
||||
APP_ENV_LABEL = Development Instance Patrick
|
||||
APP_API_URL = http://localhost:8000
|
||||
|
||||
# Database Configuration for Application
|
||||
# JSON File Storage (current)
|
||||
# DB_APP_HOST=D:/Temp/_powerondb
|
||||
# DB_APP_DATABASE=app
|
||||
# DB_APP_USER=dev_user
|
||||
# DB_APP_PASSWORD_SECRET=dev_password
|
||||
APP_KEY_SYSVAR = D:/Athi/Local/Web/poweron/local/key.txt
|
||||
APP_INIT_PASS_ADMIN_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpEeFFtRGtQeVUtcjlrU3dab1ZxUm9WSks0MlJVYUtERFlqUElHemZrOGNENk1tcmJNX3Vxc01UMDhlNU40VzZZRVBpUGNmT3podzZrOGhOeEJIUEt4eVlSWG5UYXA3d09DVXlLT21Kb1JYSUU9
|
||||
APP_INIT_PASS_EVENT_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpERzZjNm56WGVBdjJTeG5Udjd6OGQwUVotYXUzQjJ1YVNyVXVBa3NZVml3ODU0MVNkZjhWWmJwNUFkc19BcHlHMTU1Q3BRcHU0cDBoZkFlR2l6UEZQU3d2U3MtMDh5UDZteGFoQ0EyMUE1ckE9
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_APP_HOST=localhost
|
||||
DB_APP_DATABASE=poweron_app_dev
|
||||
DB_APP_DATABASE=poweron_app
|
||||
DB_APP_USER=poweron_dev
|
||||
DB_APP_PASSWORD_SECRET=dev_password
|
||||
DB_APP_PASSWORD_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpEcUIxNEFfQ2xnS0RrSC1KNnUxTlVvTGZoMHgzaEI4Z3NlVzVROTVLak5Ubi1vaEZubFZaMTFKMGd6MXAxekN2d2NvMy1hRjg2UVhybktlcFA5anZ1WjFlQmZhcXdwaGhWdzRDc3ExeUhzWTg9
|
||||
DB_APP_PORT=5432
|
||||
|
||||
# Database Configuration Chat
|
||||
# JSON File Storage (current)
|
||||
# DB_CHAT_HOST=D:/Temp/_powerondb
|
||||
# DB_CHAT_DATABASE=chat
|
||||
# DB_CHAT_USER=dev_user
|
||||
# DB_CHAT_PASSWORD_SECRET=dev_password
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_CHAT_HOST=localhost
|
||||
DB_CHAT_DATABASE=poweron_chat_dev
|
||||
DB_CHAT_DATABASE=poweron_chat
|
||||
DB_CHAT_USER=poweron_dev
|
||||
DB_CHAT_PASSWORD_SECRET=dev_password
|
||||
DB_CHAT_PASSWORD_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpERFNzNVhoalpCR0QxYXAwdEpXWXVVOTdZdWtqWW5FNXFGcFl2amNYLWYwYl9STXltRlFxLWNzVWlMVnNYdXk0RklnRExFT0FaQjg2aGswNnhhSGhCN29KN2VEb2FlUV9NTlV3b0tLelplSVU9
|
||||
DB_CHAT_PORT=5432
|
||||
|
||||
# Database Configuration Management
|
||||
# JSON File Storage (current)
|
||||
# DB_MANAGEMENT_HOST=D:/Temp/_powerondb
|
||||
# DB_MANAGEMENT_DATABASE=management
|
||||
# DB_MANAGEMENT_USER=dev_user
|
||||
# DB_MANAGEMENT_PASSWORD_SECRET=dev_password
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_MANAGEMENT_HOST=localhost
|
||||
DB_MANAGEMENT_DATABASE=poweron_management_dev
|
||||
DB_MANAGEMENT_DATABASE=poweron_management
|
||||
DB_MANAGEMENT_USER=poweron_dev
|
||||
DB_MANAGEMENT_PASSWORD_SECRET=dev_password
|
||||
DB_MANAGEMENT_PASSWORD_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpEUldqSTVpUnFqdGhITDYzT3RScGlMYVdTMmZhOXdudDRCc3dhdllOd3l6MS1vWHY2MjVsTUF1Sk9saEJOSk9ONUlBZjQwb2c2T1gtWWJhcXFzVVVXd01xc0U0b0lJX0JyVDRxaDhNS01JcWs9
|
||||
DB_MANAGEMENT_PORT=5432
|
||||
|
||||
# Security Configuration
|
||||
APP_JWT_SECRET_SECRET=rotated_jwt_secret_2025_09_17_f8a3b6c2-7d4e-45b6-9a1f-3c0b9a1d2e7f
|
||||
APP_JWT_KEY_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpERjlrSktmZHVuQnJ1VVJDdndLaUcxZGJsT2ZlUFRlcFdOZ001RnlzM2FhLWhRV2tjWWFhaWQwQ3hkcUFvbThMcndxSjFpYTdfRV9OZGhTcksxbXFTZWg5MDZvOHpCVXBHcDJYaHlJM0tyNWRZckZsVHpQcmxTZHJoZUs1M3lfU2ljRnJaTmNSQ0w0X085OXI0QW80M2xfQnJqZmZ6VEh3TUltX0xzeE42SGtZPQ==
|
||||
APP_TOKEN_EXPIRY=300
|
||||
|
||||
# CORS Configuration
|
||||
|
|
@ -56,7 +38,7 @@ APP_ALLOWED_ORIGINS=http://localhost:8080,https://playground.poweron-center.net
|
|||
|
||||
# Logging configuration
|
||||
APP_LOGGING_LOG_LEVEL = DEBUG
|
||||
APP_LOGGING_LOG_FILE = poweron.log
|
||||
APP_LOGGING_LOG_DIR = D:/Athi/Local/Web/poweron/local/logs
|
||||
APP_LOGGING_FORMAT = %(asctime)s - %(levelname)s - %(name)s - %(message)s
|
||||
APP_LOGGING_DATE_FORMAT = %Y-%m-%d %H:%M:%S
|
||||
APP_LOGGING_CONSOLE_ENABLED = True
|
||||
|
|
@ -66,4 +48,29 @@ APP_LOGGING_BACKUP_COUNT = 5
|
|||
|
||||
# Service Redirects
|
||||
Service_MSFT_REDIRECT_URI = http://localhost:8000/api/msft/auth/callback
|
||||
Service_GOOGLE_REDIRECT_URI = http://localhost:8000/api/google/auth/callback
|
||||
Service_GOOGLE_REDIRECT_URI = http://localhost:8000/api/google/auth/callback
|
||||
|
||||
# AI configuration
|
||||
Connector_AiOpenai_API_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpEajBuZmtYTVdqLTBpQm9KZ2pCXzRCV3VhZzlYTEhKb1FqWXNrV3lyb25uZUN1WVVQUEY3dGYtejludV9MNGlKeVREanZGOGloV09mY2ttQ3k5SjBFOGFac2ZQTkNKNUZWVnRINVQyeWhsR2wyYnVrRDNzV2NqSHB0ajQ4UWtGeGZtbmR0Q3VvS0hDZlphVmpSc2Z6RG5nPT0=
|
||||
Connector_AiAnthropic_API_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpENmFBWG16STFQUVZxNzZZRzRLYTA4X3lRanF1VkF4cU45OExNMzlsQmdISGFxTUxud1dXODBKcFhMVG9KNjdWVnlTTFFROVc3NDlsdlNHLUJXeG41NDBHaXhHR0VHVWl5UW9RNkVWbmlhakRKVW5pM0R4VHk0LUw0TV9LdkljNHdBLXJua21NQkl2b3l4UkVkMGN1YjBrMmJEeWtMay1jbmxrYWJNbUV0aktCXzU1djR2d2RSQXZORTNwcG92ZUVvVGMtQzQzTTVncEZTRGRtZUFIZWQ0dz09
|
||||
Connector_AiPerplexity_API_SECRET = DEV_ENC:Z0FBQUFBQm82Mzk2Q1MwZ0dNcUVBcUtuRDJIcTZkMXVvYnpjM3JEMzJiT1NKSHljX282ZDIyZTJYc09VSTdVNXAtOWU2UXp5S193NTk5dHJsWlFjRjhWektFOG1DVGY4ZUhHTXMzS0RPN1lNcF9nSlVWbW5BZ1hkZDVTejl6bVZNRFVvX29xamJidWRFMmtjQmkyRUQ2RUh6UTN1aWNPSUJBPT0=
|
||||
Connector_AiTavily_API_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpEQTdnUHMwd2pIaXNtMmtCTFREd0pyQXRKb1F5eGtHSnkyOGZiUnlBOFc0b3Vzcndrc3ViRm1nMDJIOEZKYWxqdWNkZGh5N0Z4R0JlQmxXSG5pVnJUR2VYckZhMWNMZ1FNeXJ3enJLVlpiblhOZTNleUg3ZzZyUzRZanFSeDlVMkI=
|
||||
|
||||
# Agent Mail configuration
|
||||
Service_MSFT_CLIENT_ID = c7e7112d-61dc-4f3a-8cd3-08cc4cd7504c
|
||||
Service_MSFT_CLIENT_SECRET = DEV_ENC:Z0FBQUFBQm83T29rV1pQelMtc1p1MXR4NTFpa19CTEhHQ0xfNmdPUmZqcWp5UHBMS0hYTGl4c1pPdmhTNTJVWUl5WnlnUUZhV0VTRzVCb0d5YjR1NnZPZk5CZ0dGazNGdUJVbjkxeVdrYlNiVjJUYzF2aVFtQnVxTHFqTTJqZlF0RTFGNmE1OGN1TEk=
|
||||
Service_MSFT_TENANT_ID = common
|
||||
|
||||
# Google Service configuration
|
||||
Service_GOOGLE_CLIENT_ID = 354925410565-aqs2b2qaiqmm73qpjnel6al8eid78uvg.apps.googleusercontent.com
|
||||
Service_GOOGLE_CLIENT_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpETDJhbGVQMHlFQzNPVFI1ZzBMa3pNMGlQUHhaQm10eVl1bFlSeTBybzlTOWE2MURXQ0hkRlo0NlNGbHQxWEl1OVkxQnVKYlhhOXR1cUF4T3k0WDdscktkY1oyYllRTmdDTWpfbUdwWGtSd1JvNlYxeTBJdEtaaS1vYnItcW0yaFM=
|
||||
|
||||
# Google Cloud Speech Services configuration
|
||||
Connector_GoogleSpeech_API_KEY_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpETk5FWWM3Q0JKMzhIYTlyMkhuNjA4NlF4dk82U2NScHhTVGY3UG83NkhfX3RrcWVtWWcyLXRjU1dTT21zWEl6YWRMMUFndXpsUnJOeHh3QThsNDZKRXROTzdXRUdsT0JZajZJNVlfb0gtMXkwWm9DOERPVnpjU0pyUEZfOGJsUnprT3ltMVVhalUyUm9hMUFtZEtHUnJqOGZ4dEZjZm5SWVVTckVCWnY1UkdVSHVmUlgwbnAyc0xDQW84R3ViSko5OHVCVWZRUVNiaG1pVFB6X3EwS0FPd2dUYjhiSmRjcXh2WEZiXzI4SFZqT21tbDduUWRyVWdFZXpmcVM5ZDR0VWtzZnF5UER6cGwwS2JlLV9CSTZ0Z0IyQ1h0YW9TcmhRTXZEckp4bWhmTkt6UTNYMk4zVkpnbUJmaDIxZnoyR2dWTEYwTUFEV0w2eUdUUGpoZk9XRkt4RVF1Z1NPdUpBeTcyWV9PY1Ffd2s0ZEdVekxGekhoeEl4TmNqaXYtbUJuSVdycFducERWdWtZajZnX011Q2w4eE9VMTBqQ1ZxRmdScWhXY1E3WWhzX1JZcHhxam9FbDVPN3Q1MWtrMUZuTUg3LVFQVHp1T1hpQWNDMzEzekVJWk9ybl91YUVjSkFob1VaMi1ONEtuMnRSOEg1S3QybUMwbVZDejItajBLTjM2Zy1hNzZQMW5LLVVDVGdFWm5BZUxNeEFnUkZzU3dxV0lCUlc0LWo4b05GczVpOGZSV2ZxbFBwUml6OU5tYjdnTks3Y3hrVEZVTHlmc1NPdFh4WE5pWldEZklOQUxBbjBpMTlkX3FFQVJ6c2NSZGdzTThycE92VW82enZKamhiRGFnU25aZGlHZHhZd2lUUmhuTVptNjhoWVlJQkxIOEkzbzJNMjZCZFJyM25tdXBnQ2ZWaHV3b2p6UWJpdk9xUEhBc1dyTlNmeF9wbm5yYUhHV01UZnVXWDFlNzBkdXlWUWhvcmJpSmljbmE3LUpUZEg4VzRwZ2JVSjdYUm1sODViQXVxUzdGTmZFbVpiN2V1YW5XV3U4b2VRWmxldGVGVHZsSldoekhVLU9wZ2V0cGZIYkNqM2pXVGctQVAyUm4xTHhpd1VVLXFhcnVEV21Rby1hbTlqTl84TjVveHdYTExUVkhHQ0ltaTB2WXJnY1NQVE5PbWg3ejgySElYc1JSTlQ3NDlFUWR6STZVUjVqaXFRN200NF9LY1ljQ0R2UldlWUtKY1NQVnJ4QXRyYTBGSWVuenhyM0Z0cWtndTd1eG8xRzY5a2dNZ1hkQm5MV3BHVzA2N1QwUkd6WlRGYTZQOUhnVWQ2S0Y5U0s1dXFNVXh5Q2pLWVUxSUQ2MlR1ak52NmRIZ2hlYTk1SGZGWS1RV3hWVU9rR3d1Rk9MLS11REZXbzhqMHpsSm1HYW1jMUNLT29YOHZsRWNaLTVvOFpmT3l3MHVwaERTT0dNLWFjcGRYZ25qT2szTkVFUnRFR3JWYS1aNXFIRnMyalozTlQzNFF2NXJLVHVPVF9zdTF6ZjlkbzJ4RFc2ZENmNFFxZDZzTzhfMUl0bW96V0lPZkh1dXFYZlEteFBlSG84Si1FNS1TTi1OMkFnX2pOYW8xY3MxMVJnVC02MDUyaXZfMEVHWDQtVlRpcENmV0h3V0dCWEFRS2prQXdNRlQ5dnRFVHU0Q1dNTmh0SlBCaU55bFMydWM1TTFFLW96ODBnV3dNZHFZTWZhRURYSHlrdzF3RlRuWDBoQUhSOUJWemtRM3pxcDJFbGJoaTJ3ZktRTlJxbXltaHBoZXVJVDlxS3cxNWo2c0ZBV0NzaUstRWdsMW1xLXFkanZGYUFiU0tSLXFQa0tkcDFoMV9kak41ZjQ0R214UmtOR1ZBanRuemY3Mmw1SkZ5aDZodGIzT3N2aV85MW9kcld6c0g0ZDgtTWo3b3Y3VjJCRnR2U2tMVm9rUXNVRnVHbzZXVTZ6RmI2RkNmajBfMWVnODVFbnpkT0oyci15czJHU0p1cUowTGZJMzVnd3hIRjQyTVhKOGRkcFRKdVpyQ3Yzd01Jb1lSajFmV0paeEV0cjk1SmpmdWpDVFJMUmMtUFctOGhaTmlKQXNRVlVUNlhJemxudHZCR056SVlBb3NOTEYxRTRLaFlVd2d3TWtxVlB6ZEtQLTkxOGMyY3N0a2pYRFUweDBNaGhja2xSSklPOUZla1dKTWRNbG8tUGdSNEV5cW90OWlOZFlIUExBd3U2b2hyS1owbXVMM3p0Qm41cUtzWUxYNzB1N3JpUTNBSGdsT0NuamNTb1lIbXR4MG1sakNPVkxBUXRLVE1xX0YxWDhOcERIY1lTQVFqS01CaXZKNllFaXlIR0JsM1pKMmV1OUo3TGI1WkRaVnYxUTl1LTM0SU1qN1V1b0RCT0x0VHNLTmNLZnk1S0MxYnBBcm03WnVua0xqaEhGUzhOU253ZkppRzdudXBSVlMxeFVOSWxtZ1o2RVBSQUhEUEFuQ1hxSVZMME4yWUtaU3VyRGo3RkUyRUNjT0pNcE1BdE1ZRzdXVl8ydUtXZjdMdHdEVW4teHUtTi1HSGliLUxud21TX0NtcGVkRFBHNkZ1WTlNczR4OUJfUVluc1BoV09oWS1scUdsNnB5d1U5M1huX3k4QzAyNldtb2hybktYN2xKZ1NTNWFsaWwzV3pCRVhkaGR5eTNlV1d6ZzFfaFZTT0E4UjRpQ3pKdEZxUlJ6UFZXM3laUndyWEk2NlBXLUpoajVhZzVwQXpWVzUtVjVNZFBwdWdQa3AxZC1KdGdqNnhibjN4dmFYb2cxcEVwc1g5R09zRUdINUZtOE5QRjVUU0dpZy1QVl9odnFtVDNuWFZLSURtMXlSMlhRNTBWSVFJbEdOOWpfVWV0SmdRWDdlUXZZWE8xRUxDN1I0aEN6MHYwNzM1cmpJS0ZpMnBYWkxfb3FsbEV1VnlqWGxqdVJ6SHlwSjAzRlMycTBaQ295NXNnZERpUnJQcjhrUUd3bkI4bDVzRmxQblhkaFJPTTdISnVUQmhET3BOMTM4bjVvUEc2VmZhb2lrR1FyTUl2RWNEeGg0U0dsNnV6eU5zOUxiNDY5SXBxR0hBS00wOTgyWTFnWkQyaEtLVUloT3ZxZGh0RWVGRmJzenFsaUtfZENQM0JzdkVVeTdXR3hUSmJST1NBMUI1NkVFWncwNW5JZVVLX1p1RXdqVnFfQWpvQ08yQjZhN1NkTkpTSnUxOVRXZXE0WFEtZWxhZW1NNXYtQ2sya0VGLURmS01lMkctNVY3c2ZhN0ZGRFgwWHlabTFkeS1hcUZ1dDZ3cnpPQ3hha2IzVE11M0pqbklmU0diczBqTFBNZC1QZGp6VzNTSnJVSjJoWkJUQjVORG4tYUJmMEJtSUNUdVpEaGt6OTM3TjFOdVhXUHItZjRtZ25nU3NhZC1sVTVXNTRDTmxZbnlfeHNsdkpuMXhUYnE1MnpVQ0ZOclRWM1M4eHdXTzRXbFRZZVQtTS1iRVdXVWZMSGotcWg3MUxUYTFnSEEtanBCRHlZRUNIdGdpUFhsYjdYUndCZnRITzhMZVJ1dHFoVlVNb0duVjlxd0U4OGRuQVV3MG90R0hiYW5MWkxWVklzbWFRNzBfSUNrdzc5bVdtTXg0dExEYnRCaDI3c1I4TWFwLXZKR0wxSjRZYjZIV3ZqZjNqTWhFT0RGSDVMc1A1UzY2bDBiMGFSUy1fNVRQRzRJWDVydUpqb1ZfSHNVbldVeUN2YlAxSW5WVDdxVzJ1WHpLeUdmb0xWMDNHN05oQzY3YnhvUUdhS2xaOHNidkVvbTZtSHFlblhOYmwyR3NQdVJDRUdxREhWdF9ZcXhwUWxHc2hyLW5vUGhIUVhJNUNhY0hFU0ptVnI0TFVhZDE1TFBBUEstSkRoZWJ5MHJhUmZrR1ZrRlFtRGpxS1pOMmFMQjBsdjluY3FiYUU4eGJVVXlZVEpuNWdHVVhJMGtwaTdZR2NDbXd2eHpOQ09SeTV6N1BaVUpsR1pQVDBZcElJUUt6VnVpQmxSYnE4Y1BCWV9IRWdVV0p3enBGVHItdnBGN3NyNWFBWmkySnByWThsbDliSlExQmp3LVlBaDIyZXp6UnR6cU9rTzJmTDBlSVpON0tiWllMdm1oME1zTFl2S2ZYYllhQlY2VHNZRGtHUDY4U1lIVExLZTU4VzZxSTZrZHl1ZTBDc0g4SjI4WGYyZHV1bm9wQ3R2Z09ld1ZmUkN5alJGeHZKSHl1bWhQVXpNMzdjblpLcUhfSm02Qlh5S1FVN3lIcHl0NnlRPT0=
|
||||
|
||||
# Feature SyncDelta JIRA configuration
|
||||
Feature_SyncDelta_JIRA_DELTA_TOKEN_SECRET = DEV_ENC:Z0FBQUFBQm8xSUpEbm0yRUJ6VUJKbUwyRW5kMnRaNW4wM2YxMkJUTXVXZUdmdVRCaUZIVHU2TTV2RWZLRmUtZkcwZE4yRUNlNDQ0aUJWYjNfdVg5YjV5c2JwMHhoUUYxZWdkeS11bXR0eGxRLWRVaVU3cUVQZWJlNDRtY1lWUDdqeDVFSlpXS0VFX21WajlRS3lHQjc0bS11akkybWV3QUFlR2hNWUNYLUdiRjZuN2dQODdDSExXWG1Dd2ZGclI2aUhlSWhETVZuY3hYdnhkb2c2LU1JTFBvWFpTNmZtMkNVOTZTejJwbDI2eGE0OS1xUlIwQnlCSmFxRFNCeVJNVzlOMDhTR1VUamx4RDRyV3p6Tk9qVHBrWWdySUM3TVRaYjd3N0JHMFhpdzFhZTNDLTFkRVQ2RVE4U19COXRhRWtNc0NVOHRqUS1CRDFpZ19xQmtFLU9YSDU3TXBZQXpVcld3PT0=
|
||||
|
||||
# Debug Configuration
|
||||
APP_DEBUG_CHAT_WORKFLOW_ENABLED = True
|
||||
APP_DEBUG_CHAT_WORKFLOW_DIR = ./test-chat
|
||||
38
env_int.env
38
env_int.env
|
|
@ -4,30 +4,33 @@
|
|||
APP_ENV_TYPE = int
|
||||
APP_ENV_LABEL = Integration Instance
|
||||
APP_API_URL = https://gateway-int.poweron-center.net
|
||||
APP_KEY_SYSVAR = CONFIG_KEY
|
||||
APP_INIT_PASS_ADMIN_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjWm41MWZ4TUZGaVlrX3pWZWNwakJsY3Facm0wLVZDd1VKeTFoZEVZQnItcEdUUnVJS1NXeDBpM2xKbGRsYmxOSmRhc29PZjJSU2txQjdLbUVrTTE1NEJjUXBHbV9NOVJWZUR3QlJkQnJvTEU9
|
||||
APP_INIT_PASS_EVENT_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjdmtrakgxa0djekZVNGtTZV8wM2I5UUpCZllveVBMWXROYk5yS3BiV3JEelJSM09VYTRONHpnY3VtMGxDRk5JTEZSRFhtcDZ0RVRmZ1RicTFhb3c5dVZRQ1o4SmlkLVpPTW5MMTU2eTQ0Vkk9
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_APP_HOST=gateway-int-server.postgres.database.azure.com
|
||||
DB_APP_DATABASE=poweron_app
|
||||
DB_APP_USER=heeshkdlby
|
||||
DB_APP_PASSWORD_SECRET=VkAjgECESbEVQ$Tu
|
||||
DB_APP_PASSWORD_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjb2dka2pnN0tUbW1EU0w1Rk1jNERKQ0Z1U3JkVDhuZWZDM0g5M0kwVDE5VHdubkZna3gtZVAxTnl4MDdrR1c1ZXJ3ejJHYkZvcGUwbHJaajBGOWJob0EzRXVHc0JnZkJyNGhHZTZHOXBxd2c9
|
||||
DB_APP_PORT=5432
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_CHAT_HOST=gateway-int-server.postgres.database.azure.com
|
||||
DB_CHAT_DATABASE=poweron_chat
|
||||
DB_CHAT_USER=heeshkdlby
|
||||
DB_CHAT_PASSWORD_SECRET=VkAjgECESbEVQ$Tu
|
||||
DB_CHAT_PASSWORD_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjczYzOUtTa21MMGJVTUQ5UmFfdWc3YlhCbWZOeXFaNEE1QzdJV3BLVjhnalBkLVVCMm5BZzdxdlFXQXc2RHYzLWtPSFZkZE1iWG9rQ1NkVWlpRnF5TURVbnl1cm9iYXlSMGYxd1BGYVc0VDA9
|
||||
DB_CHAT_PORT=5432
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_MANAGEMENT_HOST=gateway-int-server.postgres.database.azure.com
|
||||
DB_MANAGEMENT_DATABASE=poweron_management
|
||||
DB_MANAGEMENT_USER=heeshkdlby
|
||||
DB_MANAGEMENT_PASSWORD_SECRET=VkAjgECESbEVQ$Tu
|
||||
DB_MANAGEMENT_PASSWORD_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjTnJKNlJMNmEwQ0Y5dVNrR3pkZk9SQXVvLTRTNW9lQ1g3TTE5cFhBNTd5UENqWW9qdWd3NWNseWhnUHJveDJyd1Z3X1czS3VuZnAwZHBXYVNQWlZsRy12ME42NndEVlR5X3ZPdFBNNmhLYm89
|
||||
DB_MANAGEMENT_PORT=5432
|
||||
|
||||
# Security Configuration
|
||||
APP_JWT_SECRET_SECRET=rotated_jwt_secret_2025_09_17_2c5f8e7a-1b3d-49c7-ae5d-9f0a2c3d4b5e
|
||||
APP_JWT_KEY_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjNUctb2RwU25iR3ZnanBOdHZhWUtIajZ1RnZzTEp4aDR0MktWRjNoeVBrY1Npd1R0VE9YVHp3M2w1cXRzbUxNaU82QUJvaDNFeVQyN05KblRWblBvbWtoT0VXbkNBbDQ5OHhwSUFnaDZGRG10Vmgtdm1YUkRsYUhFMzRVZURmSFlDTFIzVWg4MXNueDZyMGc5aVpFdWRxY3dkTExGM093ZTVUZVl5LUhGWnlRPQ==
|
||||
APP_TOKEN_EXPIRY=300
|
||||
|
||||
# CORS Configuration
|
||||
|
|
@ -35,7 +38,7 @@ APP_ALLOWED_ORIGINS=http://localhost:8080,https://playground.poweron-center.net,
|
|||
|
||||
# Logging configuration
|
||||
APP_LOGGING_LOG_LEVEL = DEBUG
|
||||
APP_LOGGING_LOG_FILE = /home/site/wwwroot/poweron.log
|
||||
APP_LOGGING_LOG_DIR = /home/site/wwwroot/
|
||||
APP_LOGGING_FORMAT = %(asctime)s - %(levelname)s - %(name)s - %(message)s
|
||||
APP_LOGGING_DATE_FORMAT = %Y-%m-%d %H:%M:%S
|
||||
APP_LOGGING_CONSOLE_ENABLED = True
|
||||
|
|
@ -46,3 +49,28 @@ APP_LOGGING_BACKUP_COUNT = 5
|
|||
# Service Redirects
|
||||
Service_MSFT_REDIRECT_URI = https://gateway-int.poweron-center.net/api/msft/auth/callback
|
||||
Service_GOOGLE_REDIRECT_URI = https://gateway-int.poweron-center.net/api/google/auth/callback
|
||||
|
||||
# AI configuration
|
||||
Connector_AiOpenai_API_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjSDBNYkptSkQxTUotYVVpZVNZc0dxNGNwSEtkOEE0T3RZWjROTEhSRlRXdlZmQUxxZ0w3Y0xOV2JNV19LNF9yTUZiU1pUNG15U2VDUDdSVlI4VlpnR3JXVFFtcXBaTEZiaUtSclVFd0lCZG1rWVhra1dfWTVQOTBEYUU0MjByYVNEMTFmeXNOcmpUT216MmJKdlVPeW5nPT0=
|
||||
Connector_AiAnthropic_API_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjT1ZlRWVJdVZMT3ljSFJDcFdxRFBRVkZhS204NnN5RDBlQ0tpenhTM0FFVktuWW9mWHNwRWx2dHB0eDBSZ0JFQnZKWlp6c01pVGREWHd1eGpERnU0Q2xhaks1clQ1ZXVsdnd2ZzhpNXNQS1BhY3FjSkdkVEhHalNaRGR4emhpakZncnpDQUVxOHVXQzVUWmtQc0FsYmFwTF9TSG5FOUFtWk5Ick1NcHFvY2s1T1c2WXlRUFFJZnh6TWhuaVpMYmppcDR0QUx0a0R6RXlwbGRYb1R4dzJkUT09
|
||||
Connector_AiPerplexity_API_SECRET = INT_ENC:Z0FBQUFBQm82Mzk2UWZJdUFhSW8yc3RKc0tKRXphd0xWMkZOVlFpSGZ4SGhFWnk0cTF5VjlKQVZjdS1QSWdkS0pUSWw4OFU5MjUxdTVQel9aeWVIZTZ5TXRuVmFkZG0zWEdTOGdHMHpsTzI0TGlWYURKU1Q0VVpKTlhxUk5FTmN6SUJScDZ3ZldIaUJZcWpaQVRiSEpyQm9tRTNDWk9KTnZBPT0=
|
||||
Connector_AiTavily_API_SECRET = INT_ENC:Z0FBQUFBQm8xSVRkdkJMTDY0akhXNzZDWHVYSEt1cDZoOWEzSktneHZEV2JndTNmWlNSMV9KbFNIZmQzeVlrNE5qUEIwcUlBSGM1a0hOZ3J6djIyOVhnZzI3M1dIUkdicl9FVXF3RGktMmlEYmhnaHJfWTdGUkktSXVUSGdQMC1vSEV6VE8zR2F1SVk=
|
||||
|
||||
# Agent Mail configuration
|
||||
Service_MSFT_CLIENT_ID = c7e7112d-61dc-4f3a-8cd3-08cc4cd7504c
|
||||
Service_MSFT_CLIENT_SECRET = INT_ENC:Z0FBQUFBQm83T29rMDZvcV9qTG5xb1FzUkdqS1llbzRxSEJXbmpONFFtcUtfZXdtZjQybmJSMjBjMEpnRVhiOGRuczZvVFBFdVVTQV80SG9PSnRQTEpLdVViNm5wc2E5aGRLWjZ4TGF1QjVkNmdRSzBpNWNkYXVublFYclVEdEM5TVBBZWVVMW5RVWk=
|
||||
Service_MSFT_TENANT_ID = common
|
||||
|
||||
# Google Service configuration
|
||||
Service_GOOGLE_CLIENT_ID = 354925410565-aqs2b2qaiqmm73qpjnel6al8eid78uvg.apps.googleusercontent.com
|
||||
Service_GOOGLE_CLIENT_SECRET = INT_ENC:Z0FBQUFBQm8xSVRjNThGeVRNd3hacThtRnE0bzlDa0JPUWQyaEd6QjlFckdsMGZjRlRfUks2bXV3aDdVRTF3LVRlZVY5WjVzSXV4ZGNnX002RDl3dkNYdGFzZkxVUW01My1wTHRCanVCLUozZEx4TlduQlB5MnpvNTR2SGlvbFl1YkhzTEtsSi1SOEo=
|
||||
|
||||
# Google Cloud Speech Services configuration
|
||||
Connector_GoogleSpeech_API_KEY_SECRET = INT_ENC:Z0FBQUFBQm8xSVRkNmVXZ1pWcHcydTF2MXF0ZGJoWHBydF85bTczTktiaEJ3Wk1vMW1mZVhDSG1yd0ZxR2ZuSGJTX0N3MWptWXFJTkNTWjh1SUVVTXI4UDVzcGdLMkU5SHJ2TUpkRlRoRWdnSldtYjNTQkh4UDJHY2xmdTdZQ1ZiMTZZcGZxS3RzaHdjV3dtVkZUcEpJcWx0b2xuQVR6ZmpoVFZPY1hNMTV2SnhDaC1IZEh4UUpLTy1ILXA4RG1zamJTbUJ4X0t2M2NkdzJPbEJxSmFpRzV3WC0wZThoVzlxcmpHZ3ZkLVlVY3REZk1vV19WQ05BOWN6cnJ4MWNYYnNiQ0FQSUVnUlpfM3BhMnlsVlZUOG5wM3pzM1lSN1UzWlZKUXRLczlHbjI1LTFvSUJ4SlVXMy1BNk43bE5Hb0RfTTVlWk9oZnFIaVg0SW5pbm9EcXRTTzU1RFlYY3dTcnpKWWNyNjN5T1BGZ0FmX253cEFncmhvZVRuM05KYzhkOEhFMFJsc2NBSEwzZVZ1R0JMOGxsekVwUE55alZaRXFrdzNWWVNGWXNmbnhKeWhQSFo2VXBTUlRPeHdvdVdncEFuOWgydEtsSUFneUN6cGVaTnBSdjNCdVJseGJFdmlMc203UFhLVlYyTENkaGg2dVN6Z2xwT1ZmTmN5bVZGUkM3ZWcyVkt2ckFUVVd3WFFwYnJjNVRobEh2SkVJbXRwUUpEOFJKQ1NUc0Q4NHNqUFhPSDh5cTV6MEcwSDEwRUJCQ2JiTTJlOE5nd3pMMkJaQ1dVYjMwZVVWWnlETmp2dkZ3aXEtQ29WNkxZTFkzYUkxdTlQUU1OTnhWWU12YU9MVnJQa1d2ZjRtUlhneTNubEMxTmp1eUNPOThSMlB3Y1F0T2tCdFNsNFlKalZPV25yR2QycVBUb096RmZ1V0FTaGsxLV9FWDBmenBIOXpMdGpLcUc0TWRoY2hlMFhYTzlET1ZRekw0ZHNwUVBQdVJBX2h6Q2ZzWVZJWTNybTJiekp3WmhmWF9SUFBXQzlqUjctcVlHWWVMZWVQallzR0JGTVF0WmtnWlg1aTM1bFprNVExZXY5dnNvWF93UjhwbkJ3RzNXaVJ2d2RRU3JJVlBvaVh4eTlBRUtqWkJia3dJQVVBV2Nqdm9FUTRUVW1TaHp2ZUwxT0N2ZndxQ2Nka1RYWXF0LWxIWFE0dTFQcVhncFFPM0hFdUUtYlFnemx3WkF4bjA1aDFULUdrZlVZbEJtRGRCdjJyVkdJSXozd0I0dF9zbWhOeHFqRDA4T1NVaWR5cjBwSVgwbllPU294NjZGTnM1bFhIdGpNQUxFOENWd3FCbGpSRFRmRXotQnU0N2lCVEU5RGF6Qi10S2U2NGdadDlrRjZtVE5oZkw5ZWFjXzhCTmxXQzNFTFgxRXVYY3J3YkxnbnlBSm9PY3h4MlM1NVFQbVNDRW5Ld1dvNWMxSmdoTXJuaE1pT2VFeXYwWXBHZ29MZDVlN2lwUUNIeGNCVVdQVi1rRXdJMWFncUlPTXR0MmZVQ1l0d09mZTdzWGFBWUJMUFd3b0RSOU8zeER2UWpNdzAxS0ZJWnB5S3FJdU9wUDJnTTNwMWw3VFVqVXQ3ZGZnU1RkUktkc0NhUHJ0SGFxZ0lVWDEzYjNtU2JfMGNWM1Y0dHlCTzNESEdENC1jUWF5MVppRzR1QlBNSUJySjFfRi1ENHEwcmJ4S3hQUFpXVHA0TG9DZWdoUlo5WnNSM1lCZm1KbEs2ak1yUUU4Wk9JcVJGUkJwc0NvUkMyTjhoTWxtZmVQeDREZVRKZkhYN2duLVNTeGZzdFdBVnhEandJSXB5QjM0azF0ckI3Tk1wSzFhNGVOUVRrNjU0cG9JQ29pN09xOFkwR1lMTlktaGp4TktxdTVtTnNEcldsV2pEZm5nQWpJc2hxY0hjQnVSWUR5VVdaUXBHWUloTzFZUC1oNzJ4UjZ1dnpLcDJxWEZtQlNIMWkzZ0hXWXdKeC1iLXdZWVJhcU04VFlpMU5pd2ZIdTdCdkVWVFVBdmJuRk16bEFFQTh4alBrcTV2RzliT2hGdTVPOXlRMjFuZktiRTZIamQ1VFVqS0hRTXhxcU1mdkgyQ1NjQmZfcjl4c3NJd0RIeDVMZUFBbHJqdEJxWWl3aWdGUEQxR3ZnMkNGdVB4RUxkZi1xOVlFQXh1NjRfbkFEaEJ5TVZlUGFrWVhSTVRPeGxqNlJDTHNsRWRrei1pYjhnUmZrb3BvWkQ2QXBzYjFHNXZoWU1LSExhLWtlYlJTZlJmYUM5Y1Rhb1pkMVYyWTByM3NTS0VXMG1ybm1BTVN2QXRYaXZqX2dKSkZrajZSS2cyVlNOQnd5Y29zMlVyaWlNbTJEb3FuUFFtbWNTNVpZTktUenFZSl91cVFXZjRkQUZyYmtPczU2S1RKQ19ONGFOTHlwX2hOOEE1UHZEVjhnT0xxRjMxTEE4SHhRbmlmTkZwVXJBdlJDbU5oZS05SzI4QVhEWDZaN2ZiSlFwUGRXSnB5TE9MZV9ia3pYcmZVa1dicG5FMHRXUFZXMWJQVDAwOEdDQzJmZEl0ZDhUOEFpZXZWWXl5Q2xwSmFienNCMldlb2NKb2ZRYV9KbUdHRzNUcjU1VUFhMzk1a2J6dDVuNTl6NTdpM0hGa3k0UWVtbF9pdDVsQVp2cndDLUU5dnNYOF9CLS0ySXhBSFdCSnpqV010bllBb3U0cEZZYVF5R2tSNFM5NlRhdS1fb1NqbDBKMkw0V2N0VEZhNExtQlR3ckZ3cVlCeHVXdXJ6X0s4cEtsaG5rVUxCN2RRbHQxTmcyVFBqYUxyOHJzeFBXVUJaRHpXbUoxdHZzMFBzQk1UTUFvX1pGNFNMNDFvZWdTdEUtMUNKMXNIeVlvQk1CeEdpZVdmN0tsSDVZZHJXSGt5c2o2MHdwSTZIMVBhRzM1eU43Q2FtcVNidExxczNJeUx5U2RuUG5EeHpCTlg2SV9WNk1ET3BRNXFuc0pNWlVvZUYtY21oRGtJSmwxQ09QbHBUV3BuS3B5NE9RVkhfellqZjJUQ0diSV94QlhQWmdaaC1TRWxsMUVWSXB0aE1McFZDZDNwQUVKZ2t5cXRTXzlRZVJwN0pZSnJSV21XMlh0TzFRVEl0c2I4QjBxOGRCYkNxek04a011X1lrb2poQ3h2LUhKTGJiUlhneHp5QWFBcE5nMElkNTVzM3JGOWtUQ19wNVBTaVVHUHFDNFJnNXJaWDNBSkMwbi1WbTdtSnFySkhNQl9ZQjZrR2xDcXhTRExhMmNHcGlyWjR3ZU9SSjRZd1l4ZjVPeHNiYk53SW5SYnZPTzNkd1lnZmFseV9tQ3BxM3lNYVBHT0J0elJnMTByZ3VHemxta0tVQzZZRllmQ2VLZ1ZCNDhUUTc3LWNCZXBMekFwWW1fQkQ1NktzNGFMYUdYTU0xbXprY1FONUNlUHNMY3h2NFJMMmhNa3VNdzF4TVFWQk9odnJUMjFJMVd3Z2N6Sms5aEM2SWlWZFViZ0JWTEpUWWM5NmIzOS1oQmRqdkt1NUUycFlVcUxERUZGbnZqTUxIYnJmMDBHZDEzbnJsWEEzSUo3UmNPUDg1dnRUU1FzcWtjTWZwUG9zM0JTY3RqMDdST2UxcXFTM0d0bGkwdFhnMk5LaUlxNWx3V1pLaVlLUFJXZzBzVl9Ia1V1OHdYUEFWOU50UndycGtCdzM0Q0NQamp2VTNqbFBLaGhsbUk5dUI5MjU5OHVySk1oY0drUWtXUloyVVRvOWJmbUVYRzFVeWNQczh2NXJCeVppRlZiWDNJaDhOSmRmX2lURTNVS3NXQXFZT1QtUmdvMWJoVWYxU3lqUUJhbzEyX3I3TXhwbm9wc1FoQ1ZUTlNBRjMyQTBTY2tzbHZ3RFUtTjVxQ0o1QXRTVks2WENwMGZCRGstNU1jN3FhUFJCQThyaFhhMVRsbnlSRXNGRmt3Yk01X21ldmV3bTItWm1JaGpZQWZROEFtT1d1UUtPQlhYVVFqT2NxLUxQenJHX3JfMEdscDRiMXcyZ1ZmU3NFMzVoelZJaDlvT0ZoRGQ2bmtlM0M5ZHlCd2ZMbnRZRkZUWHVBUEx4czNfTmtMckh5eXZrZFBzOEItOGRYOEhsMzBhZ0xlOWFjZzgteVBsdnpPT1pYdUxnbFNXYnhKaVB6QUxVdUJCOFpvU2x2c1FHZV94MDBOVWJhYkxISkswc0U5UmdPWFJLXzZNYklHTjN1QzRKaldKdEVHb0pOU284N3c2LXZGMGVleEZ5NGZ6OGV1dm1tM0J0aTQ3VFlNOEJrdEh3PT0=
|
||||
|
||||
# Feature SyncDelta JIRA configuration
|
||||
Feature_SyncDelta_JIRA_DELTA_TOKEN_SECRET = INT_ENC:Z0FBQUFBQm8xSVRkTUNsWm4wX0p6eXFDZmJ4dFdHNEs1MV9MUzdrb3RzeC1jVWVYZ0REWHRyZkFiaGZLcUQtTXFBZzZkNzRmQ0gxbEhGbUNlVVFfR1JEQTc0aldkZkgyWnBOcjdlUlZxR0tDTEdKRExULXAyUEtsVmNTMkRKU1BJNnFiM0hlMXo4YndMcHlRMExtZDQ3Zm9vNFhMcEZCcHpBPT0=
|
||||
|
||||
# Debug Configuration
|
||||
APP_DEBUG_CHAT_WORKFLOW_ENABLED = FALSE
|
||||
APP_DEBUG_CHAT_WORKFLOW_DIR = ./test-chat
|
||||
38
env_prod.env
38
env_prod.env
|
|
@ -3,31 +3,34 @@
|
|||
# System Configuration
|
||||
APP_ENV_TYPE = prod
|
||||
APP_ENV_LABEL = Production Instance
|
||||
APP_KEY_SYSVAR = CONFIG_KEY
|
||||
APP_INIT_PASS_ADMIN_SECRET = PROD_ENC:Z0FBQUFBQm8xSU5pSXoyVEVwNDZ6cmthQTROUkxGUjh1UWF2UU5zaWRuX3p2aHJCVFo2NEstR0RqdnQ5clZmeVliRlhHZGFHTlhZV2dzMmRPZFVEemVlSHd5VHR3cmpNUXRaRlhZSFZ6d1dsX2Y5Zl9lOXdYdEU9
|
||||
APP_INIT_PASS_EVENT_SECRET = PROD_ENC:Z0FBQUFBQm8xSU5peGNMWExjWGZxQ2VndXVOSUVGcWhQTWd0N3d0blU3bGJvNjgzNVVNNktCQnZlTEtVckV5RUtQMjMwRTBkdmxEMlZwX0k1M1hlOFFNY3hjaWsyd2JmRGl2UWxfSXEwenVnQ3NmaTlxckp2VXM9
|
||||
APP_API_URL = https://gateway-prod.poweron-center.net
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_APP_HOST=gateway-prod-server.postgres.database.azure.com
|
||||
DB_APP_DATABASE=poweron_app
|
||||
DB_APP_USER=gzxxmcrdhn
|
||||
DB_APP_PASSWORD_SECRET=prod_password_very_secure.2025
|
||||
DB_APP_PASSWORD_SECRET = PROD_ENC:Z0FBQUFBQm8xSU5pVmtwYWZQakdWZnJPamVlRWJPa0tnc3daSVVHejVrQ0x1VFZZbHhVSkk0S2tFWl92T2NwWURBMU9UbFROMHZ2TkNKZFlEWjhJZDZ0bnFndC1oYjhNRW1VLWpEYnlDNEJwcGVKckpUVlp6YTg9
|
||||
DB_APP_PORT=5432
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_CHAT_HOST=gateway-prod-server.postgres.database.azure.com
|
||||
DB_CHAT_DATABASE=poweron_chat
|
||||
DB_CHAT_USER=gzxxmcrdhn
|
||||
DB_CHAT_PASSWORD_SECRET=prod_password_very_secure.2025
|
||||
DB_CHAT_PASSWORD_SECRET = PROD_ENC:Z0FBQUFBQm8xSU5pZVZnTzBPTDY1Q3c2U1pDV0lxbXhoWnlYSXRDWVhIeGJwSkdNMzMxR2h5a1FRN00xcWtYUE4ySGpqRllSaGM5SmRZZk9Bd2trVDJNZDdWcEFIbTJtel91MHpsazlTQnRsV2docGdBc0RVeEU9
|
||||
DB_CHAT_PORT=5432
|
||||
|
||||
# PostgreSQL Storage (new)
|
||||
DB_MANAGEMENT_HOST=gateway-prod-server.postgres.database.azure.com
|
||||
DB_MANAGEMENT_DATABASE=poweron_management
|
||||
DB_MANAGEMENT_USER=gzxxmcrdhn
|
||||
DB_MANAGEMENT_PASSWORD_SECRET=prod_password_very_secure.2025
|
||||
DB_MANAGEMENT_PASSWORD_SECRET = PROD_ENC:Z0FBQUFBQm8xSU5pQXdaRnVEQUx2MmU5ck9XZzNfaGVoRXlYMlVjSVM5dWNTekhmR2VYNkd6WVhELUlkLWdFWWRWQ1JJLWZ4WUNwclZVRlg3ZHBCS0xwM1laNklTaEs1czFDRTMxYlV2TWNueEJlTHFyNEt4aVk9
|
||||
DB_MANAGEMENT_PORT=5432
|
||||
|
||||
# Security Configuration
|
||||
APP_JWT_SECRET_SECRET=rotated_jwt_secret_2025_09_17_prod_e1a9c4d7-6b8f-4f2e-9c1a-7e3d2a1b9c5f
|
||||
APP_JWT_KEY_SECRET = PROD_ENC:Z0FBQUFBQm8xSU5pY3JfX1R3cEJhTjAzZGx2amtRSE4yVzZhMmY3a3FHam9BdzBxVWd5R0FRSW1KbmNGS3JDMktKTWptZm4wYmZZZTVDQkh3NVlxSW1MZEdiVWdORng4dm0xV08wZDh0YlBNQTdEbmlnVWduMzNWY1RPX1BqaGtnOTc2ZWNBTnNnd1AtaTNRUExpRThVdzNmdVFHM2hkTjFjcW0ya2szMWNaT3VDeDhXMlJ1NDM4PQ==
|
||||
APP_TOKEN_EXPIRY=300
|
||||
|
||||
# CORS Configuration
|
||||
|
|
@ -35,7 +38,7 @@ APP_ALLOWED_ORIGINS=http://localhost:8080,https://playground.poweron-center.net,
|
|||
|
||||
# Logging configuration
|
||||
APP_LOGGING_LOG_LEVEL = DEBUG
|
||||
APP_LOGGING_LOG_FILE = /home/site/wwwroot/poweron.log
|
||||
APP_LOGGING_LOG_DIR = /home/site/wwwroot/
|
||||
APP_LOGGING_FORMAT = %(asctime)s - %(levelname)s - %(name)s - %(message)s
|
||||
APP_LOGGING_DATE_FORMAT = %Y-%m-%d %H:%M:%S
|
||||
APP_LOGGING_CONSOLE_ENABLED = True
|
||||
|
|
@ -46,3 +49,28 @@ APP_LOGGING_BACKUP_COUNT = 5
|
|||
# Service Redirects
|
||||
Service_MSFT_REDIRECT_URI = https://gateway-prod.poweron-center.net/api/msft/auth/callback
|
||||
Service_GOOGLE_REDIRECT_URI = https://gateway-prod.poweron-center.net/api/google/auth/callback
|
||||
|
||||
# AI configuration
|
||||
Connector_AiOpenai_API_SECRET = PROD_ENC:Z0FBQUFBQm8xSU5pU05XM2hMaExPMnpYeFpwRVhyYl9JZmRITmlmRDlWOUJSSWE4NTFLZUptSkJhNlEycHBLZmh3WFA2ZmU5VmxHZks1UUNVOUZnckZNdXZ2MTY2dFg1Nl8yWDRrcTRlT0tHYkhyRGZINTEzU25iYVFRMzJGeUZIdlc4LU9GbmpQYmtmU3lJT2VVZ1UzLVd3R25ZQ092SUVnPT0=
|
||||
Connector_AiAnthropic_API_SECRET = PROD_ENC:Z0FBQUFBQm8xSU5pNTA1RkZ3UllCOXVsNVZzbkw2Rkl1TWxCZ0wwWEVXUm9ReUhBcVl1cGFUdW9FRVh4elVxR0x3NVRxZkc4SkxHVFdzSU1YNG5Rb0FqSHJhdElwWm1iLWdubTVDcUl3UkVjVHNoU0xLa0ZTSFlfTlJUVXg4cVVwUWdlVDBTSFU5SnBzS0ZnVjlQcmtiNzV2UTNMck1IakZ0OWlubUtlWDZnMk4yX2JsZ1U4Wm1yT29fM2d2NVBNOWNBbWtTRWNyQ2tZNjhwSVF6bG5SU3dTenR2MzA3Z19NUT09
|
||||
Connector_AiPerplexity_API_SECRET = PROD_ENC:Z0FBQUFBQm82Mzk2Q1FGRkJEUkI4LXlQbHYzT2RkdVJEcmM4WGdZTWpJTEhoeUF1NW5LUVpJdDBYN3k1WFN4a2FQSWJSQmd0U0xJbzZDTmFFN05FcXl0Z3V1OEpsZjYydV94TXVjVjVXRTRYSWdLMkd5XzZIbFV6emRCZHpuOUpQeThadE5xcDNDVGV1RHJrUEN0c1BBYXctZFNWcFRuVXhRPT0=
|
||||
Connector_AiTavily_API_SECRET = PROD_ENC:Z0FBQUFBQm8xSU5pMjhJNS1CZFJubUlkN3ZrTUoxR0Y1QzJFWEJSMk0wQkI0UndqOW1UelVieWhGaTVBcHoxRXo1VjRzVVRROHFIeHMyS3Q5cDZCeUlEMzE1ZlhVTmNveFk5VmFQMm80NTRyVW1TZHVsR3dUN0RtMnd4LW1VWlpqOXJPeXZBTmg4OEM=
|
||||
|
||||
# Agent Mail configuration
|
||||
Service_MSFT_CLIENT_ID = c7e7112d-61dc-4f3a-8cd3-08cc4cd7504c
|
||||
Service_MSFT_CLIENT_SECRET = PROD_ENC:Z0FBQUFBQm83T29rSzdYLTRydXN5V3lQLXhmRjMyQ1FOaGpuek45QllaX1REN2s5aWNIUl81NGlrYlJTeFV0RlRZd0xPcm5uMDM4QlpibHJQbm5XZTlWeWxfcWNVdFpCUHI2amh0MVBnZ21IN2ptSkhWLTVfaHEwNmI5SEtiS05pQmt5eV8yMnhLMEc=
|
||||
Service_MSFT_TENANT_ID = common
|
||||
|
||||
# Google Service configuration
|
||||
Service_GOOGLE_CLIENT_ID = 354925410565-aqs2b2qaiqmm73qpjnel6al8eid78uvg.apps.googleusercontent.com
|
||||
Service_GOOGLE_CLIENT_SECRET = PROD_ENC:Z0FBQUFBQm8xSU5pV2JEV0lNUXhwa1VTUGh2RWcyYnJHSFQyTmdBOEhwRkJWc3MwOFZlcHJGUmlGOVVFbG1XalNyUXVuaExESy1xeFNIQlRiSFVIWTB6Rm1fNFg0OHZZSkF4ZlBIcFZDMjZHcFRERXJ0WlVFclhHa29Za1BqWGxsM05NZGFRc1BLZnE=
|
||||
|
||||
# Google Cloud Speech Services configuration
|
||||
Connector_GoogleSpeech_API_KEY_SECRET = PROD_ENC:Z0FBQUFBQm8xSU5pNjlJdmFMeERXUUQzR0duRUY4cGRZRzdwQlpnVFAzSzQ5cHZNRnVUZ0xWd3dQMHR3QjVsdF92NmdUQlJGRk1RcG1RYWZzcE9RbEhjQmR5Yk5Ud3ZKTW5jbmpEVGJ2ZkxVeVJpcUxaT2lNREFXaks5WHg5aVlHcXlUZldMdnZGYklHWjlJOWJ6Wm5RSkNmdm5feENjS1E0QUVXTTE5SW5sNFBEeTJ1RjRmVm9SQUNIYmF2U1U2dklsbTVlWFpCcHMwTFF1SUg5NmNfcWhQRFlpeWt0U19HMXNuUHd2RFdrVl9XdUFaY0hWdVBPYWlybU1CdGlCN1A0RzZBbi1IUVJ1TWMxTE9Ea09sTURhcDFZb1JIUW1zUFJybW15MDcxOUtfVXA2N0xwMnFrczA1YTJaN05pRHhOYWNzMjVmUHdhbVdlemF3TEIzN0pJaVo3bGJBMXJnZmNYTXVJVDdmYkRXWTlBT2F2NmN4eTlteUI1SlJTOXc2WWFWUTBCZTJBVHRLVDhEVjBFeHE0Nmk1YkxYd3N3RXgtVUdGdlZFSmk4dHM0QjFmbktsQTctbmJMT0MtMDlKS1pUR0pELXBxckhULUUycjlBZmVJQjFrM0xEUm50U2ZabExtVjZ1WWZ1WnlobUZIOVlndjNydUZfczJUWVVRZURTd1lYazllaER4VU10cXUyVS1ZNG9Ha2hnbTAzOEpGMklFSWpWeVV5eFB2UlVWYmJJakZnOVM2R2lJSXRSM3VzVEZZNUVpNmVjRzdXRUJsT2hzcjhZWERFeGV5c1dFQVM3dkhGY2Q3ckNBRDZCcVdhZnZkdzM3QVNpODZYWE81TEIyZGUycldkSVRvbm5hR3Jib2UzOEtXdUpHQ2FyWDQtMDdQbC1ycEdfUzdXd0U2dHFIVjhoRDJ0YkNsWUpva1dzOGNPdXRpZjVwUldtT3FVN3RrZUhTN3JfX1M3LU9PaXZELWkzRmtMbjgxZGZ6ZjVJNW9RZW1nM2hqUXo4Z2I5Z2tSVTVMdUNLblRxOGQ1Y3F4SGZIbWo4YkFBV3FIbjB6LUxGNHdsQWgxQUM4bzVrblBObFFfVWNaQ3QwejQ1eGFlSXVIcXlyVEZEdzVKNV9pd2o4RW1UVjlqb3VMWnF0V1JTcWF1R0RjdUNjM2lLUHRqZDl2WWtXUnhmbVdxeHA3REFHTkdkMjM4LTllajBWQnd3RHlFSVdiUThfQnduOVFJdmR6OUVGN1lOYjBqclhadHozX21kRzlUT2EtWVBkYWFRSjRGdW80dmlEUTVrVjhWbjJYNGtCeGNtNzRHQXJsRlZyWjBYdHltVDM2MV9IT0RFT2dLLTVBREtsS09HdUxrODRLcEQ1TmRoVDh6WmgybGc5MzgtbmJSYThQd3FFaUcxbmg3eE95RkJVX2hHM20wT1k2c21qd24wSkFWNGROaklQeHZrc21PdTVsdHVxR0pxd3Ztb1NQVHEtd25URHRNa1pqa3BLdVdkTnNFeDNManJST0dOb1RWM2hqekxFTlFSZkd6TlZBY1VQT1NFOVlDQzlPQWVlVXQ4MW0wdGkzd0Myam1lSWE2aEtVVTVNc3N3dENpa1BWRl9ZQ3daYllONWRmRUF0THpleFRmdWRqTFM2aldmLUFuZzFGdkFQNHR6d21SdzRGQ0Q4cU8yV0xGUTVUY01TZlYxSzZ4cmtfUGZvVDhmYmNBX1pibTVTcl9lenJoME9KSnBucUxPRU1PRXBmLWFENEgwRWZOU0RvRDlvQk9ueVp0dXJrUVgtQUk5VldVbV9MS19PYmlua3liWl80Z2hMcFRnTXBkZDA3enIxRWFzaU56TEZKa0hPQUtNY0dCY1pnQ2V3Zml6ZFczWFBESUlLd3BSVEs5ZXlGLUpINDRsd1NBVjBkR1dvbE8wLWZBeEhFQ0hvY3E5UGJsTDdteGdSRjBIZTRobXpsd29PMmhKQkxXY3Znd2FMdWtZU1VkQlVRZXlSZ3FaVnNqcXpwR3N3SktOTDA3aUZIcE9TR1VDcXdaTDhQX2E5VDlwckoyX0xlNmFQcnoydEkwc0s1S08yaVlsM0pwYktUVWl3LU5hQzF2UVZNSm9ZR3QyQWdrUXB2a25QNzhkVEFOYmZ0b1BmTXRCMmVQZTAtYzdOeUlBYlNINlZNZW1nUTFfSV92UlJiWGt6Qms1c1hBc3kzZkVRMzEwNVJDOS1JeVg4YWtVeUJyOTZPQ0FnSUs1Z25sMlY0S1V1c0dIWEpuX2pMQmZ4Z29SY1U0bVZscXNWcjJwRy1UZEFYSXBzQURGblRTelBybU5BeDF6N3hZLXZwSHBkMmlzbHZWN2JkU3hRcE0zQ0hna3QwYWlJX3hBdGcxUHdGRE55cndUNHRvbXU5VTRMRmZDRjhvXzIwajI1Y0RCcmR2OV94cS1XYkNwalNHS2lObHlkNGZBbklycnZMSlJYVnlfakRXb1ZfWUo2MGxzYUNIektYeENGTkUzMUJXRE9WRHRrY2o5UFJHckZza2RQbjNPUkstbG9GZG4yNmxKeEdtbHo4WDZFc0lvT01wZkxuN29ycXl3X1hTN1prRGdvWG9hRFYwNzBwVVpuMW0wQlZYbGZxZjFQUHp2XzBQT3Fqa3lzejVKZmJDMG0wRzhqWV9HY1dxaXB2VFNQUzV2LUJSOXRFRUllak83cUI3RGUtYVBJakF1YUVOV0otT1BxUHJqS0NLdFVHc0tsT2RGcWd6UTU4Yi1kc0JZS1VPT1NXSlc3TDM5ZDVEZlRDOURZU1hMT0YxZ25ndVBUaG1VcGsxWFZSS1RxT1ZZTU1vclZjVU5iYmZMd0VBTXlvdTE0YjdoclZ6ZnNKMmE2Yy1ORmNCMnJNX3dwcVJSN2RSd2d6aENLRXQyTjhkcDlLTFVZMHBydFowNTJoZm1mVHNRVHI1YjhTNnl1Vll4dFZhenZfa0dybk9KYVh6LUluSUo0djUzRFNEdzBoVGt5UU9tMlg5UnBLbk9WaEhoU2txY2tUSXJmemlmNEExb3Q1blI5bE9adHluWVI3NXZQNUtXdmpra05aNy15dTBXdlVqcXhteFVqSXFxNnlQR2FGeVNONkx3NVpQUk1FNk5yTUY4T1hQV1FCdm9PYzdFTGl4QXZkODltSlprbGJ6cWREcEM1VlNwN3V5aWdWYXNkekk4X3U0cjJjZ1k2X190cmNnMlpMQVlLdExxM3pFNkZudVFKci1CalE1U3kzdmotQ01LV0ZzWnp0VUxRblhkdlN6VG1MWHNQdGlrNmF4RnFtd0c3UXNqZFVRZTRFMGl1NFU5T2k3VEpjZXA1U052VkJtdUhDWEpTaDRGQnM0SDQwY2IxdDVNbUtELTQ0R0s0OHpfTHdFOHZ0VmRMTC1FUVpPSkJ4QXRWNnl5MURUdjVyUk53emRwbDBxUnloUmlheXhKY3RBUG1mX3JxM2w0VlZvcE40b2ROeG15NS01RFlvUHdoYllLNVhCZUNEd0dwQnFCLVdZU0RhVEFzR2gxTVpub3FGRnl4VDNiSVZrTnpMQUlxeGJGQzh5WlNZR2NKbklHRVRTaVJ2REduN0hXaGo5MHFGb1FOa0U5TUFwQ09zOXVWMnRRNVlJWmZpaTUxLWFIeWR0UEFtaVNDX1k5Q1p3Y2V4ckVXQVBRYzV1eGwwMWd0SE15WUxiYzUyLTUzTGlyTUhZUDFlRTFjcFpieWQwU0pxRWJXSE53Nkd5aHp5T28wZVd6Z1phLTQ4TmgxU3hvNHpySzExUk5WZlFFS3VpOXNHMDdZU0gzSGxYUlU4WmgwNUlPdlhQcUI0cGtITmQ4SlByczN0THUxNHc0a21vUEp6S1hLNnFRNmFfdlpmUWpJQ1VNYXVEOW1abzlsd2RoRG5pVXRVbjBKV2RFTGFEa3ZYTHByOTJjalc1b3hTWkFmS2RPdVlTUTVkRkpSTnZsMWtnYWZEUm1SR3lBemdON2xiN3pkZlNfX2NSYU5wWHNybHh4V0lnNHJjQ2NON1hiRHMycUdmNC1kay13bUE0OTBPN0xmNDA1NlQxVmRySEJvM1VUN2Y2Sl9KX2pZVHRPWEdfR2RYNUoxY01Va3pXb2VBd3lZb3BSXzU5NVJfWlhEYXFSVDJrUnFHWG42RVZJUVQ2RlJWUEkyQnRnREI3eHNiRERiQ3FUczJsRTBDZ3pUUGZPcjExZUFKc21QUWxVYVBmV2hPZXRGd3lJX3ZTczhCVG1jWFVwanhIZHlyTTdiR2c5cTBVSXBRV1U4ZExtWWdub1pTSHU0cU5aYWJVWmExbXI0MjE3WUVnPT0=
|
||||
|
||||
# Feature SyncDelta JIRA configuration
|
||||
Feature_SyncDelta_JIRA_DELTA_TOKEN_SECRET = PROD_ENC:Z0FBQUFBQm8xSU5pTDhnTVNzRUhScU8wYnZsZk52bHFkSWxLc18xQmtCeC1HbnNwTzVBbXRNTmQzRjZYaGE2MVlCNGtnWDk1T2I5VXVKNHpKU1VRbXEyN2tRWUJnU2ltZE5qZ3lmNEF6Z1hMTTEwZkk2NUNBYjhmVTJEcWpRUW9HNEVpSGFWdjBWQXQ3eUtHUTFJS3U5QWpaeno0RFNhMUxnPT0=
|
||||
|
||||
# Debug Configuration
|
||||
APP_DEBUG_CHAT_WORKFLOW_ENABLED = FALSE
|
||||
APP_DEBUG_CHAT_WORKFLOW_DIR = ./test-chat
|
||||
102
modules/aicore/aicoreBase.py
Normal file
102
modules/aicore/aicoreBase.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""
|
||||
Base connector interface for AI connectors.
|
||||
All AI connectors should inherit from this class.
|
||||
|
||||
IMPORTANT: Model Registration Requirements
|
||||
- Each model must have a unique displayName across all connectors
|
||||
- The displayName is used as the unique identifier in the model registry
|
||||
- The name field is used for API calls (can be duplicated across different model instances)
|
||||
- If duplicate displayNames are detected during registration, an error will be raised
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
from modules.datamodels.datamodelAi import AiModel
|
||||
|
||||
|
||||
class BaseConnectorAi(ABC):
|
||||
"""
|
||||
Base class for all AI connectors.
|
||||
|
||||
IMPORTANT: Models returned by getModels() must have unique displayName values.
|
||||
The displayName serves as the unique identifier in the model registry.
|
||||
Duplicate displayNames will cause registration to fail with an error.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._models_cache: Optional[List[AiModel]] = None
|
||||
self._last_cache_update: Optional[float] = None
|
||||
self._cache_ttl: float = 300.0 # 5 minutes cache TTL
|
||||
|
||||
@abstractmethod
|
||||
def getModels(self) -> List[AiModel]:
|
||||
"""
|
||||
Get all available models for this connector.
|
||||
Should be implemented by each connector.
|
||||
|
||||
IMPORTANT: Each model's displayName must be unique across all connectors.
|
||||
If multiple models share the same API name (e.g., "gpt-4o"), they must have
|
||||
different displayNames (e.g., "OpenAI GPT-4o" vs "OpenAI GPT-4o Instance Vision").
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def getConnectorType(self) -> str:
|
||||
"""
|
||||
Get the connector type identifier.
|
||||
Should return one of: openai, anthropic, perplexity, tavily
|
||||
"""
|
||||
pass
|
||||
|
||||
def getCachedModels(self) -> List[AiModel]:
|
||||
"""
|
||||
Get cached models with TTL check.
|
||||
Returns cached models if still valid, otherwise refreshes cache.
|
||||
"""
|
||||
import time
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# Check if cache is valid
|
||||
if (self._models_cache is not None and
|
||||
self._last_cache_update is not None and
|
||||
current_time - self._last_cache_update < self._cache_ttl):
|
||||
return self._models_cache
|
||||
|
||||
# Refresh cache
|
||||
self._models_cache = self.getModels()
|
||||
self._last_cache_update = current_time
|
||||
|
||||
return self._models_cache
|
||||
|
||||
def clearCache(self):
|
||||
"""Clear the models cache."""
|
||||
self._models_cache = None
|
||||
self._last_cache_update = None
|
||||
|
||||
def getModelByDisplayName(self, displayName: str) -> Optional[AiModel]:
|
||||
"""Get a specific model by displayName (displayName must be unique)."""
|
||||
models = self.getCachedModels()
|
||||
for model in models:
|
||||
if model.displayName == displayName:
|
||||
return model
|
||||
return None
|
||||
|
||||
def getModelByName(self, name: str) -> Optional[AiModel]:
|
||||
"""Get a specific model by name (API name). Note: name can be duplicated, returns first match."""
|
||||
models = self.getCachedModels()
|
||||
for model in models:
|
||||
if model.name == name:
|
||||
return model
|
||||
return None
|
||||
|
||||
|
||||
def getModelsByPriority(self, priority: str) -> List[AiModel]:
|
||||
"""Get models that have a specific priority."""
|
||||
models = self.getCachedModels()
|
||||
return [model for model in models if model.priority == priority]
|
||||
|
||||
def getAvailableModels(self) -> List[AiModel]:
|
||||
"""Get only available models."""
|
||||
models = self.getCachedModels()
|
||||
return [model for model in models if model.isAvailable]
|
||||
202
modules/aicore/aicoreModelRegistry.py
Normal file
202
modules/aicore/aicoreModelRegistry.py
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
"""
|
||||
Dynamic model registry that collects models from all AI connectors.
|
||||
Implements plugin-like architecture for connector discovery.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import importlib
|
||||
import os
|
||||
from typing import Dict, List, Optional, Any
|
||||
from modules.datamodels.datamodelAi import AiModel
|
||||
from modules.aicore.aicoreBase import BaseConnectorAi
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""Dynamic registry for AI models from all connectors."""
|
||||
|
||||
def __init__(self):
|
||||
self._models: Dict[str, AiModel] = {}
|
||||
self._connectors: Dict[str, BaseConnectorAi] = {}
|
||||
self._lastRefresh: Optional[float] = None
|
||||
self._refreshInterval: float = 300.0 # 5 minutes
|
||||
|
||||
def registerConnector(self, connector: BaseConnectorAi):
|
||||
"""Register a connector and collect its models."""
|
||||
connectorType = connector.getConnectorType()
|
||||
|
||||
# If connector already registered, skip re-registration to avoid duplicate models
|
||||
if connectorType in self._connectors:
|
||||
logger.debug(f"Connector {connectorType} already registered, skipping re-registration")
|
||||
return
|
||||
|
||||
self._connectors[connectorType] = connector
|
||||
|
||||
# Collect models from this connector
|
||||
try:
|
||||
models = connector.getCachedModels()
|
||||
for model in models:
|
||||
# Validate displayName uniqueness
|
||||
if model.displayName in self._models:
|
||||
existingModel = self._models[model.displayName]
|
||||
errorMsg = f"Duplicate displayName '{model.displayName}' detected! Existing model: displayName='{existingModel.displayName}', name='{existingModel.name}' (connector: {existingModel.connectorType}), New model: displayName='{model.displayName}', name='{model.name}' (connector: {connectorType}). displayName must be unique."
|
||||
logger.error(errorMsg)
|
||||
raise ValueError(errorMsg)
|
||||
|
||||
# Use displayName as the key (must be unique)
|
||||
self._models[model.displayName] = model
|
||||
logger.debug(f"Registered model: {model.displayName} (name: {model.name}) from {connectorType}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register models from {connectorType}: {e}")
|
||||
raise
|
||||
|
||||
def discoverConnectors(self) -> List[BaseConnectorAi]:
|
||||
"""Auto-discover connectors by scanning aicorePlugin*.py files."""
|
||||
connectors = []
|
||||
connectorDir = os.path.dirname(__file__)
|
||||
|
||||
# Scan for connector files
|
||||
for filename in os.listdir(connectorDir):
|
||||
if filename.startswith('aicorePlugin') and filename.endswith('.py'):
|
||||
moduleName = filename[:-3] # Remove .py extension
|
||||
|
||||
try:
|
||||
# Import the module
|
||||
module = importlib.import_module(f'modules.aicore.{moduleName}')
|
||||
|
||||
# Find connector classes (classes that inherit from BaseConnectorAi)
|
||||
for attrName in dir(module):
|
||||
attr = getattr(module, attrName)
|
||||
if (isinstance(attr, type) and
|
||||
issubclass(attr, BaseConnectorAi) and
|
||||
attr != BaseConnectorAi):
|
||||
|
||||
# Instantiate the connector
|
||||
connector = attr()
|
||||
connectors.append(connector)
|
||||
logger.info(f"Discovered connector: {connector.getConnectorType()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to discover connector from {filename}: {e}")
|
||||
|
||||
return connectors
|
||||
|
||||
def refreshModels(self, force: bool = False):
|
||||
"""Refresh models from all registered connectors."""
|
||||
import time
|
||||
|
||||
currentTime = time.time()
|
||||
|
||||
# Check if refresh is needed
|
||||
if (not force and
|
||||
self._lastRefresh is not None and
|
||||
currentTime - self._lastRefresh < self._refreshInterval):
|
||||
return
|
||||
|
||||
logger.info("Refreshing model registry...")
|
||||
|
||||
# Clear existing models
|
||||
self._models.clear()
|
||||
|
||||
# Re-register all connectors
|
||||
for connector in self._connectors.values():
|
||||
try:
|
||||
connector.clearCache() # Clear connector cache
|
||||
models = connector.getCachedModels()
|
||||
for model in models:
|
||||
# Validate displayName uniqueness
|
||||
if model.displayName in self._models:
|
||||
existingModel = self._models[model.displayName]
|
||||
errorMsg = f"Duplicate displayName '{model.displayName}' detected! Existing model: displayName='{existingModel.displayName}', name='{existingModel.name}' (connector: {existingModel.connectorType}), New model: displayName='{model.displayName}', name='{model.name}' (connector: {connector.getConnectorType()}). displayName must be unique."
|
||||
logger.error(errorMsg)
|
||||
raise ValueError(errorMsg)
|
||||
|
||||
# Use displayName as the key (must be unique)
|
||||
self._models[model.displayName] = model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to refresh models from {connector.getConnectorType()}: {e}")
|
||||
raise
|
||||
|
||||
self._lastRefresh = currentTime
|
||||
logger.info(f"Model registry refreshed: {len(self._models)} models available")
|
||||
|
||||
def getModel(self, displayName: str) -> Optional[AiModel]:
|
||||
"""Get a specific model by displayName (displayName must be unique)."""
|
||||
self.refreshModels()
|
||||
return self._models.get(displayName)
|
||||
|
||||
def getModels(self) -> List[AiModel]:
|
||||
"""Get all available models."""
|
||||
self.refreshModels()
|
||||
return list(self._models.values())
|
||||
|
||||
def getModelsByConnector(self, connectorType: str) -> List[AiModel]:
|
||||
"""Get models from a specific connector."""
|
||||
self.refreshModels()
|
||||
return [model for model in self._models.values() if model.connectorType == connectorType]
|
||||
|
||||
|
||||
def getModelsByPriority(self, priority: str) -> List[AiModel]:
|
||||
"""Get models that have a specific priority."""
|
||||
self.refreshModels()
|
||||
return [model for model in self._models.values() if model.priority == priority]
|
||||
|
||||
def getAvailableModels(self) -> List[AiModel]:
|
||||
"""Get only available models."""
|
||||
self.refreshModels()
|
||||
allModels = list(self._models.values())
|
||||
availableModels = [model for model in allModels if model.isAvailable]
|
||||
unavailableCount = len(allModels) - len(availableModels)
|
||||
if unavailableCount > 0:
|
||||
unavailableModels = [m.name for m in allModels if not m.isAvailable]
|
||||
logger.debug(f"getAvailableModels: {len(availableModels)} available, {unavailableCount} unavailable. Unavailable: {unavailableModels}")
|
||||
logger.debug(f"getAvailableModels: Returning {len(availableModels)} models: {[m.name for m in availableModels]}")
|
||||
return availableModels
|
||||
|
||||
def getConnectorForModel(self, displayName: str) -> Optional[BaseConnectorAi]:
|
||||
"""Get the connector instance for a specific model by displayName."""
|
||||
model = self.getModel(displayName)
|
||||
if model:
|
||||
return self._connectors.get(model.connectorType)
|
||||
return None
|
||||
|
||||
def getModelStats(self) -> Dict[str, Any]:
|
||||
"""Get statistics about the model registry."""
|
||||
self.refreshModels()
|
||||
|
||||
stats = {
|
||||
"totalModels": len(self._models),
|
||||
"availableModels": len([m for m in self._models.values() if m.isAvailable]),
|
||||
"connectors": len(self._connectors),
|
||||
"byConnector": {},
|
||||
"byCapability": {},
|
||||
"byPriority": {}
|
||||
}
|
||||
|
||||
# Count by connector
|
||||
for model in self._models.values():
|
||||
connector = model.connectorType
|
||||
if connector not in stats["byConnector"]:
|
||||
stats["byConnector"][connector] = 0
|
||||
stats["byConnector"][connector] += 1
|
||||
|
||||
# Count by capability
|
||||
for model in self._models.values():
|
||||
for capability in model.capabilities:
|
||||
if capability not in stats["byCapability"]:
|
||||
stats["byCapability"][capability] = 0
|
||||
stats["byCapability"][capability] += 1
|
||||
|
||||
# Count by priority
|
||||
for model in self._models.values():
|
||||
priority = model.priority
|
||||
if priority not in stats["byPriority"]:
|
||||
stats["byPriority"][priority] = 0
|
||||
stats["byPriority"][priority] += 1
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# Global registry instance
|
||||
modelRegistry = ModelRegistry()
|
||||
279
modules/aicore/aicoreModelSelector.py
Normal file
279
modules/aicore/aicoreModelSelector.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
"""
|
||||
Simplified model selection based on model properties and priority-based sorting.
|
||||
No complex rules needed - just filter by properties and sort by priority!
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from modules.datamodels.datamodelAi import AiModel, AiCallOptions, OperationTypeEnum, PriorityEnum, ProcessingModeEnum
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelSelector:
|
||||
"""Simple model selector based on properties and priority-based sorting."""
|
||||
|
||||
def __init__(self):
|
||||
logger.info("ModelSelector initialized with simplified approach")
|
||||
|
||||
def selectModel(self,
|
||||
prompt: str,
|
||||
context: str,
|
||||
options: AiCallOptions,
|
||||
availableModels: List[AiModel]) -> Optional[AiModel]:
|
||||
"""
|
||||
Select the best model using simple filtering and priority-based sorting.
|
||||
|
||||
Args:
|
||||
prompt: User prompt
|
||||
context: Context data
|
||||
options: AI call options
|
||||
availableModels: List of available models
|
||||
|
||||
Returns:
|
||||
Best model for the request, or None if no suitable model found
|
||||
"""
|
||||
try:
|
||||
# Get failover models (which includes all filtering and sorting)
|
||||
failoverModelList = self.getFailoverModelList(prompt, context, options, availableModels)
|
||||
|
||||
if not failoverModelList:
|
||||
logger.warning("No suitable models found for the request")
|
||||
return None
|
||||
|
||||
selectedModel = failoverModelList[0] # First model is the best one
|
||||
logger.info(f"Selected model: {selectedModel.name} (quality: {selectedModel.qualityRating}, cost: ${selectedModel.costPer1kTokensInput:.4f})")
|
||||
return selectedModel
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error selecting model: {str(e)}")
|
||||
return None
|
||||
|
||||
def getFailoverModelList(self,
|
||||
prompt: str,
|
||||
context: str,
|
||||
options: AiCallOptions,
|
||||
availableModels: List[AiModel]) -> List[AiModel]:
|
||||
"""
|
||||
Get prioritized list of models using scoring-based ranking.
|
||||
|
||||
Args:
|
||||
prompt: User prompt
|
||||
context: Context data
|
||||
options: AI call options
|
||||
availableModels: List of available models
|
||||
|
||||
Returns:
|
||||
List of models sorted by score (descending)
|
||||
"""
|
||||
try:
|
||||
promptSize = len(prompt.encode("utf-8"))
|
||||
contextSize = len(context.encode("utf-8"))
|
||||
totalSize = promptSize + contextSize
|
||||
# Convert bytes to approximate tokens (1 token ≈ 4 bytes)
|
||||
promptTokens = promptSize / 4
|
||||
contextTokens = contextSize / 4
|
||||
totalTokens = totalSize / 4
|
||||
|
||||
logger.debug(f"Request sizes - Prompt: {promptTokens:.0f} tokens ({promptSize} bytes), Context: {contextTokens:.0f} tokens ({contextSize} bytes), Total: {totalTokens:.0f} tokens ({totalSize} bytes)")
|
||||
|
||||
# Step 1: Filter by operation type (MUST match) - check if model has this operation type
|
||||
operationFiltered = []
|
||||
for model in availableModels:
|
||||
# Check if model has the required operation type
|
||||
hasOperationType = any(ot.operationType == options.operationType for ot in model.operationTypes)
|
||||
if hasOperationType:
|
||||
operationFiltered.append(model)
|
||||
logger.debug(f"After operation type filtering: {len(operationFiltered)} models")
|
||||
|
||||
if operationFiltered:
|
||||
logger.debug(f"Models with {options.operationType.value}: {[m.name for m in operationFiltered]}")
|
||||
|
||||
# Step 2: Filter by prompt size (MUST be <= 80% of context size)
|
||||
# Note: contextLength is in tokens, so we need to compare tokens with tokens
|
||||
promptFiltered = []
|
||||
for model in operationFiltered:
|
||||
if model.contextLength == 0:
|
||||
# No context length limit - always pass
|
||||
promptFiltered.append(model)
|
||||
else:
|
||||
maxAllowedTokens = model.contextLength * 0.8
|
||||
# Compare prompt tokens (not bytes) with model's token limit
|
||||
if promptTokens <= maxAllowedTokens:
|
||||
promptFiltered.append(model)
|
||||
else:
|
||||
logger.debug(f"Model {model.name} filtered out: promptSize={promptTokens:.0f} tokens > maxAllowed={maxAllowedTokens:.0f} tokens (80% of {model.contextLength} tokens)")
|
||||
|
||||
logger.debug(f"After prompt size filtering: {len(promptFiltered)} models")
|
||||
|
||||
if not promptFiltered and operationFiltered:
|
||||
logger.warning(f"All {len(operationFiltered)} models with {options.operationType.value} were filtered out due to prompt size. Prompt: {promptTokens:.0f} tokens. Available models:")
|
||||
for model in operationFiltered:
|
||||
maxAllowed = model.contextLength * 0.8 / 4 if model.contextLength > 0 else "unlimited"
|
||||
logger.warning(f" - {model.name}: contextLength={model.contextLength} tokens, maxAllowed={maxAllowed} tokens")
|
||||
|
||||
# Step 3: Calculate scores for each model
|
||||
scoredModels = []
|
||||
for model in promptFiltered:
|
||||
score = self._calculateModelScore(model, promptSize, contextSize, totalSize, options)
|
||||
scoredModels.append((model, score))
|
||||
logger.debug(f"Model {model.name}: score={score:.3f}")
|
||||
|
||||
# Step 4: Sort by score (descending)
|
||||
scoredModels.sort(key=lambda x: x[1], reverse=True)
|
||||
sortedModels = [model for model, score in scoredModels]
|
||||
|
||||
logger.debug(f"Final sorted models: {len(sortedModels)} models")
|
||||
return sortedModels
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting failover models: {str(e)}")
|
||||
return []
|
||||
|
||||
def _calculateModelScore(self, model: AiModel, promptSize: int, contextSize: int, totalSize: int, options: AiCallOptions) -> float:
|
||||
"""
|
||||
Calculate a score for a model based on how well it fulfills the criteria.
|
||||
Operation type rating is the PRIMARY sorting criteria (multiplied by 1000).
|
||||
|
||||
Args:
|
||||
model: The model to score
|
||||
promptSize: Size of the prompt in bytes
|
||||
contextSize: Size of the context in bytes
|
||||
totalSize: Total size (prompt + context) in bytes
|
||||
options: AI call options
|
||||
|
||||
Returns:
|
||||
Score for the model (higher is better)
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
# 1. PRIMARY: Operation Type Rating (multiplied by 1000 for primary sorting)
|
||||
operationTypeRating = self._getOperationTypeRating(model, options.operationType)
|
||||
score += operationTypeRating * 1000.0 # Primary sorting criteria
|
||||
|
||||
# 2. Prompt + Context size rating
|
||||
if model.contextLength > 0:
|
||||
modelMaxSize = model.contextLength * 0.8 # 80% of model context length
|
||||
if totalSize <= modelMaxSize:
|
||||
# Within limits: rating = (prompt+contextsize) / (80% modelsize)
|
||||
score += totalSize / modelMaxSize
|
||||
else:
|
||||
# Exceeds limits: rating = modelsize / (prompt+contextsize) (ensures minimum chunks)
|
||||
score += modelMaxSize / totalSize
|
||||
else:
|
||||
# No context length limit
|
||||
score += 1.0
|
||||
|
||||
# 3. Processing Mode rating
|
||||
if hasattr(options, 'processingMode') and options.processingMode:
|
||||
score += self._getProcessingModeRating(model.processingMode, options.processingMode)
|
||||
else:
|
||||
score += 1.0 # No preference
|
||||
|
||||
# 4. Priority rating
|
||||
if hasattr(options, 'priority') and options.priority:
|
||||
score += self._getPriorityRating(model, options.priority)
|
||||
else:
|
||||
score += 1.0 # No preference
|
||||
|
||||
return score
|
||||
|
||||
def _getOperationTypeRating(self, model: AiModel, operationType: OperationTypeEnum) -> float:
|
||||
"""
|
||||
Get the operation type rating for a model.
|
||||
|
||||
Args:
|
||||
model: The model to check
|
||||
operationType: The operation type to get rating for
|
||||
|
||||
Returns:
|
||||
Rating (1-10) or 0 if model doesn't support this operation type
|
||||
"""
|
||||
for ot_rating in model.operationTypes:
|
||||
if ot_rating.operationType == operationType:
|
||||
return float(ot_rating.rating)
|
||||
return 0.0 # Model doesn't support this operation type
|
||||
|
||||
def _getProcessingModeRating(self, modelMode: ProcessingModeEnum, requestedMode: ProcessingModeEnum) -> float:
|
||||
"""Get processing mode rating based on compatibility."""
|
||||
if modelMode == requestedMode:
|
||||
return 1.0
|
||||
|
||||
# Compatibility matrix
|
||||
if requestedMode == ProcessingModeEnum.BASIC:
|
||||
if modelMode == ProcessingModeEnum.ADVANCED:
|
||||
return 0.5
|
||||
elif modelMode == ProcessingModeEnum.DETAILED:
|
||||
return 0.2
|
||||
|
||||
elif requestedMode == ProcessingModeEnum.ADVANCED:
|
||||
if modelMode == ProcessingModeEnum.BASIC:
|
||||
return 0.2
|
||||
elif modelMode == ProcessingModeEnum.DETAILED:
|
||||
return 0.5
|
||||
|
||||
elif requestedMode == ProcessingModeEnum.DETAILED:
|
||||
if modelMode == ProcessingModeEnum.BASIC:
|
||||
return 0.2
|
||||
elif modelMode == ProcessingModeEnum.ADVANCED:
|
||||
return 0.5
|
||||
|
||||
return 0.0 # No compatibility
|
||||
|
||||
def _getPriorityRating(self, model: AiModel, requestedPriority: PriorityEnum) -> float:
|
||||
"""Get priority rating based on model capabilities."""
|
||||
if requestedPriority == PriorityEnum.BALANCED:
|
||||
return 1.0
|
||||
|
||||
elif requestedPriority == PriorityEnum.SPEED:
|
||||
return model.speedRating / 10.0
|
||||
|
||||
elif requestedPriority == PriorityEnum.QUALITY:
|
||||
return model.qualityRating / 10.0
|
||||
|
||||
elif requestedPriority == PriorityEnum.COST:
|
||||
# Cost priority: cost gives 1, speed gives 0.5, quality gives 0.2
|
||||
# Lower cost is better, so we invert the cost rating
|
||||
costRating = 1.0 - (model.costPer1kTokensInput / 0.1) # Normalize to 0-1
|
||||
costRating = max(0, costRating) # Ensure non-negative
|
||||
|
||||
speedRating = model.speedRating / 10.0 * 0.5
|
||||
qualityRating = model.qualityRating / 10.0 * 0.2
|
||||
|
||||
return costRating + speedRating + qualityRating
|
||||
|
||||
return 1.0 # Default
|
||||
|
||||
def _getSizeRating(self, model: AiModel, totalSize: int) -> float:
|
||||
"""Get size rating for a model based on total input size."""
|
||||
if model.contextLength > 0:
|
||||
modelMaxSize = model.contextLength * 0.8 # 80% of model context length
|
||||
if totalSize <= modelMaxSize:
|
||||
# Within limits: rating = (prompt+contextsize) / (80% modelsize)
|
||||
return totalSize / modelMaxSize
|
||||
else:
|
||||
# Exceeds limits: rating = modelsize / (prompt+contextsize) (ensures minimum chunks)
|
||||
return modelMaxSize / totalSize
|
||||
else:
|
||||
# No context length limit
|
||||
return 1.0
|
||||
|
||||
|
||||
def _logModelDetails(self, model: AiModel):
|
||||
"""Log detailed information about a model."""
|
||||
logger.info(f"Model: {model.name}")
|
||||
logger.info(f" Display Name: {model.displayName}")
|
||||
logger.info(f" Connector: {model.connectorType}")
|
||||
logger.info(f" Context Length: {model.contextLength}")
|
||||
logger.info(f" Max Tokens: {model.maxTokens}")
|
||||
logger.info(f" Quality Rating: {model.qualityRating}/10")
|
||||
logger.info(f" Speed Rating: {model.speedRating}/10")
|
||||
logger.info(f" Cost: ${model.costPer1kTokensInput:.4f}/1k tokens")
|
||||
logger.info(f" Priority: {model.priority}")
|
||||
logger.info(f" Processing Mode: {model.processingMode}")
|
||||
operationTypesStr = ', '.join([f"{ot.operationType.value}({ot.rating})" for ot in model.operationTypes])
|
||||
logger.info(f" Operation Types: {operationTypesStr}")
|
||||
|
||||
|
||||
# Global model selector instance
|
||||
modelSelector = ModelSelector()
|
||||
376
modules/aicore/aicorePluginAnthropic.py
Normal file
376
modules/aicore/aicorePluginAnthropic.py
Normal file
|
|
@ -0,0 +1,376 @@
|
|||
import logging
|
||||
import httpx
|
||||
import os
|
||||
from typing import Dict, Any, List
|
||||
from fastapi import HTTPException
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.aicore.aicoreBase import BaseConnectorAi
|
||||
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def loadConfigData():
|
||||
"""Load configuration data for Anthropic connector"""
|
||||
return {
|
||||
"apiKey": APP_CONFIG.get('Connector_AiAnthropic_API_SECRET'),
|
||||
}
|
||||
|
||||
class AiAnthropic(BaseConnectorAi):
|
||||
"""Connector for communication with the Anthropic API."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Load configuration
|
||||
self.config = loadConfigData()
|
||||
self.apiKey = self.config["apiKey"]
|
||||
|
||||
# HttpClient for API calls
|
||||
self.httpClient = httpx.AsyncClient(
|
||||
timeout=120.0, # Longer timeout for complex requests
|
||||
headers={
|
||||
"x-api-key": self.apiKey,
|
||||
"anthropic-version": "2023-06-01", # Anthropic API Version
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("Anthropic Connector initialized")
|
||||
|
||||
def getConnectorType(self) -> str:
|
||||
"""Get the connector type identifier."""
|
||||
return "anthropic"
|
||||
|
||||
def getModels(self) -> List[AiModel]:
|
||||
"""Get all available Anthropic models."""
|
||||
return [
|
||||
AiModel(
|
||||
name="claude-sonnet-4-5-20250929",
|
||||
displayName="Anthropic Claude Sonnet 4.5",
|
||||
connectorType="anthropic",
|
||||
apiUrl="https://api.anthropic.com/v1/messages",
|
||||
temperature=0.2,
|
||||
maxTokens=8192,
|
||||
contextLength=200000,
|
||||
costPer1kTokensInput=0.015,
|
||||
costPer1kTokensOutput=0.075,
|
||||
speedRating=6, # Slower due to high-quality processing
|
||||
qualityRating=10, # Best quality available
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self.callAiBasic,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 9),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 10),
|
||||
(OperationTypeEnum.DATA_GENERATE, 9),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 8)
|
||||
),
|
||||
version="claude-sonnet-4-5-20250929",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.015 + (bytesReceived / 4 / 1000) * 0.075
|
||||
),
|
||||
AiModel(
|
||||
name="claude-sonnet-4-5-20250929",
|
||||
displayName="Anthropic Claude Sonnet 4.5 Vision",
|
||||
connectorType="anthropic",
|
||||
apiUrl="https://api.anthropic.com/v1/messages",
|
||||
temperature=0.2,
|
||||
maxTokens=8192,
|
||||
contextLength=200000,
|
||||
costPer1kTokensInput=0.015,
|
||||
costPer1kTokensOutput=0.075,
|
||||
speedRating=6,
|
||||
qualityRating=10,
|
||||
functionCall=self.callAiImage,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.IMAGE_ANALYSE, 10)
|
||||
),
|
||||
version="claude-sonnet-4-5-20250929",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.015 + (bytesReceived / 4 / 1000) * 0.075
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
async def callAiBasic(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
Calls the Anthropic API with the given messages using standardized pattern.
|
||||
|
||||
Args:
|
||||
modelCall: AiModelCall with messages and options
|
||||
|
||||
Returns:
|
||||
AiModelResponse with content and metadata
|
||||
|
||||
Raises:
|
||||
HTTPException: For errors in API communication
|
||||
"""
|
||||
try:
|
||||
# Extract parameters from modelCall
|
||||
messages = modelCall.messages
|
||||
model = modelCall.model
|
||||
options = modelCall.options
|
||||
temperature = getattr(options, "temperature", None)
|
||||
if temperature is None:
|
||||
temperature = model.temperature
|
||||
maxTokens = model.maxTokens
|
||||
|
||||
# Transform OpenAI-style messages to Anthropic format:
|
||||
# - Move any 'system' role content to top-level 'system'
|
||||
# - Keep only 'user'/'assistant' messages in the list
|
||||
system_contents: List[str] = []
|
||||
converted_messages: List[Dict[str, Any]] = []
|
||||
for m in messages:
|
||||
role = m.get("role")
|
||||
content = m.get("content", "")
|
||||
if role == "system":
|
||||
# Collect system content; Anthropic expects top-level 'system'
|
||||
if isinstance(content, list):
|
||||
# Join text parts if provided as blocks
|
||||
joined = "\n\n".join(
|
||||
[
|
||||
(part.get("text") if isinstance(part, dict) else str(part))
|
||||
for part in content
|
||||
]
|
||||
)
|
||||
system_contents.append(joined)
|
||||
else:
|
||||
system_contents.append(str(content))
|
||||
continue
|
||||
# For Anthropic, content can be a string; pass through strings, collapse blocks
|
||||
if isinstance(content, list):
|
||||
# Collapse to text if blocks are provided
|
||||
collapsed = "\n\n".join(
|
||||
[
|
||||
(part.get("text") if isinstance(part, dict) else str(part))
|
||||
for part in content
|
||||
]
|
||||
)
|
||||
converted_messages.append({"role": role, "content": collapsed})
|
||||
else:
|
||||
converted_messages.append({"role": role, "content": content})
|
||||
|
||||
system_prompt = "\n\n".join([s for s in system_contents if s]) if system_contents else None
|
||||
|
||||
# Create Anthropic API payload
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model.name,
|
||||
"messages": converted_messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
# Anthropic requires max_tokens - use provided value or throw error
|
||||
if maxTokens is None:
|
||||
raise ValueError("maxTokens must be provided for Anthropic API calls")
|
||||
payload["max_tokens"] = maxTokens
|
||||
if system_prompt:
|
||||
payload["system"] = system_prompt
|
||||
|
||||
response = await self.httpClient.post(
|
||||
model.apiUrl,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_detail = f"Anthropic API error: {response.status_code} - {response.text}"
|
||||
logger.error(error_detail)
|
||||
|
||||
# Provide more specific error messages based on status code
|
||||
if response.status_code == 529:
|
||||
error_message = "Anthropic API is currently overloaded. Please try again in a few minutes."
|
||||
elif response.status_code == 429:
|
||||
error_message = "Rate limit exceeded. Please wait before making another request."
|
||||
elif response.status_code == 401:
|
||||
error_message = "Invalid API key. Please check your Anthropic API configuration."
|
||||
elif response.status_code == 400:
|
||||
error_message = f"Invalid request to Anthropic API: {response.text}"
|
||||
else:
|
||||
error_message = f"Anthropic API error ({response.status_code}): {response.text}"
|
||||
|
||||
raise HTTPException(status_code=500, detail=error_message)
|
||||
|
||||
# Parse response
|
||||
anthropicResponse = response.json()
|
||||
|
||||
# Extract content from response
|
||||
content = ""
|
||||
if "content" in anthropicResponse:
|
||||
if isinstance(anthropicResponse["content"], list):
|
||||
# Content is a list of parts (in newer API versions)
|
||||
for part in anthropicResponse["content"]:
|
||||
if part.get("type") == "text":
|
||||
content += part.get("text", "")
|
||||
else:
|
||||
# Direct content as string (in older API versions)
|
||||
content = anthropicResponse["content"]
|
||||
|
||||
# Debug logging for empty responses
|
||||
if not content or content.strip() == "":
|
||||
logger.warning(f"Anthropic API returned empty content. Full response: {anthropicResponse}")
|
||||
content = "[Anthropic API returned empty response]"
|
||||
|
||||
# Return standardized response
|
||||
return AiModelResponse(
|
||||
content=content,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
metadata={"response_id": anthropicResponse.get("id", "")}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e) if str(e) else f"{type(e).__name__}"
|
||||
error_detail = f"Error calling Anthropic API: {error_msg}"
|
||||
if hasattr(e, 'detail') and e.detail:
|
||||
error_detail += f" | Detail: {e.detail}"
|
||||
if hasattr(e, 'status_code'):
|
||||
error_detail += f" | Status: {e.status_code}"
|
||||
logger.error(error_detail, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=error_detail)
|
||||
|
||||
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
Analyzes an image using Anthropic's vision capabilities using standardized pattern.
|
||||
|
||||
Args:
|
||||
modelCall: AiModelCall with messages and image data in options
|
||||
|
||||
Returns:
|
||||
AiModelResponse with analysis content
|
||||
"""
|
||||
try:
|
||||
# Extract parameters from messages for Anthropic Vision API
|
||||
messages = modelCall.messages
|
||||
model = modelCall.model
|
||||
|
||||
# Verify messages contain image data
|
||||
if not messages or not messages[0].get("content"):
|
||||
raise ValueError("No messages provided for image analysis")
|
||||
|
||||
logger.info(f"callAiImage called with {len(messages)} message(s)...")
|
||||
|
||||
# Extract text prompt and image data from messages
|
||||
# Messages format: [{"role": "user", "content": [{"type": "text", "text": "..."}, {"type": "image_url", "image_url": {"url": "data:..."}}]}]
|
||||
userContent = messages[0]["content"]
|
||||
if not isinstance(userContent, list):
|
||||
raise ValueError("Expected content to be a list for vision")
|
||||
|
||||
textPrompt = ""
|
||||
imageUrl = None
|
||||
|
||||
for contentItem in userContent:
|
||||
if contentItem.get("type") == "text":
|
||||
textPrompt = contentItem.get("text", "") or ""
|
||||
elif contentItem.get("type") == "image_url":
|
||||
imageUrlDict = contentItem.get("image_url")
|
||||
if imageUrlDict and isinstance(imageUrlDict, dict):
|
||||
imageUrl = imageUrlDict.get("url", "") or ""
|
||||
else:
|
||||
imageUrl = None
|
||||
|
||||
if not imageUrl or not imageUrl.startswith("data:"):
|
||||
raise ValueError("No image data found in messages")
|
||||
|
||||
# Extract base64 data and mime type from data URL
|
||||
# Format: data:image/jpeg;base64,/9j/4AAQSkZ...
|
||||
parts = imageUrl.split(";base64,")
|
||||
if len(parts) != 2:
|
||||
raise ValueError("Invalid image data URL format")
|
||||
|
||||
mimeType = parts[0].replace("data:", "")
|
||||
base64Data = parts[1]
|
||||
|
||||
# Convert to Anthropic's vision format
|
||||
anthropicMessages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": textPrompt},
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mimeType,
|
||||
"data": base64Data
|
||||
}
|
||||
}
|
||||
]
|
||||
}]
|
||||
|
||||
# Call Anthropic API directly for vision
|
||||
import time
|
||||
import base64
|
||||
|
||||
startTime = time.time()
|
||||
|
||||
# Prepare system prompt if available
|
||||
systemPrompt = None
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
systemContent = msg.get("content")
|
||||
if isinstance(systemContent, list):
|
||||
textParts = []
|
||||
for item in systemContent:
|
||||
if item.get("type") == "text":
|
||||
textValue = item.get("text")
|
||||
if textValue is not None:
|
||||
textParts.append(str(textValue))
|
||||
if textParts:
|
||||
systemPrompt = "\n".join(textParts)
|
||||
elif systemContent is not None:
|
||||
systemPrompt = str(systemContent)
|
||||
break
|
||||
|
||||
# Get parameters from model (consistent with callAiBasic)
|
||||
maxTokens = model.maxTokens if hasattr(model, 'maxTokens') else 8192
|
||||
temperature = model.temperature if hasattr(model, 'temperature') else 0.2
|
||||
|
||||
# Prepare API payload
|
||||
payload = {
|
||||
"model": model.name, # Use standard model.name
|
||||
"max_tokens": maxTokens,
|
||||
"messages": anthropicMessages
|
||||
}
|
||||
|
||||
if systemPrompt:
|
||||
payload["system"] = systemPrompt
|
||||
|
||||
# Set temperature from model
|
||||
payload["temperature"] = temperature
|
||||
|
||||
# Make API call with headers from httpClient (which includes anthropic-version)
|
||||
response = await self.httpClient.post(
|
||||
"https://api.anthropic.com/v1/messages",
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
errorText = response.text
|
||||
logger.error(f"Anthropic API error: {response.status_code} - {errorText}")
|
||||
raise HTTPException(status_code=response.status_code, detail=f"Anthropic API error: {errorText}")
|
||||
|
||||
# Parse response
|
||||
result = response.json()
|
||||
content = result["content"][0]["text"] if result.get("content") else ""
|
||||
|
||||
endTime = time.time()
|
||||
processingTime = endTime - startTime
|
||||
|
||||
# Calculate cost
|
||||
inputTokens = result.get("usage", {}).get("input_tokens", 0)
|
||||
outputTokens = result.get("usage", {}).get("output_tokens", 0)
|
||||
|
||||
# Return standardized response
|
||||
return AiModelResponse(
|
||||
content=content,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
processingTime=processingTime
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during image analysis: {str(e)}", exc_info=True)
|
||||
return AiModelResponse(
|
||||
content="",
|
||||
success=False,
|
||||
error=f"Error during image analysis: {str(e)}"
|
||||
)
|
||||
117
modules/aicore/aicorePluginInternal.py
Normal file
117
modules/aicore/aicorePluginInternal.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
import logging
|
||||
from typing import List
|
||||
from modules.aicore.aicoreBase import BaseConnectorAi
|
||||
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AiInternal(BaseConnectorAi):
|
||||
"""Internal connector for document processing, generation, and rendering."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
logger.info("Internal Connector initialized")
|
||||
|
||||
def getConnectorType(self) -> str:
|
||||
"""Get the connector type identifier."""
|
||||
return "internal"
|
||||
|
||||
def getModels(self) -> List[AiModel]:
|
||||
"""Get all available internal models."""
|
||||
return [
|
||||
AiModel(
|
||||
name="internal-extractor",
|
||||
displayName="Internal Document Extractor",
|
||||
connectorType="internal",
|
||||
apiUrl="internal://extract",
|
||||
temperature=0.0, # Not applicable for extraction
|
||||
maxTokens=0, # Not token-based
|
||||
contextLength=0,
|
||||
costPer1kTokensInput=0.0,
|
||||
costPer1kTokensOutput=0.0,
|
||||
speedRating=9, # Very fast for internal operations
|
||||
qualityRating=8, # Good quality
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self.extractDocument,
|
||||
priority=PriorityEnum.COST,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(),
|
||||
version="internal-extractor-v1",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: 0.001 + (bytesSent + bytesReceived) / (1024 * 1024) * 0.01
|
||||
),
|
||||
AiModel(
|
||||
name="internal-generator",
|
||||
displayName="Internal Document Generator",
|
||||
connectorType="internal",
|
||||
apiUrl="internal://generate",
|
||||
temperature=0.0, # Not applicable for generation
|
||||
maxTokens=0, # Not token-based
|
||||
contextLength=0,
|
||||
costPer1kTokensInput=0.0,
|
||||
costPer1kTokensOutput=0.0,
|
||||
speedRating=8, # Fast for generation
|
||||
qualityRating=8, # Good quality
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self.generateDocument,
|
||||
priority=PriorityEnum.COST,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(),
|
||||
version="internal-generator-v1",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: 0.002 + (bytesReceived / (1024 * 1024)) * 0.005
|
||||
),
|
||||
AiModel(
|
||||
name="internal-renderer",
|
||||
displayName="Internal Document Renderer",
|
||||
connectorType="internal",
|
||||
apiUrl="internal://render",
|
||||
temperature=0.0, # Not applicable for rendering
|
||||
maxTokens=0, # Not token-based
|
||||
contextLength=0,
|
||||
costPer1kTokensInput=0.0,
|
||||
costPer1kTokensOutput=0.0,
|
||||
speedRating=7, # Good for rendering
|
||||
qualityRating=9, # High quality rendering
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self.renderDocument,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
operationTypes=createOperationTypeRatings(),
|
||||
version="internal-renderer-v1",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: 0.003 + (bytesReceived / (1024 * 1024)) * 0.008
|
||||
)
|
||||
]
|
||||
|
||||
async def extractDocument(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
NOP - we only need the model for price calculations
|
||||
"""
|
||||
logger.error(f"Document extraction not to call here")
|
||||
return AiModelResponse(
|
||||
content="",
|
||||
success=False,
|
||||
error="Internal connector should not be called directly"
|
||||
)
|
||||
|
||||
async def generateDocument(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
NOP - we only need the model for price calculations
|
||||
"""
|
||||
logger.error(f"Document generation not to call here")
|
||||
return AiModelResponse(
|
||||
content="",
|
||||
success=False,
|
||||
error="Internal connector should not be called directly"
|
||||
)
|
||||
|
||||
async def renderDocument(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
NOP - we only need the model for price calculations
|
||||
"""
|
||||
logger.error(f"Document rendering not to call here")
|
||||
return AiModelResponse(
|
||||
content="",
|
||||
success=False,
|
||||
error="Internal connector should not be called directly"
|
||||
)
|
||||
|
||||
388
modules/aicore/aicorePluginOpenai.py
Normal file
388
modules/aicore/aicorePluginOpenai.py
Normal file
|
|
@ -0,0 +1,388 @@
|
|||
import logging
|
||||
import httpx
|
||||
from typing import List
|
||||
from fastapi import HTTPException
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.aicore.aicoreBase import BaseConnectorAi
|
||||
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ContextLengthExceededException(Exception):
|
||||
"""Exception raised when the context length exceeds the model's limit"""
|
||||
pass
|
||||
|
||||
def loadConfigData():
|
||||
"""Load configuration data for OpenAI connector"""
|
||||
return {
|
||||
"apiKey": APP_CONFIG.get('Connector_AiOpenai_API_SECRET'),
|
||||
}
|
||||
|
||||
class AiOpenai(BaseConnectorAi):
|
||||
"""Connector for communication with the OpenAI API."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Load configuration
|
||||
self.config = loadConfigData()
|
||||
self.apiKey = self.config["apiKey"]
|
||||
|
||||
# HttpClient for API calls
|
||||
self.httpClient = httpx.AsyncClient(
|
||||
timeout=120.0, # Longer timeout for complex requests
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.apiKey}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
logger.info("OpenAI Connector initialized")
|
||||
|
||||
def getConnectorType(self) -> str:
|
||||
"""Get the connector type identifier."""
|
||||
return "openai"
|
||||
|
||||
def getModels(self) -> List[AiModel]:
|
||||
"""Get all available OpenAI models."""
|
||||
return [
|
||||
AiModel(
|
||||
name="gpt-4o",
|
||||
displayName="OpenAI GPT-4o",
|
||||
connectorType="openai",
|
||||
apiUrl="https://api.openai.com/v1/chat/completions",
|
||||
temperature=0.2,
|
||||
maxTokens=16384,
|
||||
contextLength=128000,
|
||||
costPer1kTokensInput=0.03,
|
||||
costPer1kTokensOutput=0.06,
|
||||
speedRating=7, # Good speed for complex tasks
|
||||
qualityRating=9, # High quality
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self.callAiBasic,
|
||||
priority=PriorityEnum.BALANCED,
|
||||
processingMode=ProcessingModeEnum.ADVANCED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 8),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 9),
|
||||
(OperationTypeEnum.DATA_GENERATE, 9),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 7)
|
||||
),
|
||||
version="gpt-4o",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.03 + (bytesReceived / 4 / 1000) * 0.06
|
||||
),
|
||||
AiModel(
|
||||
name="gpt-3.5-turbo",
|
||||
displayName="OpenAI GPT-3.5 Turbo",
|
||||
connectorType="openai",
|
||||
apiUrl="https://api.openai.com/v1/chat/completions",
|
||||
temperature=0.2,
|
||||
maxTokens=4096,
|
||||
contextLength=16000,
|
||||
costPer1kTokensInput=0.0015,
|
||||
costPer1kTokensOutput=0.002,
|
||||
speedRating=9, # Very fast
|
||||
qualityRating=7, # Good but not premium
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self.callAiBasic,
|
||||
priority=PriorityEnum.SPEED,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.PLAN, 7),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 8),
|
||||
(OperationTypeEnum.DATA_GENERATE, 8)
|
||||
# Note: GPT-3.5-turbo does NOT support vision/image operations
|
||||
),
|
||||
version="gpt-3.5-turbo",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.0015 + (bytesReceived / 4 / 1000) * 0.002
|
||||
),
|
||||
AiModel(
|
||||
name="gpt-4o",
|
||||
displayName="OpenAI GPT-4o Instance Vision",
|
||||
connectorType="openai",
|
||||
apiUrl="https://api.openai.com/v1/chat/completions",
|
||||
temperature=0.2,
|
||||
maxTokens=16384,
|
||||
contextLength=128000,
|
||||
costPer1kTokensInput=0.03,
|
||||
costPer1kTokensOutput=0.06,
|
||||
speedRating=6, # Slower for vision tasks
|
||||
qualityRating=9, # High quality vision
|
||||
functionCall=self.callAiImage,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.IMAGE_ANALYSE, 9)
|
||||
),
|
||||
version="gpt-4o",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.03 + (bytesReceived / 4 / 1000) * 0.06
|
||||
),
|
||||
AiModel(
|
||||
name="dall-e-3",
|
||||
displayName="OpenAI DALL-E 3",
|
||||
connectorType="openai",
|
||||
apiUrl="https://api.openai.com/v1/images/generations",
|
||||
temperature=0.0, # Image generation doesn't use temperature
|
||||
maxTokens=0, # Image generation doesn't use tokens
|
||||
contextLength=0,
|
||||
costPer1kTokensInput=0.04,
|
||||
costPer1kTokensOutput=0.0,
|
||||
speedRating=5, # Slow for image generation
|
||||
qualityRating=9, # High quality art generation
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self.generateImage,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.IMAGE_GENERATE, 10)
|
||||
),
|
||||
version="dall-e-3",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.04
|
||||
)
|
||||
]
|
||||
|
||||
async def callAiBasic(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
Calls the OpenAI API with the given messages using standardized pattern.
|
||||
|
||||
Args:
|
||||
modelCall: AiModelCall with messages and options
|
||||
|
||||
Returns:
|
||||
AiModelResponse with content and metadata
|
||||
|
||||
Raises:
|
||||
HTTPException: For errors in API communication
|
||||
"""
|
||||
try:
|
||||
# Extract parameters from modelCall
|
||||
messages = modelCall.messages
|
||||
model = modelCall.model
|
||||
options = modelCall.options
|
||||
temperature = getattr(options, "temperature", None)
|
||||
if temperature is None:
|
||||
temperature = model.temperature
|
||||
maxTokens = model.maxTokens
|
||||
|
||||
payload = {
|
||||
"model": model.name,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": maxTokens
|
||||
}
|
||||
|
||||
response = await self.httpClient.post(
|
||||
model.apiUrl,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_message = f"OpenAI API error: {response.status_code} - {response.text}"
|
||||
logger.error(error_message)
|
||||
|
||||
# Check for context length exceeded error
|
||||
if response.status_code == 400:
|
||||
try:
|
||||
error_data = response.json()
|
||||
if (error_data.get("error", {}).get("code") == "context_length_exceeded" or
|
||||
"context length" in error_data.get("error", {}).get("message", "").lower()):
|
||||
# Raise a specific exception for context length issues
|
||||
raise ContextLengthExceededException(
|
||||
f"Context length exceeded: {error_data.get('error', {}).get('message', 'Unknown error')}"
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
pass # If we can't parse the error, fall through to generic error
|
||||
|
||||
# Include the actual error details in the exception
|
||||
raise HTTPException(status_code=500, detail=error_message)
|
||||
|
||||
responseJson = response.json()
|
||||
content = responseJson["choices"][0]["message"]["content"]
|
||||
|
||||
return AiModelResponse(
|
||||
content=content,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
metadata={"response_id": responseJson.get("id", "")}
|
||||
)
|
||||
|
||||
except ContextLengthExceededException:
|
||||
# Re-raise context length exceptions without wrapping
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling OpenAI API: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error calling OpenAI API: {str(e)}")
|
||||
|
||||
async def callAiImage(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
Analyzes an image with the OpenAI Vision API using standardized pattern.
|
||||
|
||||
Args:
|
||||
modelCall: AiModelCall with messages and image data in options
|
||||
|
||||
Returns:
|
||||
AiModelResponse with analysis content
|
||||
"""
|
||||
try:
|
||||
# Extract parameters from modelCall
|
||||
messages = modelCall.messages
|
||||
model = modelCall.model
|
||||
|
||||
# Messages should already be in the correct format with image data embedded
|
||||
# Just verify they contain image data
|
||||
if not messages or not messages[0].get("content"):
|
||||
raise ValueError("No messages provided for image analysis")
|
||||
|
||||
logger.debug(f"Starting image analysis with {len(messages)} message(s)...")
|
||||
|
||||
# Use the messages directly - they should already contain the image data
|
||||
# in the format: {"type": "image_url", "image_url": {"url": "data:...base64,..."}}
|
||||
|
||||
# Use parameters from model
|
||||
temperature = model.temperature
|
||||
# Don't set maxTokens - let the model use its full context length
|
||||
|
||||
payload = {
|
||||
"model": model.name,
|
||||
"messages": messages,
|
||||
"temperature": temperature
|
||||
}
|
||||
|
||||
response = await self.httpClient.post(
|
||||
model.apiUrl,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"OpenAI API error: {response.status_code} - {response.text}")
|
||||
raise HTTPException(status_code=500, detail="Error communicating with OpenAI API")
|
||||
|
||||
responseJson = response.json()
|
||||
content = responseJson["choices"][0]["message"]["content"]
|
||||
|
||||
return AiModelResponse(
|
||||
content=content,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
metadata={"response_id": responseJson.get("id", "")}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during image analysis: {str(e)}", exc_info=True)
|
||||
return AiModelResponse(
|
||||
content="",
|
||||
success=False,
|
||||
error=f"Error during image analysis: {str(e)}"
|
||||
)
|
||||
|
||||
async def generateImage(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
Generate an image using DALL-E 3 using standardized pattern.
|
||||
|
||||
Args:
|
||||
modelCall: AiModelCall with messages and generation options
|
||||
|
||||
Returns:
|
||||
AiModelResponse with generated image data
|
||||
"""
|
||||
try:
|
||||
# Extract parameters from modelCall
|
||||
messages = modelCall.messages
|
||||
model = modelCall.model
|
||||
options = modelCall.options
|
||||
|
||||
# Get prompt from messages
|
||||
promptContent = messages[0]["content"] if messages else ""
|
||||
|
||||
# Parse prompt using AiCallPromptImage model
|
||||
from modules.datamodels.datamodelAi import AiCallPromptImage
|
||||
import json
|
||||
|
||||
try:
|
||||
# Try to parse as JSON
|
||||
promptData = json.loads(promptContent)
|
||||
promptModel = AiCallPromptImage(**promptData)
|
||||
except:
|
||||
# If not JSON, use plain text prompt
|
||||
promptModel = AiCallPromptImage(
|
||||
prompt=promptContent,
|
||||
size=options.size if options and hasattr(options, 'size') else "1024x1024",
|
||||
quality=options.quality if options and hasattr(options, 'quality') else "standard",
|
||||
style=options.style if options and hasattr(options, 'style') else "vivid"
|
||||
)
|
||||
|
||||
# Extract parameters from Pydantic model
|
||||
prompt = promptModel.prompt
|
||||
size = promptModel.size or "1024x1024"
|
||||
quality = promptModel.quality or "standard"
|
||||
style = promptModel.style or "vivid"
|
||||
|
||||
logger.debug(f"Starting image generation with prompt: '{prompt[:100]}...'")
|
||||
|
||||
# DALL-E 3 API endpoint
|
||||
dalle_url = "https://api.openai.com/v1/images/generations"
|
||||
|
||||
payload = {
|
||||
"model": "dall-e-3",
|
||||
"prompt": prompt,
|
||||
"size": size,
|
||||
"quality": quality,
|
||||
"style": style,
|
||||
"n": 1,
|
||||
"response_format": "b64_json" # Get base64 data directly instead of URLs
|
||||
}
|
||||
|
||||
# Create a separate client for DALL-E API calls
|
||||
dalle_client = httpx.AsyncClient(
|
||||
timeout=120.0,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.apiKey}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
response = await dalle_client.post(
|
||||
dalle_url,
|
||||
json=payload
|
||||
)
|
||||
|
||||
await dalle_client.aclose()
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"DALL-E API error: {response.status_code} - {response.text}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"DALL-E API error: {response.status_code} - {response.text}"
|
||||
}
|
||||
|
||||
responseJson = response.json()
|
||||
|
||||
if "data" in responseJson and len(responseJson["data"]) > 0:
|
||||
image_data = responseJson["data"][0]["b64_json"]
|
||||
|
||||
logger.info(f"Successfully generated image: {len(image_data)} characters")
|
||||
return AiModelResponse(
|
||||
content=image_data,
|
||||
success=True,
|
||||
modelId="dall-e-3",
|
||||
metadata={
|
||||
"size": size,
|
||||
"quality": quality,
|
||||
"style": style,
|
||||
"response_id": responseJson.get("id", "")
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.error("No image data in DALL-E response")
|
||||
return AiModelResponse(
|
||||
content="",
|
||||
success=False,
|
||||
error="No image data in DALL-E response"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during image generation: {str(e)}", exc_info=True)
|
||||
return AiModelResponse(
|
||||
content="",
|
||||
success=False,
|
||||
error=f"Error during image generation: {str(e)}"
|
||||
)
|
||||
471
modules/aicore/aicorePluginPerplexity.py
Normal file
471
modules/aicore/aicorePluginPerplexity.py
Normal file
|
|
@ -0,0 +1,471 @@
|
|||
import logging
|
||||
import httpx
|
||||
from typing import List
|
||||
from fastapi import HTTPException
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.aicore.aicoreBase import BaseConnectorAi
|
||||
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings, AiCallPromptWebSearch, AiCallPromptWebCrawl
|
||||
from modules.datamodels.datamodelTools import CountryCodes
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def loadConfigData():
|
||||
"""Load configuration data for Perplexity connector"""
|
||||
return {
|
||||
"apiKey": APP_CONFIG.get('Connector_AiPerplexity_API_SECRET'),
|
||||
}
|
||||
|
||||
class AiPerplexity(BaseConnectorAi):
|
||||
"""Connector for communication with the Perplexity API."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Load configuration
|
||||
self.config = loadConfigData()
|
||||
self.apiKey = self.config["apiKey"]
|
||||
|
||||
# HttpClient for API calls
|
||||
self.httpClient = httpx.AsyncClient(
|
||||
timeout=120.0, # Longer timeout for complex requests
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.apiKey}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("Perplexity Connector initialized")
|
||||
|
||||
def getConnectorType(self) -> str:
|
||||
"""Get the connector type identifier."""
|
||||
return "perplexity"
|
||||
|
||||
def _convertIsoCodeToCountryName(self, isoCode: str) -> str:
|
||||
"""
|
||||
Convert ISO-2 country code to Perplexity country name.
|
||||
Uses centralized CountryCodes mapping.
|
||||
"""
|
||||
return CountryCodes.getForPerplexity(isoCode)
|
||||
|
||||
def getModels(self) -> List[AiModel]:
|
||||
"""Get all available Perplexity models."""
|
||||
return [
|
||||
AiModel(
|
||||
name="sonar",
|
||||
displayName="Perplexity Sonar",
|
||||
connectorType="perplexity",
|
||||
apiUrl="https://api.perplexity.ai/chat/completions",
|
||||
temperature=0.2,
|
||||
maxTokens=24000, # Increased for detailed web crawl responses (Perplexity supports up to 25k)
|
||||
contextLength=32000,
|
||||
costPer1kTokensInput=0.005,
|
||||
costPer1kTokensOutput=0.005,
|
||||
speedRating=8,
|
||||
qualityRating=8,
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self._routeWebOperation,
|
||||
priority=PriorityEnum.BALANCED,
|
||||
processingMode=ProcessingModeEnum.ADVANCED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.WEB_SEARCH, 9),
|
||||
(OperationTypeEnum.WEB_CRAWL, 7)
|
||||
),
|
||||
version="sonar",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.005 + (bytesReceived / 4 / 1000) * 0.005
|
||||
),
|
||||
AiModel(
|
||||
name="sonar-pro",
|
||||
displayName="Perplexity Sonar Pro",
|
||||
connectorType="perplexity",
|
||||
apiUrl="https://api.perplexity.ai/chat/completions",
|
||||
temperature=0.2,
|
||||
maxTokens=24000, # Increased for detailed web crawl responses (Perplexity supports up to 25k)
|
||||
contextLength=32000,
|
||||
costPer1kTokensInput=0.01,
|
||||
costPer1kTokensOutput=0.01,
|
||||
speedRating=6, # Slower due to AI analysis
|
||||
qualityRating=9, # Best AI analysis quality
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self._routeWebOperation,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.WEB_SEARCH, 9),
|
||||
(OperationTypeEnum.WEB_CRAWL, 8)
|
||||
),
|
||||
version="sonar-pro",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.01 + (bytesReceived / 4 / 1000) * 0.01
|
||||
)
|
||||
]
|
||||
|
||||
async def callAiBasic(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
Calls the Perplexity API with the given messages using standardized pattern.
|
||||
|
||||
Args:
|
||||
modelCall: AiModelCall with messages and options
|
||||
|
||||
Returns:
|
||||
AiModelResponse with content and metadata
|
||||
|
||||
Raises:
|
||||
HTTPException: For errors in API communication
|
||||
"""
|
||||
try:
|
||||
# Extract parameters from modelCall
|
||||
messages = modelCall.messages
|
||||
model = modelCall.model
|
||||
options = modelCall.options
|
||||
temperature = getattr(options, "temperature", None)
|
||||
if temperature is None:
|
||||
temperature = model.temperature
|
||||
maxTokens = model.maxTokens
|
||||
|
||||
payload = {
|
||||
"model": model.name,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": maxTokens
|
||||
}
|
||||
|
||||
response = await self.httpClient.post(
|
||||
model.apiUrl,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
errorDetail = f"Perplexity API error: {response.status_code} - {response.text}"
|
||||
logger.error(errorDetail)
|
||||
|
||||
# Provide more specific error messages based on status code
|
||||
if response.status_code == 429:
|
||||
errorMessage = "Rate limit exceeded. Please wait before making another request."
|
||||
elif response.status_code == 401:
|
||||
errorMessage = "Invalid API key. Please check your Perplexity API configuration."
|
||||
elif response.status_code == 400:
|
||||
errorMessage = f"Invalid request to Perplexity API: {response.text}"
|
||||
else:
|
||||
errorMessage = f"Perplexity API error ({response.status_code}): {response.text}"
|
||||
|
||||
raise HTTPException(status_code=500, detail=errorMessage)
|
||||
|
||||
apiResponse = response.json()
|
||||
content = apiResponse["choices"][0]["message"]["content"]
|
||||
|
||||
return AiModelResponse(
|
||||
content=content,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
metadata={"response_id": apiResponse.get("id", "")}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Perplexity API: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error calling Perplexity API: {str(e)}")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
async def _testConnection(self) -> bool:
|
||||
"""
|
||||
Tests the connection to the Perplexity API.
|
||||
|
||||
Returns:
|
||||
True if connection is successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Try a simple test message
|
||||
testMessages = [
|
||||
{"role": "user", "content": "Hello, please respond with just 'OK' to confirm the connection works."}
|
||||
]
|
||||
|
||||
# Create a model call for testing
|
||||
from modules.datamodels.datamodelAi import AiCallOptions
|
||||
model = self.getModels()[0] # Get first model for testing
|
||||
testCall = AiModelCall(
|
||||
messages=testMessages,
|
||||
model=model,
|
||||
options=AiCallOptions()
|
||||
)
|
||||
|
||||
response = await self.callAiBasic(testCall)
|
||||
return response.success and len(response.content.strip()) > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Perplexity connection test failed: {str(e)}")
|
||||
return False
|
||||
|
||||
async def _routeWebOperation(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
Route web operation based on operation type.
|
||||
|
||||
Args:
|
||||
modelCall: AiModelCall with messages and options
|
||||
|
||||
Returns:
|
||||
AiModelResponse based on operation type
|
||||
"""
|
||||
operationType = modelCall.options.operationType
|
||||
|
||||
if operationType == OperationTypeEnum.WEB_SEARCH:
|
||||
return await self.webSearch(modelCall)
|
||||
elif operationType == OperationTypeEnum.WEB_CRAWL:
|
||||
return await self.webCrawl(modelCall)
|
||||
else:
|
||||
# Fallback to basic call
|
||||
return await self.callAiBasic(modelCall)
|
||||
|
||||
def _getDepthInstructions(self, maxDepth: int) -> str:
|
||||
"""
|
||||
Map maxDepth (numeric) to instructional text for LLM.
|
||||
|
||||
Args:
|
||||
maxDepth: 1 (fast/overview), 2 (general/standard), 3 (deep/comprehensive)
|
||||
|
||||
Returns:
|
||||
Instructional text for the LLM
|
||||
"""
|
||||
depthMap = {
|
||||
1: "Basic overview - extract main content from the main page only",
|
||||
2: "Standard crawl - extract content from main page and linked pages (2 levels deep)",
|
||||
3: "Deep crawl - comprehensively extract content from main page and all accessible linked pages (3+ levels deep)"
|
||||
}
|
||||
return depthMap.get(maxDepth, depthMap[2])
|
||||
|
||||
def _getWidthInstructions(self, maxWidth: int) -> str:
|
||||
"""
|
||||
Map maxWidth (numeric) to instructional text for LLM.
|
||||
|
||||
Args:
|
||||
maxWidth: Number of pages to crawl at each level (default: 10)
|
||||
|
||||
Returns:
|
||||
Instructional text for the LLM
|
||||
"""
|
||||
if maxWidth <= 5:
|
||||
return f"Focused crawl - limit to {maxWidth} most relevant pages per level"
|
||||
elif maxWidth <= 15:
|
||||
return f"Standard breadth - crawl up to {maxWidth} pages per level"
|
||||
elif maxWidth <= 30:
|
||||
return f"Wide crawl - crawl up to {maxWidth} pages per level, prioritize quality"
|
||||
else:
|
||||
return f"Extensive crawl - crawl up to {maxWidth} pages per level, comprehensive coverage"
|
||||
|
||||
async def webSearch(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
WEB_SEARCH operation - returns list of URLs based on search query.
|
||||
|
||||
Args:
|
||||
modelCall: AiModelCall with AiCallPromptWebSearch as prompt
|
||||
|
||||
Returns:
|
||||
AiModelResponse with JSON list of URLs
|
||||
"""
|
||||
try:
|
||||
# Extract parameters
|
||||
messages = modelCall.messages
|
||||
model = modelCall.model
|
||||
options = modelCall.options
|
||||
temperature = getattr(options, "temperature", None) or model.temperature
|
||||
maxTokens = model.maxTokens
|
||||
|
||||
# Parse prompt JSON - find user message (not system message)
|
||||
promptContent = ""
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
promptContent = msg.get("content", "")
|
||||
break
|
||||
# Fallback to first message if no user message found
|
||||
if not promptContent and len(messages) > 0:
|
||||
promptContent = messages[0].get("content", "")
|
||||
|
||||
import json
|
||||
promptData = json.loads(promptContent)
|
||||
|
||||
# Create Pydantic model
|
||||
webSearchPrompt = AiCallPromptWebSearch(**promptData)
|
||||
|
||||
# Convert ISO country code to country name
|
||||
countryName = webSearchPrompt.country
|
||||
if countryName:
|
||||
countryName = self._convertIsoCodeToCountryName(countryName)
|
||||
|
||||
# Build search request for Perplexity
|
||||
searchPrompt = f"""Search the web for: {webSearchPrompt.instruction}
|
||||
|
||||
Return a JSON array of {webSearchPrompt.maxNumberPages} most relevant URLs.
|
||||
{'' if not countryName else f'Focus on results from {countryName}.'}
|
||||
|
||||
Return ONLY a JSON array of URLs, no additional text:
|
||||
[
|
||||
"https://example1.com/page",
|
||||
"https://example2.com/article",
|
||||
"https://example3.com/resource"
|
||||
]"""
|
||||
|
||||
payload = {
|
||||
"model": model.name,
|
||||
"messages": [{"role": "user", "content": searchPrompt}],
|
||||
"temperature": temperature,
|
||||
"max_tokens": maxTokens
|
||||
}
|
||||
|
||||
response = await self.httpClient.post(model.apiUrl, json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(status_code=500, detail=f"Perplexity Web Search API error: {response.text}")
|
||||
|
||||
# Check if response body is empty or invalid
|
||||
responseText = response.text
|
||||
if not responseText or not responseText.strip():
|
||||
raise HTTPException(status_code=500, detail="Perplexity Web Search API returned empty response")
|
||||
|
||||
try:
|
||||
apiResponse = response.json()
|
||||
except Exception as jsonError:
|
||||
logger.error(f"Failed to parse Perplexity response as JSON. Status: {response.status_code}, Response: {responseText[:500]}")
|
||||
raise HTTPException(status_code=500, detail=f"Perplexity Web Search API returned invalid JSON: {str(jsonError)}")
|
||||
|
||||
if "choices" not in apiResponse or not apiResponse["choices"]:
|
||||
raise HTTPException(status_code=500, detail="Perplexity Web Search API response missing 'choices' field")
|
||||
|
||||
content = apiResponse["choices"][0]["message"]["content"]
|
||||
|
||||
return AiModelResponse(
|
||||
content=content,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
metadata={"response_id": apiResponse.get("id", ""), "operation": "WEB_SEARCH"}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Perplexity web search: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error in Perplexity web search: {str(e)}")
|
||||
|
||||
async def webCrawl(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||
"""
|
||||
WEB_CRAWL operation - crawls ONE URL and returns content.
|
||||
|
||||
Perplexity API Parameters Used:
|
||||
- messages: The prompt containing URL and instruction
|
||||
- max_tokens: Maximum response length
|
||||
- max_results: Number of search results (1-20, default: 10)
|
||||
- temperature: Response randomness (not web search specific)
|
||||
|
||||
Pagination: Perplexity does NOT return paginated responses.
|
||||
A single response contains all results within max_tokens limit.
|
||||
|
||||
Args:
|
||||
modelCall: AiModelCall with AiCallPromptWebCrawl as prompt
|
||||
|
||||
Returns:
|
||||
AiModelResponse with crawl results as JSON object
|
||||
"""
|
||||
try:
|
||||
# Extract parameters
|
||||
messages = modelCall.messages
|
||||
model = modelCall.model
|
||||
options = modelCall.options
|
||||
temperature = getattr(options, "temperature", None) or model.temperature
|
||||
maxTokens = model.maxTokens
|
||||
|
||||
# Parse prompt JSON - find user message (not system message)
|
||||
promptContent = ""
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
promptContent = msg.get("content", "")
|
||||
break
|
||||
# Fallback to first message if no user message found
|
||||
if not promptContent and len(messages) > 0:
|
||||
promptContent = messages[0].get("content", "")
|
||||
|
||||
import json
|
||||
promptData = json.loads(promptContent)
|
||||
|
||||
# Create Pydantic model
|
||||
webCrawlPrompt = AiCallPromptWebCrawl(**promptData)
|
||||
|
||||
# Build crawl request for Perplexity - ONE URL
|
||||
# Match playground prompt style: just URL + question
|
||||
# This allows Perplexity to return detailed multi-source results
|
||||
crawlPrompt = f"{webCrawlPrompt.url}: {webCrawlPrompt.instruction}"
|
||||
|
||||
# Build payload with optional Perplexity parameters
|
||||
# Note: max_tokens_per_page may not be supported by chat/completions endpoint
|
||||
# The playground Python SDK might use a different internal API
|
||||
maxResults = min(webCrawlPrompt.maxWidth or 10, 20) # Max 20 results
|
||||
|
||||
payload = {
|
||||
"model": model.name,
|
||||
"messages": [{"role": "user", "content": crawlPrompt}],
|
||||
"temperature": temperature,
|
||||
"max_tokens": maxTokens, # Use model's configured maxTokens (24000)
|
||||
"max_results": maxResults,
|
||||
"return_citations": True # Request citations explicitly
|
||||
}
|
||||
|
||||
logger.info(f"Perplexity crawl payload: model={model.name}, prompt_length={len(crawlPrompt)}, max_tokens={maxTokens}, max_results={maxResults}")
|
||||
|
||||
response = await self.httpClient.post(model.apiUrl, json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(status_code=500, detail=f"Perplexity Web Crawl API error: {response.text}")
|
||||
|
||||
# Check if response body is empty or invalid
|
||||
responseText = response.text
|
||||
if not responseText or not responseText.strip():
|
||||
raise HTTPException(status_code=500, detail="Perplexity Web Crawl API returned empty response")
|
||||
|
||||
try:
|
||||
apiResponse = response.json()
|
||||
except Exception as jsonError:
|
||||
logger.error(f"Failed to parse Perplexity response as JSON. Status: {response.status_code}, Response: {responseText[:500]}")
|
||||
raise HTTPException(status_code=500, detail=f"Perplexity Web Crawl API returned invalid JSON: {str(jsonError)}")
|
||||
|
||||
if "choices" not in apiResponse or not apiResponse["choices"]:
|
||||
raise HTTPException(status_code=500, detail="Perplexity Web Crawl API response missing 'choices' field")
|
||||
|
||||
# Extract the main content
|
||||
content = apiResponse["choices"][0]["message"]["content"]
|
||||
|
||||
# Check for citations or search results in the response
|
||||
citations = apiResponse.get("citations", [])
|
||||
searchResults = apiResponse.get("search_results", [])
|
||||
|
||||
# Log what we found
|
||||
if citations:
|
||||
logger.info(f"Found {len(citations)} citations in response")
|
||||
if searchResults:
|
||||
logger.info(f"Found {len(searchResults)} search results in response")
|
||||
logger.debug(f"API response keys: {list(apiResponse.keys())}")
|
||||
|
||||
# Build comprehensive response with citations if available
|
||||
import json
|
||||
responseData = {
|
||||
"content": content,
|
||||
"citations": citations if citations else [],
|
||||
"search_results": searchResults if searchResults else []
|
||||
}
|
||||
|
||||
# Return comprehensive response
|
||||
return AiModelResponse(
|
||||
content=json.dumps(responseData, indent=2) if (citations or searchResults) else content,
|
||||
success=True,
|
||||
modelId=model.name,
|
||||
metadata={
|
||||
"response_id": apiResponse.get("id", ""),
|
||||
"operation": "WEB_CRAWL",
|
||||
"url": webCrawlPrompt.url,
|
||||
"actualPromptSent": crawlPrompt,
|
||||
"has_citations": len(citations) > 0,
|
||||
"has_search_results": len(searchResults) > 0
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Perplexity web crawl: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error in Perplexity web crawl: {str(e)}")
|
||||
610
modules/aicore/aicorePluginTavily.py
Normal file
610
modules/aicore/aicorePluginTavily.py
Normal file
|
|
@ -0,0 +1,610 @@
|
|||
"""Tavily web search class.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Dict
|
||||
from tavily import AsyncTavilyClient
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.aicore.aicoreBase import BaseConnectorAi
|
||||
from modules.datamodels.datamodelAi import AiModel, PriorityEnum, ProcessingModeEnum, OperationTypeEnum, AiModelCall, AiModelResponse, createOperationTypeRatings, AiCallPromptWebSearch, AiCallPromptWebCrawl
|
||||
from modules.datamodels.datamodelTools import CountryCodes
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class WebSearchResult:
|
||||
title: str
|
||||
url: str
|
||||
rawContent: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class WebCrawlResult:
|
||||
url: str
|
||||
content: str
|
||||
title: Optional[str] = None
|
||||
|
||||
|
||||
class AiTavily(BaseConnectorAi):
|
||||
"""Tavily web search connector."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.client: Optional[AsyncTavilyClient] = None
|
||||
# Cached settings loaded at initialization time
|
||||
self.crawlTimeout: int = 30
|
||||
self.crawlMaxRetries: int = 3
|
||||
self.crawlRetryDelay: int = 2
|
||||
# Cached web search constraints (camelCase per project style)
|
||||
self.webSearchMinResults: int = 1
|
||||
self.webSearchMaxResults: int = 20
|
||||
# Initialize client if API key is available
|
||||
self._initializeClient()
|
||||
|
||||
|
||||
def getModels(self) -> List[AiModel]:
|
||||
"""Get all available Tavily models."""
|
||||
return [
|
||||
AiModel(
|
||||
name="tavily-search",
|
||||
displayName="Tavily Search & Research",
|
||||
connectorType="tavily",
|
||||
apiUrl="https://api.tavily.com",
|
||||
temperature=0.0, # Web search doesn't use temperature
|
||||
maxTokens=0, # Web search doesn't use tokens
|
||||
contextLength=0,
|
||||
costPer1kTokensInput=0.0,
|
||||
costPer1kTokensOutput=0.0,
|
||||
speedRating=8, # Good speed for search and extract
|
||||
qualityRating=9, # Excellent quality for web research
|
||||
# capabilities removed (not used in business logic)
|
||||
functionCall=self._routeWebOperation,
|
||||
priority=PriorityEnum.BALANCED,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.WEB_SEARCH, 9),
|
||||
(OperationTypeEnum.WEB_CRAWL, 10)
|
||||
),
|
||||
version="tavily-search",
|
||||
calculatePriceUsd=lambda processingTime, bytesSent, bytesReceived: 0.008 # Simple flat rate
|
||||
)
|
||||
]
|
||||
|
||||
def _initializeClient(self):
|
||||
"""Initialize the Tavily client if API key is available."""
|
||||
try:
|
||||
apiKey = APP_CONFIG.get("Connector_AiTavily_API_SECRET")
|
||||
if apiKey:
|
||||
self.client = AsyncTavilyClient(api_key=apiKey)
|
||||
logger.info("Tavily client initialized successfully")
|
||||
else:
|
||||
logger.warning("Tavily API key not found, client not initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize Tavily client: {str(e)}")
|
||||
|
||||
def getConnectorType(self) -> str:
|
||||
"""Get the connector type identifier."""
|
||||
return "tavily"
|
||||
|
||||
def _convertIsoCodeToCountryName(self, isoCode: str) -> str:
|
||||
"""
|
||||
Convert ISO-2 country code to Tavily country name.
|
||||
Uses centralized CountryCodes mapping.
|
||||
"""
|
||||
return CountryCodes.getForTavily(isoCode)
|
||||
|
||||
def _extractUrlsFromPrompt(self, prompt: str) -> List[str]:
|
||||
"""Extract URLs from a text prompt using regex."""
|
||||
if not prompt:
|
||||
return []
|
||||
|
||||
# URL regex pattern - matches http/https URLs
|
||||
urlPattern = r'https?://(?:[-\w.])+(?:[:\d]+)?(?:/(?:[\w/_.])*(?:\?(?:[\w&=%.])*)?(?:#(?:[\w.])*)?)?'
|
||||
urls = re.findall(urlPattern, prompt)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
uniqueUrls = []
|
||||
for url in urls:
|
||||
if url not in seen:
|
||||
seen.add(url)
|
||||
uniqueUrls.append(url)
|
||||
|
||||
return uniqueUrls
|
||||
|
||||
def _normalizeUrl(self, url: str) -> str:
|
||||
"""
|
||||
Normalize URL for better deduplication.
|
||||
Removes common variations that represent the same content.
|
||||
"""
|
||||
if not url:
|
||||
return url
|
||||
|
||||
# Remove trailing slashes
|
||||
url = url.rstrip('/')
|
||||
|
||||
# Remove common query parameters that don't affect content
|
||||
import urllib.parse
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
|
||||
# Remove common tracking parameters
|
||||
queryParams = urllib.parse.parse_qs(parsed.query)
|
||||
filteredParams = {}
|
||||
|
||||
for key, values in queryParams.items():
|
||||
# Keep important parameters, remove tracking ones
|
||||
if key.lower() not in ['utm_source', 'utm_medium', 'utm_campaign', 'utm_term', 'utm_content',
|
||||
'fbclid', 'gclid', 'ref', 'source', 'campaign']:
|
||||
filteredParams[key] = values
|
||||
|
||||
# Rebuild query string
|
||||
filteredQuery = urllib.parse.urlencode(filteredParams, doseq=True)
|
||||
|
||||
# Reconstruct URL
|
||||
normalized = urllib.parse.urlunparse((
|
||||
parsed.scheme,
|
||||
parsed.netloc,
|
||||
parsed.path,
|
||||
parsed.params,
|
||||
filteredQuery,
|
||||
parsed.fragment
|
||||
))
|
||||
|
||||
return normalized
|
||||
|
||||
def _calculateRelevanceScore(self, result: WebSearchResult, queryWords: set) -> float:
|
||||
"""
|
||||
Calculate relevance score for a search result.
|
||||
Higher score means more relevant to the query.
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
# Title relevance (most important)
|
||||
titleWords = set(result.title.lower().split())
|
||||
titleMatches = len(queryWords.intersection(titleWords))
|
||||
score += titleMatches * 3.0 # Weight title matches heavily
|
||||
|
||||
# URL relevance
|
||||
urlWords = set(result.url.lower().split('/'))
|
||||
urlMatches = len(queryWords.intersection(urlWords))
|
||||
score += urlMatches * 1.5
|
||||
|
||||
# Content relevance (if available)
|
||||
if hasattr(result, 'rawContent') and result.rawContent:
|
||||
contentWords = set(result.rawContent.lower().split())
|
||||
contentMatches = len(queryWords.intersection(contentWords))
|
||||
score += contentMatches * 0.1 # Lower weight for content matches
|
||||
|
||||
# Domain authority bonus (simple heuristic)
|
||||
domain = result.url.split('/')[2] if '/' in result.url else result.url
|
||||
if any(authDomain in domain.lower() for authDomain in
|
||||
['wikipedia.org', 'github.com', 'stackoverflow.com', 'reddit.com', 'medium.com']):
|
||||
score += 1.0
|
||||
|
||||
# Penalty for very long URLs (often less relevant)
|
||||
if len(result.url) > 100:
|
||||
score -= 0.5
|
||||
|
||||
return score
|
||||
|
||||
def _intelligentUrlFiltering(self, searchResults: List[WebSearchResult], query: str, maxResults: int) -> List[WebSearchResult]:
|
||||
"""
|
||||
Intelligent URL filtering with de-duplication and relevance scoring.
|
||||
|
||||
Args:
|
||||
searchResults: Raw search results from Tavily
|
||||
query: Original search query for relevance scoring
|
||||
maxResults: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
Filtered and deduplicated list of search results
|
||||
"""
|
||||
if not searchResults:
|
||||
return []
|
||||
|
||||
# Step 1: Basic de-duplication by URL
|
||||
seenUrls = set()
|
||||
uniqueResults = []
|
||||
|
||||
for result in searchResults:
|
||||
# Normalize URL for better deduplication
|
||||
normalizedUrl = self._normalizeUrl(result.url)
|
||||
if normalizedUrl not in seenUrls:
|
||||
seenUrls.add(normalizedUrl)
|
||||
uniqueResults.append(result)
|
||||
|
||||
logger.info(f"After basic deduplication: {len(uniqueResults)} unique URLs from {len(searchResults)} original")
|
||||
|
||||
# Step 2: Relevance scoring and filtering
|
||||
scoredResults = []
|
||||
queryWords = set(query.lower().split())
|
||||
|
||||
for result in uniqueResults:
|
||||
score = self._calculateRelevanceScore(result, queryWords)
|
||||
scoredResults.append((score, result))
|
||||
|
||||
# Step 3: Sort by relevance score (higher is better)
|
||||
scoredResults.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Step 4: Take top results
|
||||
filteredResults = [result for score, result in scoredResults[:maxResults]]
|
||||
|
||||
logger.info(f"After intelligent filtering: {len(filteredResults)} results selected from {len(uniqueResults)} unique")
|
||||
|
||||
return filteredResults
|
||||
|
||||
@classmethod
|
||||
async def create(cls):
|
||||
apiKey = APP_CONFIG.get("Connector_AiTavily_API_SECRET")
|
||||
if not apiKey:
|
||||
raise ValueError("Tavily API key not configured. Please set Connector_AiTavily_API_SECRET in config.ini")
|
||||
# Load and cache web crawl related configuration
|
||||
crawlTimeout = int(APP_CONFIG.get("Web_Crawl_TIMEOUT", "30"))
|
||||
crawlMaxRetries = int(APP_CONFIG.get("Web_Crawl_MAX_RETRIES", "3"))
|
||||
crawlRetryDelay = int(APP_CONFIG.get("Web_Crawl_RETRY_DELAY", "2"))
|
||||
return cls(
|
||||
client=AsyncTavilyClient(api_key=apiKey),
|
||||
crawlTimeout=crawlTimeout,
|
||||
crawlMaxRetries=crawlMaxRetries,
|
||||
crawlRetryDelay=crawlRetryDelay,
|
||||
webSearchMinResults=int(APP_CONFIG.get("Web_Search_MIN_RESULTS", "1")),
|
||||
webSearchMaxResults=int(APP_CONFIG.get("Web_Search_MAX_RESULTS", "20")),
|
||||
)
|
||||
|
||||
# Standardized method using AiModelCall/AiModelResponse pattern
|
||||
|
||||
|
||||
def _cleanUrl(self, url: str) -> str:
|
||||
"""Clean URL by removing extra text that might be appended."""
|
||||
import re
|
||||
# Extract just the URL part, removing any extra text after it
|
||||
urlMatch = re.match(r'(https?://[^\s,]+)', url)
|
||||
if urlMatch:
|
||||
return urlMatch.group(1)
|
||||
return url
|
||||
|
||||
async def _search(
|
||||
self,
|
||||
query: str,
|
||||
maxResults: int,
|
||||
searchDepth: str | None = None,
|
||||
timeRange: str | None = None,
|
||||
topic: str | None = None,
|
||||
includeDomains: list[str] | None = None,
|
||||
excludeDomains: list[str] | None = None,
|
||||
country: str | None = None,
|
||||
includeAnswer: str | None = None,
|
||||
includeRawContent: str | None = None,
|
||||
) -> list[WebSearchResult]:
|
||||
"""Calls the Tavily API to perform a web search."""
|
||||
# Make sure maxResults is within the allowed range (use cached values)
|
||||
minResults = self.webSearchMinResults
|
||||
maxAllowedResults = self.webSearchMaxResults
|
||||
if maxResults < minResults or maxResults > maxAllowedResults:
|
||||
raise ValueError(f"maxResults must be between {minResults} and {maxAllowedResults}")
|
||||
|
||||
# Perform actual API call
|
||||
# Build kwargs only for provided options to avoid API rejections
|
||||
kwargs: dict = {"query": query, "max_results": maxResults}
|
||||
if searchDepth is not None:
|
||||
kwargs["search_depth"] = searchDepth
|
||||
if timeRange is not None:
|
||||
kwargs["time_range"] = timeRange
|
||||
if topic is not None:
|
||||
kwargs["topic"] = topic
|
||||
if includeDomains is not None and len(includeDomains) > 0:
|
||||
kwargs["include_domains"] = includeDomains
|
||||
if excludeDomains is not None:
|
||||
kwargs["exclude_domains"] = excludeDomains
|
||||
if country is not None:
|
||||
kwargs["country"] = country
|
||||
if includeAnswer is not None:
|
||||
kwargs["include_answer"] = includeAnswer
|
||||
if includeRawContent is not None:
|
||||
kwargs["include_raw_content"] = includeRawContent
|
||||
|
||||
# Log the final API call parameters for comparison
|
||||
logger.info(f"Tavily API call parameters: {kwargs}")
|
||||
|
||||
# Ensure client is initialized
|
||||
if self.client is None:
|
||||
self._initializeClient()
|
||||
if self.client is None:
|
||||
raise ValueError("Tavily client not initialized. Please check API key configuration.")
|
||||
|
||||
response = await self.client.search(**kwargs)
|
||||
|
||||
# Return all results without score filtering
|
||||
# Tavily's scoring is already applied by the API
|
||||
logger.info(f"Tavily returned {len(response.get('results', []))} results")
|
||||
|
||||
return [
|
||||
WebSearchResult(
|
||||
title=result["title"],
|
||||
url=self._cleanUrl(result["url"]),
|
||||
rawContent=result.get("raw_content")
|
||||
)
|
||||
for result in response["results"]
|
||||
]
|
||||
|
||||
async def _crawl(
|
||||
self,
|
||||
url: str,
|
||||
instructions: str | None = None,
|
||||
limit: int = 20,
|
||||
maxDepth: int = 2,
|
||||
maxBreadth: int = 40,
|
||||
) -> list[WebCrawlResult]:
|
||||
"""Calls the Tavily API to crawl ONE URL with link following and retry logic."""
|
||||
maxRetries = self.crawlMaxRetries
|
||||
retryDelay = self.crawlRetryDelay
|
||||
timeout = self.crawlTimeout
|
||||
|
||||
logger.debug(f"Starting crawl of URL: {url}")
|
||||
logger.debug(f"Crawl settings: instructions={instructions}, limit={limit}, maxDepth={maxDepth}, maxBreadth={maxBreadth}, timeout={timeout}s")
|
||||
|
||||
for attempt in range(maxRetries + 1):
|
||||
try:
|
||||
logger.debug(f"Crawl attempt {attempt + 1}/{maxRetries + 1}")
|
||||
|
||||
# Ensure client is initialized
|
||||
if self.client is None:
|
||||
self._initializeClient()
|
||||
if self.client is None:
|
||||
raise ValueError("Tavily client not initialized. Please check API key configuration.")
|
||||
|
||||
logger.debug(f"Crawling URL: {url}")
|
||||
|
||||
# Build kwargs for crawl
|
||||
kwargsCrawl: dict = {"url": url}
|
||||
if instructions:
|
||||
kwargsCrawl["instructions"] = instructions
|
||||
if limit:
|
||||
kwargsCrawl["limit"] = limit
|
||||
if maxDepth:
|
||||
kwargsCrawl["max_depth"] = maxDepth
|
||||
if maxBreadth:
|
||||
kwargsCrawl["max_breadth"] = maxBreadth
|
||||
|
||||
logger.debug(f"Sending request to Tavily with kwargs: {kwargsCrawl}")
|
||||
|
||||
response = await asyncio.wait_for(
|
||||
self.client.crawl(**kwargsCrawl),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
logger.debug(f"Tavily response received: {type(response)}")
|
||||
|
||||
# Parse response - could be dict with results or list
|
||||
if isinstance(response, dict) and "results" in response:
|
||||
pageResults = response["results"]
|
||||
elif isinstance(response, list):
|
||||
pageResults = response
|
||||
else:
|
||||
logger.warning(f"Unexpected response format: {type(response)}")
|
||||
pageResults = []
|
||||
|
||||
logger.debug(f"Got {len(pageResults)} pages from crawl")
|
||||
|
||||
# Convert to WebCrawlResult format
|
||||
results = []
|
||||
for result in pageResults:
|
||||
results.append(WebCrawlResult(
|
||||
url=result.get("url", url),
|
||||
content=result.get("raw_content", result.get("content", "")),
|
||||
title=result.get("title", "")
|
||||
))
|
||||
|
||||
logger.debug(f"Crawl successful: extracted {len(results)} pages from URL")
|
||||
return results
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Crawl attempt {attempt + 1} timed out after {timeout} seconds for URL: {url}")
|
||||
if attempt < maxRetries:
|
||||
logger.info(f"Retrying in {retryDelay} seconds...")
|
||||
await asyncio.sleep(retryDelay)
|
||||
else:
|
||||
raise Exception(f"Crawl failed after {maxRetries + 1} attempts due to timeout")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Crawl attempt {attempt + 1} failed for URL {url}: {str(e)}")
|
||||
logger.debug(f"Full error details: {type(e).__name__}: {str(e)}")
|
||||
|
||||
# Check if it's a validation error and log more details
|
||||
if "validation" in str(e).lower():
|
||||
logger.debug(f"URL validation failed. Checking URL format:")
|
||||
logger.debug(f" URL: '{url}' (length: {len(url)})")
|
||||
# Check for common URL issues
|
||||
if ' ' in url:
|
||||
logger.debug(f" WARNING: URL contains spaces!")
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
logger.debug(f" WARNING: URL doesn't start with http/https!")
|
||||
if len(url) > 2000:
|
||||
logger.debug(f" WARNING: URL is very long ({len(url)} chars)")
|
||||
|
||||
if attempt < maxRetries:
|
||||
logger.info(f"Retrying in {retryDelay} seconds...")
|
||||
await asyncio.sleep(retryDelay)
|
||||
else:
|
||||
raise Exception(f"Crawl failed after {maxRetries + 1} attempts: {str(e)}")
|
||||
|
||||
async def _routeWebOperation(self, modelCall: AiModelCall) -> "AiModelResponse":
|
||||
"""
|
||||
Route web operation based on operation type.
|
||||
|
||||
Args:
|
||||
modelCall: AiModelCall with messages and options
|
||||
|
||||
Returns:
|
||||
AiModelResponse based on operation type
|
||||
"""
|
||||
operationType = modelCall.options.operationType
|
||||
|
||||
if operationType == OperationTypeEnum.WEB_SEARCH:
|
||||
return await self.webSearch(modelCall)
|
||||
elif operationType == OperationTypeEnum.WEB_CRAWL:
|
||||
return await self.webCrawl(modelCall)
|
||||
else:
|
||||
# Unsupported operation type
|
||||
return AiModelResponse(
|
||||
content="",
|
||||
success=False,
|
||||
error=f"Unsupported operation type: {operationType}"
|
||||
)
|
||||
|
||||
async def webSearch(self, modelCall: AiModelCall) -> "AiModelResponse":
|
||||
"""
|
||||
WEB_SEARCH operation - returns list of URLs using Tavily search.
|
||||
|
||||
Args:
|
||||
modelCall: AiModelCall with AiCallPromptWebSearch as prompt
|
||||
|
||||
Returns:
|
||||
AiModelResponse with JSON list of URLs
|
||||
"""
|
||||
try:
|
||||
# Extract parameters - find user message (not system message)
|
||||
promptContent = ""
|
||||
if modelCall.messages:
|
||||
for msg in modelCall.messages:
|
||||
if msg.get("role") == "user":
|
||||
promptContent = msg.get("content", "")
|
||||
break
|
||||
# Fallback to first message if no user message found
|
||||
if not promptContent and len(modelCall.messages) > 0:
|
||||
promptContent = modelCall.messages[0].get("content", "")
|
||||
|
||||
if not promptContent or not promptContent.strip():
|
||||
raise ValueError("Empty prompt content received for web search")
|
||||
|
||||
import json
|
||||
try:
|
||||
promptData = json.loads(promptContent)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse prompt content as JSON: {promptContent[:200]}")
|
||||
raise ValueError(f"Invalid JSON in prompt content: {str(e)}")
|
||||
|
||||
# Create Pydantic model
|
||||
webSearchPrompt = AiCallPromptWebSearch(**promptData)
|
||||
|
||||
# Convert ISO country code to country name for Tavily
|
||||
countryName = webSearchPrompt.country
|
||||
if countryName:
|
||||
countryName = self._convertIsoCodeToCountryName(countryName)
|
||||
|
||||
# Perform search - use exact parameters from prompt
|
||||
# NOTE: timeRange parameter causes generic results, so we don't use it
|
||||
searchResults = await self._search(
|
||||
query=webSearchPrompt.instruction,
|
||||
maxResults=webSearchPrompt.maxNumberPages,
|
||||
timeRange=None, # Not used - causes generic results
|
||||
country=countryName,
|
||||
includeAnswer="basic",
|
||||
includeRawContent="text"
|
||||
)
|
||||
|
||||
# Extract URLs from results
|
||||
urls = [result.url for result in searchResults]
|
||||
|
||||
# Return as JSON array
|
||||
import json
|
||||
return AiModelResponse(
|
||||
content=json.dumps(urls, indent=2),
|
||||
success=True,
|
||||
metadata={"total_urls": len(urls), "operation": "WEB_SEARCH"}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Tavily web search: {str(e)}")
|
||||
return AiModelResponse(
|
||||
content="[]",
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def webCrawl(self, modelCall: AiModelCall) -> "AiModelResponse":
|
||||
"""
|
||||
WEB_CRAWL operation - crawls one URL using Tavily with link following.
|
||||
|
||||
Args:
|
||||
modelCall: AiModelCall with AiCallPromptWebCrawl as prompt
|
||||
|
||||
Returns:
|
||||
AiModelResponse with crawl results as JSON (may include multiple pages)
|
||||
"""
|
||||
try:
|
||||
# Extract parameters - find user message (not system message)
|
||||
promptContent = ""
|
||||
if modelCall.messages:
|
||||
for msg in modelCall.messages:
|
||||
if msg.get("role") == "user":
|
||||
promptContent = msg.get("content", "")
|
||||
break
|
||||
# Fallback to first message if no user message found
|
||||
if not promptContent and len(modelCall.messages) > 0:
|
||||
promptContent = modelCall.messages[0].get("content", "")
|
||||
|
||||
if not promptContent or not promptContent.strip():
|
||||
raise ValueError("Empty prompt content received for web crawl")
|
||||
|
||||
import json
|
||||
try:
|
||||
promptData = json.loads(promptContent)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse prompt content as JSON: {promptContent[:200]}")
|
||||
raise ValueError(f"Invalid JSON in prompt content: {str(e)}")
|
||||
|
||||
# Create Pydantic model
|
||||
webCrawlPrompt = AiCallPromptWebCrawl(**promptData)
|
||||
|
||||
# Perform crawl for ONE URL with link following
|
||||
# Use maxWidth as limit, maxDepth as maxDepth, and calculate maxBreadth
|
||||
crawlResults = await self._crawl(
|
||||
url=webCrawlPrompt.url,
|
||||
instructions=webCrawlPrompt.instruction,
|
||||
limit=webCrawlPrompt.maxWidth or 20, # maxWidth controls number of pages
|
||||
maxDepth=webCrawlPrompt.maxDepth or 2,
|
||||
maxBreadth=webCrawlPrompt.maxWidth or 40 # Use same as limit for breadth
|
||||
)
|
||||
|
||||
# If we got multiple pages from the crawl, we need to format them differently
|
||||
# Return the first result for backwards compatibility, but include total page count
|
||||
if crawlResults and len(crawlResults) > 0:
|
||||
# Get all pages content
|
||||
allContent = ""
|
||||
for i, result in enumerate(crawlResults, 1):
|
||||
pageHeader = f"\n{'='*60}\nPAGE {i}: {result.url}\n{'='*60}\n"
|
||||
if result.title:
|
||||
allContent += f"{pageHeader}Title: {result.title}\n\n"
|
||||
allContent += f"{result.content}\n"
|
||||
|
||||
resultData = {
|
||||
"url": webCrawlPrompt.url,
|
||||
"title": crawlResults[0].title if crawlResults[0].title else "Content",
|
||||
"content": allContent,
|
||||
"pagesCrawled": len(crawlResults),
|
||||
"pageUrls": [result.url for result in crawlResults]
|
||||
}
|
||||
else:
|
||||
resultData = {"url": webCrawlPrompt.url, "title": "", "content": "", "error": "No content extracted", "pagesCrawled": 0}
|
||||
|
||||
# Return as JSON - same format as Perplexity but with multiple pages content
|
||||
import json
|
||||
return AiModelResponse(
|
||||
content=json.dumps(resultData, indent=2),
|
||||
success=True,
|
||||
metadata={"operation": "WEB_CRAWL", "url": webCrawlPrompt.url, "pagesCrawled": len(crawlResults) if crawlResults else 0}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Tavily web crawl: {str(e)}")
|
||||
import json
|
||||
errorResult = {"error": str(e), "url": webCrawlPrompt.url if 'webCrawlPrompt' in locals() else ""}
|
||||
return AiModelResponse(
|
||||
content=json.dumps(errorResult, indent=2),
|
||||
success=False,
|
||||
error=str(e)
|
||||
)
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,157 +0,0 @@
|
|||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from datetime import datetime, UTC
|
||||
import re
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
from .documentUtility import (
|
||||
getFileExtension,
|
||||
getMimeTypeFromExtension,
|
||||
detectMimeTypeFromContent,
|
||||
detectMimeTypeFromData,
|
||||
convertDocumentDataToString
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DocumentGenerator:
|
||||
def __init__(self, service):
|
||||
self.service = service
|
||||
|
||||
def processActionResultDocuments(self, action_result, action, workflow) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Process documents produced by AI actions and convert them to ChatDocument format.
|
||||
This function handles AI-generated document data, not document references.
|
||||
Returns a list of processed document dictionaries.
|
||||
"""
|
||||
try:
|
||||
# Read documents from the standard documents field (not data.documents)
|
||||
documents = action_result.documents if action_result and hasattr(action_result, 'documents') else []
|
||||
|
||||
if not documents:
|
||||
logger.info(f"No documents found in action_result.documents for {action.execMethod}.{action.execAction}")
|
||||
return []
|
||||
|
||||
logger.info(f"Processing {len(documents)} documents from action_result.documents")
|
||||
|
||||
# Process each document from the AI action result
|
||||
processed_documents = []
|
||||
for doc in documents:
|
||||
processed_doc = self.processSingleDocument(doc, action)
|
||||
if processed_doc:
|
||||
processed_documents.append(processed_doc)
|
||||
|
||||
logger.info(f"Successfully processed {len(processed_documents)} documents")
|
||||
return processed_documents
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing action result documents: {str(e)}")
|
||||
return []
|
||||
|
||||
def processSingleDocument(self, doc: Any, action) -> Optional[Dict[str, Any]]:
|
||||
"""Process a single document from action result with simplified logic"""
|
||||
try:
|
||||
# ActionDocument objects have documentName, documentData, and mimeType
|
||||
mime_type = doc.mimeType
|
||||
if mime_type == "application/octet-stream":
|
||||
content = doc.documentData
|
||||
mime_type = detectMimeTypeFromContent(content, doc.documentName, self.service)
|
||||
|
||||
return {
|
||||
'fileName': doc.documentName,
|
||||
'fileSize': len(str(doc.documentData)),
|
||||
'mimeType': mime_type,
|
||||
'content': doc.documentData,
|
||||
'document': doc
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing single document: {str(e)}")
|
||||
return None
|
||||
|
||||
def createDocumentsFromActionResult(self, action_result, action, workflow, message_id=None) -> List[Any]:
|
||||
"""
|
||||
Create actual document objects from action result and store them in the system.
|
||||
Returns a list of created document objects with proper workflow context.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Creating documents from action result for {action.execMethod}.{action.execAction}")
|
||||
logger.info(f"Action result documents count: {len(action_result.documents) if action_result.documents else 0}")
|
||||
|
||||
processed_docs = self.processActionResultDocuments(action_result, action, workflow)
|
||||
logger.info(f"Processed {len(processed_docs)} documents")
|
||||
|
||||
created_documents = []
|
||||
for i, doc_data in enumerate(processed_docs):
|
||||
try:
|
||||
document_name = doc_data['fileName']
|
||||
document_data = doc_data['content']
|
||||
mime_type = doc_data['mimeType']
|
||||
|
||||
logger.info(f"Creating document {i+1}: {document_name} (mime: {mime_type}, content length: {len(str(document_data))})")
|
||||
|
||||
# Convert document data to string content
|
||||
content = convertDocumentDataToString(document_data, getFileExtension(document_name))
|
||||
|
||||
# Skip empty or minimal content
|
||||
minimal_content_patterns = ['{}', '[]', 'null', '""', "''"]
|
||||
if not content or content.strip() == "" or content.strip() in minimal_content_patterns:
|
||||
logger.warning(f"Empty or minimal content for document {document_name}, skipping")
|
||||
continue
|
||||
|
||||
logger.info(f"Document {document_name} has content: {len(content)} characters")
|
||||
|
||||
# Create document with file in one step
|
||||
document = self.service.createDocument(
|
||||
fileName=document_name,
|
||||
mimeType=mime_type,
|
||||
content=content,
|
||||
base64encoded=False,
|
||||
messageId=message_id
|
||||
)
|
||||
if document:
|
||||
# Set workflow context on the document if possible
|
||||
self._setDocumentWorkflowContext(document, action, workflow)
|
||||
created_documents.append(document)
|
||||
logger.info(f"Successfully created ChatDocument: {document_name} (ID: {document.id if hasattr(document, 'id') else 'N/A'}, fileId: {document.fileId if hasattr(document, 'fileId') else 'N/A'})")
|
||||
else:
|
||||
logger.error(f"Failed to create ChatDocument object for {document_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating document {doc_data.get('fileName', 'unknown')}: {str(e)}")
|
||||
continue
|
||||
|
||||
logger.info(f"Successfully created {len(created_documents)} documents")
|
||||
return created_documents
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating documents from action result: {str(e)}")
|
||||
return []
|
||||
|
||||
def _setDocumentWorkflowContext(self, document, action, workflow):
|
||||
"""Set workflow context on a document for proper routing and labeling"""
|
||||
try:
|
||||
# Get current workflow context from service center
|
||||
workflow_context = self.service.getWorkflowContext()
|
||||
workflow_stats = self.service.getWorkflowStats()
|
||||
|
||||
current_round = workflow_context.get('currentRound', 0)
|
||||
current_task = workflow_context.get('currentTask', 0)
|
||||
current_action = workflow_context.get('currentAction', 0)
|
||||
|
||||
# Try to set workflow context attributes if they exist
|
||||
if hasattr(document, 'roundNumber'):
|
||||
document.roundNumber = current_round
|
||||
if hasattr(document, 'taskNumber'):
|
||||
document.taskNumber = current_task
|
||||
if hasattr(document, 'actionNumber'):
|
||||
document.actionNumber = current_action
|
||||
if hasattr(document, 'actionId'):
|
||||
document.actionId = action.id if hasattr(action, 'id') else None
|
||||
|
||||
# Set additional workflow metadata if available
|
||||
if hasattr(document, 'workflowId'):
|
||||
document.workflowId = workflow_stats.get('workflowId', workflow.id if hasattr(workflow, 'id') else None)
|
||||
if hasattr(document, 'workflowStatus'):
|
||||
document.workflowStatus = workflow_stats.get('workflowStatus', workflow.status if hasattr(workflow, 'status') else 'unknown')
|
||||
|
||||
logger.debug(f"Set workflow context on document: Round {current_round}, Task {current_task}, Action {current_action}")
|
||||
logger.debug(f"Document workflow metadata: ID={document.workflowId if hasattr(document, 'workflowId') else 'N/A'}, Status={document.workflowStatus if hasattr(document, 'workflowStatus') else 'N/A'}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not set workflow context on document: {str(e)}")
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
# executionState.py
|
||||
# Contains all execution state management logic extracted from managerChat.py
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from datetime import datetime, UTC
|
||||
from modules.interfaces.interfaceChatModel import TaskStep, ActionResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TaskExecutionState:
|
||||
"""Manages execution state for a task with retry logic"""
|
||||
|
||||
def __init__(self, task_step: TaskStep):
|
||||
self.task_step = task_step
|
||||
self.successful_actions: List[ActionResult] = [] # Preserved across retries
|
||||
self.failed_actions: List[ActionResult] = [] # For analysis
|
||||
self.current_action_index = 0
|
||||
self.retry_count = 0
|
||||
self.max_retries = 3
|
||||
|
||||
def addSuccessfulAction(self, action_result: ActionResult):
|
||||
"""Add a successful action to the state"""
|
||||
self.successful_actions.append(action_result)
|
||||
self.current_action_index += 1
|
||||
|
||||
def addFailedAction(self, action_result: ActionResult):
|
||||
"""Add a failed action to the state for analysis"""
|
||||
self.failed_actions.append(action_result)
|
||||
self.current_action_index += 1
|
||||
|
||||
def canRetry(self) -> bool:
|
||||
"""Check if task can be retried"""
|
||||
return self.retry_count < self.max_retries
|
||||
|
||||
def incrementRetryCount(self):
|
||||
"""Increment retry count"""
|
||||
self.retry_count += 1
|
||||
|
||||
def getFailurePatterns(self) -> list:
|
||||
"""Analyze failure patterns from failed actions"""
|
||||
patterns = []
|
||||
for action in self.failed_actions:
|
||||
error = action.error.lower() if action.error else ''
|
||||
if "timeout" in error:
|
||||
patterns.append("timeout_issues")
|
||||
elif "document_not_found" in error or "file not found" in error:
|
||||
patterns.append("document_reference_issues")
|
||||
elif "empty_result" in error or "no content" in error:
|
||||
patterns.append("content_extraction_issues")
|
||||
elif "invalid_format" in error or "wrong format" in error:
|
||||
patterns.append("format_issues")
|
||||
elif "permission" in error or "access denied" in error:
|
||||
patterns.append("permission_issues")
|
||||
return list(set(patterns))
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,770 +0,0 @@
|
|||
# promptFactory.py
|
||||
# Contains all prompt creation functions extracted from managerChat.py
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from modules.interfaces.interfaceChatModel import TaskContext, ReviewContext
|
||||
|
||||
# Set up logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Prompt creation helpers extracted from managerChat.py
|
||||
|
||||
def _getPreviousRoundContext(service, workflow) -> str:
|
||||
"""Get context from previous workflow rounds to help understand follow-up prompts"""
|
||||
try:
|
||||
if not workflow or not hasattr(workflow, 'messages') or not workflow.messages:
|
||||
return ""
|
||||
|
||||
# Get current round number
|
||||
current_round = getattr(workflow, 'currentRound', 0)
|
||||
|
||||
# If this is round 0 or 1, there's no previous context
|
||||
if current_round <= 1:
|
||||
return ""
|
||||
|
||||
# Find messages from previous rounds (rounds before current)
|
||||
previous_messages = []
|
||||
for message in workflow.messages:
|
||||
message_round = getattr(message, 'roundNumber', 0)
|
||||
if message_round > 0 and message_round < current_round:
|
||||
previous_messages.append(message)
|
||||
|
||||
if not previous_messages:
|
||||
return ""
|
||||
|
||||
# Sort by round number and sequence to get chronological order
|
||||
previous_messages.sort(key=lambda msg: (getattr(msg, 'roundNumber', 0), getattr(msg, 'sequenceNr', 0)))
|
||||
|
||||
# Build context summary
|
||||
context_parts = []
|
||||
current_round_context = {}
|
||||
|
||||
for message in previous_messages:
|
||||
round_num = getattr(message, 'roundNumber', 0)
|
||||
if round_num not in current_round_context:
|
||||
current_round_context[round_num] = {
|
||||
'user_inputs': [],
|
||||
'assistant_responses': [],
|
||||
'task_outcomes': [],
|
||||
'documents_processed': []
|
||||
}
|
||||
|
||||
# Categorize messages
|
||||
if message.role == 'user':
|
||||
current_round_context[round_num]['user_inputs'].append(message.message)
|
||||
elif message.role == 'assistant':
|
||||
# Check if it's a task completion or error message
|
||||
if 'task' in message.message.lower() and ('completed' in message.message.lower() or 'failed' in message.message.lower() or 'error' in message.message.lower()):
|
||||
current_round_context[round_num]['task_outcomes'].append(message.message)
|
||||
else:
|
||||
current_round_context[round_num]['assistant_responses'].append(message.message)
|
||||
|
||||
# Check for document processing
|
||||
if hasattr(message, 'documents') and message.documents:
|
||||
doc_names = [doc.fileName for doc in message.documents if hasattr(doc, 'fileName')]
|
||||
if doc_names:
|
||||
current_round_context[round_num]['documents_processed'].extend(doc_names)
|
||||
|
||||
# Build context summary
|
||||
for round_num in sorted(current_round_context.keys()):
|
||||
round_data = current_round_context[round_num]
|
||||
context_parts.append(f"ROUND {round_num} CONTEXT:")
|
||||
|
||||
if round_data['user_inputs']:
|
||||
context_parts.append(f" User requests: {'; '.join(round_data['user_inputs'])}")
|
||||
|
||||
if round_data['task_outcomes']:
|
||||
context_parts.append(f" Task outcomes: {'; '.join(round_data['task_outcomes'])}")
|
||||
|
||||
if round_data['documents_processed']:
|
||||
context_parts.append(f" Documents processed: {', '.join(set(round_data['documents_processed']))}")
|
||||
|
||||
if context_parts:
|
||||
return "\n".join(context_parts)
|
||||
else:
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting previous round context: {str(e)}")
|
||||
return ""
|
||||
|
||||
def createTaskPlanningPrompt(context: TaskContext, service) -> str:
|
||||
"""Create enhanced prompt for task planning with user-friendly message generation and language detection"""
|
||||
# Get user language directly from service.user.language
|
||||
user_language = service.user.language if service and service.user else 'en'
|
||||
|
||||
# Extract user request from context - use Pydantic model directly
|
||||
user_request = context.task_step.objective if context.task_step else 'No request specified'
|
||||
|
||||
# Extract available documents from context - use Pydantic model directly
|
||||
available_documents = context.available_documents or "No documents available"
|
||||
|
||||
# Get previous workflow round context for better understanding of follow-up prompts
|
||||
previous_round_context = _getPreviousRoundContext(service, context.workflow)
|
||||
|
||||
return f"""You are a task planning AI that analyzes user requests and creates structured task plans with user-friendly feedback messages.
|
||||
|
||||
USER REQUEST: {user_request}
|
||||
|
||||
AVAILABLE DOCUMENTS: {available_documents}
|
||||
|
||||
PREVIOUS WORKFLOW ROUNDS CONTEXT:
|
||||
{previous_round_context if previous_round_context else "No previous workflow rounds - this is the first round."}
|
||||
|
||||
INSTRUCTIONS:
|
||||
1. Analyze the user request, available documents, and previous workflow rounds context
|
||||
2. If the user request appears to be a follow-up (like "try again", "versuche es nochmals", "retry", etc.),
|
||||
use the PREVIOUS WORKFLOW ROUNDS CONTEXT to understand what the user wants to retry or continue
|
||||
3. Group related topics and sequential steps into single, comprehensive tasks
|
||||
4. Focus on business outcomes, not technical operations
|
||||
5. Each task should produce meaningful, usable outputs
|
||||
6. Ensure proper handover between tasks using result labels
|
||||
7. Detect the language of the user request and include it in languageUserDetected
|
||||
8. Generate user-friendly messages for each task in the user's request language
|
||||
9. Return a JSON object with the exact structure shown below
|
||||
|
||||
TASK GROUPING PRINCIPLES:
|
||||
- COMBINE RELATED TOPICS: Group related subjects, sequential steps, or workflow-structured activities into single tasks
|
||||
- SEQUENTIAL WORKFLOWS: If the user says "first do this, then that, then that" → create ONE task that handles the entire sequence
|
||||
- SIMILAR CONTENT: If multiple items deal with the same subject matter → combine into ONE comprehensive task
|
||||
- ONLY SPLIT WHEN DIFFERENT: Create separate tasks ONLY when the user explicitly wants different, independent things
|
||||
|
||||
EXAMPLES OF GOOD TASK GROUPING:
|
||||
|
||||
COMBINE INTO ONE TASK:
|
||||
- "Analyze the documents, extract key insights, and create a summary report" → ONE task: "Analyze documents and create comprehensive summary report"
|
||||
- "First check my emails, then respond to urgent ones, then organize my inbox" → ONE task: "Process and organize email inbox with priority responses"
|
||||
- "Review the budget, analyze spending patterns, and suggest cost-cutting measures" → ONE task: "Comprehensive budget analysis with optimization recommendations"
|
||||
- "Create a business strategy, develop marketing plan, and prepare presentation" → ONE task: "Develop complete business strategy with marketing plan and presentation"
|
||||
|
||||
SPLIT INTO MULTIPLE TASKS:
|
||||
- "Create a business strategy for Q4" AND "Check my emails for messages from my assistant" → TWO separate tasks (different subjects)
|
||||
- "Analyze customer feedback" AND "Prepare quarterly financial report" → TWO separate tasks (different business areas)
|
||||
- "Review project timeline" AND "Update employee handbook" → TWO separate tasks (unrelated activities)
|
||||
|
||||
TASK PLANNING PRINCIPLES:
|
||||
- Break down complex requests into logical, sequential steps
|
||||
- Focus on business value and outcomes
|
||||
- Keep tasks at a meaningful level of abstraction
|
||||
- Each task should produce results that can be used by subsequent tasks
|
||||
- Ensure clear dependencies and handovers between tasks
|
||||
- Provide clear, actionable user messages in the user's request language
|
||||
- Group related activities to minimize task fragmentation
|
||||
- Only create multiple tasks when dealing with truly different, independent objectives
|
||||
|
||||
FOLLOW-UP PROMPT HANDLING:
|
||||
- If the user request is a follow-up (e.g., "try again", "versuche es nochmals", "retry", "continue", "proceed"),
|
||||
analyze the PREVIOUS WORKFLOW ROUNDS CONTEXT to understand what failed or was incomplete
|
||||
- Use the previous round's user requests and task outcomes to determine what the user wants to retry
|
||||
- If previous rounds failed due to missing documents, and documents are now available,
|
||||
create tasks that use the newly available documents to accomplish the original request
|
||||
- Maintain the same business objective from previous rounds but adapt to current available resources
|
||||
|
||||
SPECIFIC SCENARIO HANDLING:
|
||||
- If previous round failed with "documents missing" error and current round has documents available,
|
||||
the user likely wants to retry the same operation with the newly provided documents
|
||||
- Example: Previous round "speichere mir die 3 dokumente im sharepoint unter xxx" failed due to missing documents,
|
||||
current round "versuche es nochmals" with documents should retry the SharePoint save operation
|
||||
- Always check if the current request is a retry by looking for retry keywords and previous round context
|
||||
|
||||
REQUIRED JSON STRUCTURE:
|
||||
{{
|
||||
"overview": "Brief description of the overall plan",
|
||||
"languageUserDetected": "en", // Language code detected from user request (en, de, fr, it, es, etc.)
|
||||
"userMessage": "User-friendly message explaining the task plan in user's request language",
|
||||
"tasks": [
|
||||
{{
|
||||
"id": "task_1",
|
||||
"objective": "Clear business objective this task accomplishes (combining related activities)",
|
||||
"dependencies": ["task_0"], // IDs of tasks that must complete first
|
||||
"success_criteria": ["criteria1", "criteria2"],
|
||||
"estimated_complexity": "low|medium|high",
|
||||
"userMessage": "User-friendly message explaining what this task will accomplish in user's request language"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
EXAMPLES OF GOOD TASK OBJECTIVES (COMBINING RELATED ACTIVITIES):
|
||||
- "Analyze documents and extract key insights for business communication"
|
||||
- "Create professional business communication incorporating analyzed information"
|
||||
- "Execute business communication using specified channels and document outcomes"
|
||||
- "Develop comprehensive business strategy with implementation roadmap and success metrics"
|
||||
|
||||
EXAMPLES OF GOOD SUCCESS CRITERIA:
|
||||
- "Key insights extracted and ready for business use"
|
||||
- "Professional communication created with clear business value"
|
||||
- "Business communication successfully delivered and documented"
|
||||
- "All outcomes properly documented and accessible"
|
||||
|
||||
EXAMPLES OF BAD TASK OBJECTIVES:
|
||||
- "Read the PDF file" (too granular - should be "Analyze document content")
|
||||
- "Convert data to CSV" (implementation detail - should be "Structure data for analysis")
|
||||
- "Send email" (too specific - should be "Deliver business communication")
|
||||
|
||||
LANGUAGE DETECTION:
|
||||
- Analyze the user request text to identify the language
|
||||
- Use standard language codes: en (English), de (German), fr (French), it (Italian), es (Spanish), etc.
|
||||
- If the language cannot be determined, use "en" as default
|
||||
- Include the detected language in the languageUserDetected field
|
||||
|
||||
NOTE: Respond with ONLY the JSON object. Do not include any explanatory text."""
|
||||
|
||||
async def createActionDefinitionPrompt(context: TaskContext, service) -> str:
|
||||
"""Create enhanced prompt for action generation with user-friendly messages and enhanced document context"""
|
||||
methodList = service.getMethodsList()
|
||||
method_actions = {}
|
||||
for sig in methodList:
|
||||
if '.' in sig:
|
||||
method, rest = sig.split('.', 1)
|
||||
action = rest.split('(')[0]
|
||||
method_actions.setdefault(method, []).append((action, sig))
|
||||
|
||||
messageSummary = await service.summarizeChat(context.workflow.messages) if context.workflow else ""
|
||||
|
||||
# Get enhanced document context using the new method
|
||||
available_documents_str = service.getEnhancedDocumentContext()
|
||||
|
||||
connRefs = service.getConnectionReferenceList()
|
||||
|
||||
# Create a structured JSON format for better AI parsing
|
||||
# This replaces the old hard-to-read format with a clean JSON structure
|
||||
# that the AI can easily parse and understand
|
||||
available_methods_json = {}
|
||||
for method, actions in method_actions.items():
|
||||
available_methods_json[method] = {}
|
||||
# Get the method instance for accessing docstrings
|
||||
method_instance = service.methods.get(method, {}).get('instance') if hasattr(service, 'methods') else None
|
||||
|
||||
for action, sig in actions:
|
||||
# Parse the signature to extract parameters
|
||||
if '(' in sig and ')' in sig:
|
||||
# Extract parameters from signature
|
||||
params_start = sig.find('(')
|
||||
params_end = sig.find(')')
|
||||
params_str = sig[params_start+1:params_end]
|
||||
|
||||
# Parse parameters directly from the docstring - much simpler and more reliable!
|
||||
parameters = []
|
||||
|
||||
# Get the actual function's docstring
|
||||
if method_instance and hasattr(method_instance, action):
|
||||
func = getattr(method_instance, action)
|
||||
if hasattr(func, '__doc__') and func.__doc__:
|
||||
docstring = func.__doc__
|
||||
|
||||
# Parse Parameters section from docstring
|
||||
lines = docstring.split('\n')
|
||||
in_parameters = False
|
||||
for i, line in enumerate(lines):
|
||||
original_line = line
|
||||
line = line.strip()
|
||||
|
||||
if line == 'Parameters:':
|
||||
in_parameters = True
|
||||
continue
|
||||
elif in_parameters and line and not original_line.startswith(' ') and not original_line.startswith('\t'):
|
||||
# End of parameters section
|
||||
break
|
||||
elif in_parameters and (original_line.startswith(' ') or original_line.startswith('\t')):
|
||||
# This is a parameter line - already stripped
|
||||
# Format: "paramName (type): description"
|
||||
if ':' in line:
|
||||
# Find the colon that separates param from description
|
||||
colon_pos = line.find(':')
|
||||
param_part = line[:colon_pos].strip()
|
||||
description = line[colon_pos+1:].strip()
|
||||
|
||||
# Parse parameter name and type
|
||||
if '(' in param_part and ')' in param_part:
|
||||
param_name = param_part.split('(')[0].strip()
|
||||
type_part = param_part[param_part.find('(')+1:param_part.find(')')].strip()
|
||||
|
||||
# Check if optional
|
||||
is_optional = 'optional' in type_part
|
||||
param_type = type_part.replace('optional', '').strip().rstrip(',').strip()
|
||||
|
||||
parameters.append({
|
||||
"name": param_name,
|
||||
"type": param_type,
|
||||
"description": description,
|
||||
"required": not is_optional
|
||||
})
|
||||
|
||||
available_methods_json[method][action] = {
|
||||
"signature": sig,
|
||||
"parameters": parameters,
|
||||
"description": f"{method}.{action} action"
|
||||
}
|
||||
|
||||
# Convert to a compact, AI-friendly format
|
||||
available_methods_str = f"""
|
||||
AVAILABLE ACTIONS (JSON format for better AI parsing):
|
||||
{json.dumps(available_methods_json, indent=1, separators=(',', ':'))}
|
||||
"""
|
||||
retry_context = ""
|
||||
if context.retry_count and context.retry_count > 0:
|
||||
retry_context = f"""
|
||||
RETRY CONTEXT (Attempt {context.retry_count}):
|
||||
Previous action results that failed or were incomplete:
|
||||
"""
|
||||
for i, result in enumerate(context.previous_action_results or []):
|
||||
retry_context += f"- Action {i+1}: ActionResult\n"
|
||||
retry_context += f" Status: {result.success and 'success' or 'failed'}\n"
|
||||
retry_context += f" Error: {result.error or 'None'}\n"
|
||||
# Check if result has documents and show document info
|
||||
if result.documents:
|
||||
doc_info = f"Documents: {len(result.documents)} document(s)"
|
||||
if result.documents[0].documentName:
|
||||
doc_info += f" - {result.documents[0].documentName}"
|
||||
retry_context += f" {doc_info}\n"
|
||||
else:
|
||||
retry_context += f" Documents: None\n"
|
||||
|
||||
if context.previous_review_result:
|
||||
retry_context += f"""
|
||||
Previous review feedback:
|
||||
- Status: {context.previous_review_result.status or 'unknown'}
|
||||
- Reason: {context.previous_review_result.reason or 'No reason provided'}
|
||||
- Quality Score: {context.previous_review_result.quality_score or 0}/10
|
||||
- Unmet Criteria: {', '.join(context.previous_review_result.unmet_criteria or [])}
|
||||
"""
|
||||
|
||||
# Use Pydantic model directly - no need for getattr
|
||||
success_criteria_str = ', '.join(context.task_step.success_criteria) if context.task_step and context.task_step.success_criteria else 'No criteria specified'
|
||||
previous_results_str = ', '.join(context.previous_results) if context.previous_results else 'None'
|
||||
improvements_str = str(context.improvements) if context.improvements else 'None'
|
||||
available_connections_str = '\n'.join(f"- {conn}" for conn in connRefs)
|
||||
|
||||
# Get user language from service - this is the correct way
|
||||
user_language = service.user.language if service and service.user else 'en'
|
||||
|
||||
# Get current workflow context for dynamic examples
|
||||
workflow_context = service.getWorkflowContext()
|
||||
current_round = workflow_context.get('currentRound', 0)
|
||||
current_task = workflow_context.get('currentTask', 1)
|
||||
|
||||
prompt = f"""
|
||||
You are an action generation AI that creates specific actions to accomplish a task step with user-friendly messages.
|
||||
|
||||
DOCUMENT REFERENCE TYPES:
|
||||
- docItem: Reference to a single document
|
||||
- docList: Reference to a group of documents
|
||||
- round{{round_number}}_task{{task_number}}_action{{action_number}}_{{context}}: Reference to resulting document list from previous action
|
||||
|
||||
USAGE GUIDE:
|
||||
- Use docItem when you need a specific document: "docItem:doc_123:component_diagram.pdf"
|
||||
- Use docList when you need all documents in a group: "docList:msg_456:AnalysisResults"
|
||||
- Use round/task/action format when referencing outputs from previous actions: "round{current_round}_task{current_task}_action2_AnalysisResults"
|
||||
|
||||
CRITICAL DOCUMENT REFERENCE RULES:
|
||||
- ONLY use the exact labels listed in AVAILABLE DOCUMENTS below, or result labels from previous actions
|
||||
- When generating multiple actions, you may only use as input documents those that are already present in AVAILABLE DOCUMENTS or produced by actions that come earlier in the list. Do NOT use as input any document label that will be produced by a later action.
|
||||
- If AVAILABLE DOCUMENTS shows "NO DOCUMENTS AVAILABLE", you CANNOT create document extraction actions. Instead, create actions that generate new content or inform the user that documents are needed, if you miss something.
|
||||
|
||||
CURRENT WORKFLOW CONTEXT:
|
||||
- Current Round: {current_round}
|
||||
- Current Task: {current_task}
|
||||
- Use these values when creating resultLabel references
|
||||
|
||||
TASK STEP: {context.task_step.objective if context.task_step else 'No task step specified'} (ID: {context.task_step.id if context.task_step else 'unknown'})
|
||||
|
||||
SUCCESS CRITERIA: {success_criteria_str}
|
||||
|
||||
CONTEXT - Chat History:
|
||||
{messageSummary}
|
||||
|
||||
WORKFLOW CONTEXT - Previous Messages Summary:
|
||||
The following summarizes key information from previous workflow interactions to provide context for continued workflows:
|
||||
- Previous user inputs and their outcomes
|
||||
- Key decisions and findings from earlier tasks
|
||||
- Document processing results and insights
|
||||
- User preferences and requirements established
|
||||
|
||||
This context helps ensure your actions build upon previous work and maintain consistency with the overall workflow objectives.
|
||||
|
||||
AVAILABLE METHODS AND ACTIONS (with signatures):
|
||||
{available_methods_str}
|
||||
|
||||
AVAILABLE CONNECTIONS:
|
||||
{available_connections_str}
|
||||
|
||||
AVAILABLE DOCUMENTS:
|
||||
{available_documents_str}
|
||||
|
||||
DOCUMENT REFERENCE EXAMPLES:
|
||||
✅ CORRECT: Use exact references from AVAILABLE DOCUMENTS above or result labels from previous actions
|
||||
- "docList:msg_456:diagram_analysis_results" (access all documents in a list)
|
||||
- "docItem:doc_123:component_diagram.pdf" (access specific document)
|
||||
- "round{current_round}_task{current_task}_action3_contextinfo" (document list from previous action)
|
||||
|
||||
❌ INCORRECT: These will cause errors
|
||||
- "msg_xxx:documents" (invalid format - missing docList/docItem prefix)
|
||||
- "task_2_results" (not a valid reference - use exact references from AVAILABLE DOCUMENTS)
|
||||
- Inventing document IDs not produces from a preceeding action
|
||||
|
||||
PREVIOUS RESULTS: {previous_results_str}
|
||||
IMPROVEMENTS NEEDED: {improvements_str}
|
||||
|
||||
PREVIOUS TASK HANDOVER CONTEXT:
|
||||
{context.previous_handover.workflowSummary if context.previous_handover and context.previous_handover.workflowSummary else 'No previous task handover available'}
|
||||
|
||||
{retry_context}
|
||||
|
||||
ACTION GENERATION PRINCIPLES:
|
||||
- Create meaningful actions per task step
|
||||
- Use comprehensive AI prompts for document processing
|
||||
- Focus on business outcomes, not technical operations
|
||||
- Combine related operations into single actions when possible
|
||||
- Use the task's AI prompt if provided, or create a comprehensive one
|
||||
- Each action should produce meaningful, usable outputs
|
||||
- For document extraction, ensure prompts are specific and detailed
|
||||
- Include validation steps in extraction prompts
|
||||
- If this is a retry, learn from previous failures and improve the approach
|
||||
- Address specific issues mentioned in previous review feedback
|
||||
- When specifying expectedDocumentFormats, ensure AI prompts explicitly request pure data without markdown formatting
|
||||
- Generate user-friendly messages for each action in the user's language ({user_language})
|
||||
|
||||
USER LANGUAGE: {user_language} - All user messages must be generated in this language.
|
||||
|
||||
DOCUMENT ROUTING GUIDANCE:
|
||||
- Each action should produce documents with a clear resultLabel for routing
|
||||
- Use consistent naming: "round{current_round}_task{{task_id}}_action{{action_number}}_{{descriptive_label}}"
|
||||
- Ensure document flow: Action A produces documents that Action B can consume
|
||||
- Document labels should be descriptive of content, not just "results" or "output"
|
||||
- Consider what subsequent actions will need and structure outputs accordingly
|
||||
|
||||
INSTRUCTIONS:
|
||||
- Generate actions to accomplish this task step using available documents, connections, and previous results
|
||||
- Use docItem for single documents and docList for groups of documents as shown in AVAILABLE DOCUMENTS
|
||||
- If AVAILABLE DOCUMENTS shows "NO DOCUMENTS AVAILABLE", you cannot create document extraction actions. Instead, create actions that generate new content or inform the user that documents are needed.
|
||||
- Always pass documentList as a LIST of references (docItem and/or docList) - this list CANNOT be empty for document extraction actions
|
||||
- For referencing documents from previous actions, use the format "round{{round_number}}_task{{task_number}}_action{{action_number}}_{{context}}"
|
||||
- For resultLabel, use the format: "round{current_round}_task{{task_id}}_action{{action_number}}_{{short_label}}" where:
|
||||
- {{round_number}} = the current round number ({current_round})
|
||||
- {{task_id}} = the current task's id ({current_task})
|
||||
- {{action_number}} = the sequence number of the action within the task (e.g., 1, 2, 3)
|
||||
- {{short_label}} = a short, descriptive label for the output (e.g., "AnalysisResults")
|
||||
Example: "round{current_round}_task{current_task}_action1_AnalysisResults"
|
||||
- If this is a retry, ensure the new actions address the specific issues from previous attempts
|
||||
- Follow the JSON structure below. All fields are required.
|
||||
|
||||
REQUIRED JSON STRUCTURE:
|
||||
{{
|
||||
"actions": [
|
||||
{{
|
||||
"method": "method_name", // Use only the method name (e.g., "document")
|
||||
"action": "action_name", // Use only the action name (e.g., "extract")
|
||||
"parameters": {{
|
||||
"documentList": ["docItem:doc_abc:round{current_round}_task{current_task}_action1_AnalysisResults", "round{current_round}_task{current_task}_action1_input"],
|
||||
"aiPrompt": "Comprehensive AI prompt describing what to accomplish"
|
||||
}},
|
||||
"resultLabel": "round{current_round}_task{current_task}_action2_AnalysisResults",
|
||||
"expectedDocumentFormats": [ // OPTIONAL: Specify expected document formats when needed
|
||||
{{
|
||||
"extension": ".txt",
|
||||
"mimeType": "text/plain",
|
||||
"description": "Structured data output"
|
||||
}}
|
||||
],
|
||||
"description": "What this action accomplishes (business outcome)",
|
||||
"userMessage": "User-friendly message explaining what this action will do in the user's language"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
FIELD REQUIREMENTS:
|
||||
- "method": Must be from AVAILABLE METHODS
|
||||
- "action": Must be valid for the method
|
||||
- "parameters": Method-specific, must include documentList as a list if required by the signature
|
||||
- "resultLabel": Must follow the format above (e.g., "round{current_round}_task{current_task}_action3_AnalysisResults")
|
||||
- "expectedDocumentFormats": OPTIONAL - Only specify when you need to control output format
|
||||
- Use when you need specific file types (e.g., CSV for data, JSON for structured output)
|
||||
- Omit when format is flexible (e.g., folder queries with mixed file types)
|
||||
- Each format should specify: extension, mimeType, description
|
||||
- When using expectedDocumentFormats, ensure the aiPrompt explicitly requests pure data without markdown formatting
|
||||
- "description": Clear summary of the business outcome
|
||||
- "userMessage": User-friendly message explaining what the action will accomplish in the user's language
|
||||
|
||||
EXAMPLES OF GOOD ACTIONS:
|
||||
|
||||
1. Document analysis with specific output format and user message:
|
||||
{{
|
||||
"method": "document",
|
||||
"action": "extract",
|
||||
"parameters": {{
|
||||
"documentList": ["docItem:doc_57520394-6b6d-41c2-b641-bab3fc6d7f4b:candidate_profile.txt"],
|
||||
"aiPrompt": "Extract and analyze the candidate's qualifications, experience, skills, and suitability for the product designer position. Identify key strengths, relevant experience, technical skills, and any areas of concern. Provide a comprehensive assessment that can be used for evaluation."
|
||||
}},
|
||||
"resultLabel": "round{current_round}_task{current_task}_action2_candidate_analysis",
|
||||
"expectedDocumentFormats": [
|
||||
{{
|
||||
"extension": ".json",
|
||||
"mimeType": "application/json",
|
||||
"description": "Structured candidate analysis data"
|
||||
}}
|
||||
],
|
||||
"description": "Comprehensive analysis of candidate profile for evaluation",
|
||||
"userMessage": "Ich analysiere das Kandidatenprofil und extrahiere alle wichtigen Informationen für die Bewertung."
|
||||
}}
|
||||
|
||||
2. Multi-document processing with user message:
|
||||
{{
|
||||
"method": "document",
|
||||
"action": "extract",
|
||||
"parameters": {{
|
||||
"documentList": ["docList:msg_456:candidate_analysis_results"],
|
||||
"aiPrompt": "Compare all candidate profiles and create an evaluation matrix. Rate each candidate on technical skills, experience level, cultural fit, portfolio quality, and communication skills. Provide clear rankings and recommendations for the product designer position."
|
||||
}},
|
||||
"resultLabel": "round{current_round}_task{current_task}_action5_evaluation_matrix",
|
||||
"description": "Create comprehensive evaluation matrix comparing all candidates",
|
||||
"userMessage": "Ich vergleiche alle Kandidatenprofile und erstelle eine umfassende Bewertungsmatrix mit klaren Empfehlungen."
|
||||
}}
|
||||
|
||||
3. Data extraction with specific CSV format and user message:
|
||||
{{
|
||||
"method": "document",
|
||||
"action": "extract",
|
||||
"parameters": {{
|
||||
"documentList": ["docItem:doc_abc:table_data.pdf"],
|
||||
"aiPrompt": "Extract all table data and convert to structured CSV format with proper headers and data types. IMPORTANT: Deliver pure CSV data without any markdown formatting, code blocks, or additional text. Output only the CSV content with proper headers and data rows."
|
||||
}},
|
||||
"resultLabel": "round{current_round}_task{current_task}_action2_structured_data",
|
||||
"expectedDocumentFormats": [
|
||||
{{
|
||||
"extension": ".csv",
|
||||
"mimeType": "text/csv",
|
||||
"description": "Structured table data in CSV format"
|
||||
}}
|
||||
],
|
||||
"description": "Extract and structure table data for analysis",
|
||||
"userMessage": "Ich extrahiere alle Tabellendaten und konvertiere sie in ein strukturiertes CSV-Format für die weitere Analyse."
|
||||
}}
|
||||
|
||||
4. Comprehensive summary report with user message:
|
||||
{{
|
||||
"method": "document",
|
||||
"action": "generateReport",
|
||||
"parameters": {{
|
||||
"documentList": ["docList:msg_456:candidate_analysis_results"],
|
||||
"title": "Comprehensive Candidate Evaluation Report"
|
||||
}},
|
||||
"resultLabel": "round{current_round}_task{current_task}_action6_summary_report",
|
||||
"description": "Generate a comprehensive, professional HTML report consolidating all candidate analyses and findings",
|
||||
"userMessage": "Ich erstelle einen umfassenden, professionellen Bericht, der alle Kandidatenanalysen und Erkenntnisse zusammenfasst."
|
||||
}}
|
||||
|
||||
5. Correct chaining of actions within a task:
|
||||
{{
|
||||
"actions": [
|
||||
{{
|
||||
"method": "document",
|
||||
"action": "extract",
|
||||
"parameters": {{
|
||||
"documentList": ["docItem:doc_abc:round{current_round}_task{current_task}_action1_file1.txt"],
|
||||
"aiPrompt": "Extract data from file1."
|
||||
}},
|
||||
"resultLabel": "round{current_round}_task{current_task}_action1_extracted_data",
|
||||
"description": "Extract data from file1.",
|
||||
"userMessage": "Ich extrahiere die Daten aus der Datei."
|
||||
}},
|
||||
{{
|
||||
"method": "document",
|
||||
"action": "generateReport",
|
||||
"parameters": {{
|
||||
"documentList": ["round{current_round}_task{current_task}_action1_extracted_data"],
|
||||
"title": "Report"
|
||||
}},
|
||||
"resultLabel": "round{current_round}_task{current_task}_action2_report",
|
||||
"description": "Generate report from extracted data.",
|
||||
"userMessage": "Ich erstelle einen Bericht basierend auf den extrahierten Daten."
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
6. When no documents are available (NO DOCUMENTS AVAILABLE scenario):
|
||||
{{
|
||||
"method": "document",
|
||||
"action": "generateReport",
|
||||
"parameters": {{
|
||||
"documentList": [],
|
||||
"title": "Workflow Status Report"
|
||||
}},
|
||||
"resultLabel": "round{current_round}_task{current_task}_action1_status_report",
|
||||
"description": "Generate a status report informing the user that no documents are available for processing and requesting document upload or alternative input.",
|
||||
"userMessage": "Ich erstelle einen Statusbericht, der Sie darüber informiert, dass keine Dokumente zur Verarbeitung verfügbar sind und um Dokumente oder alternative Eingaben bittet."
|
||||
}}
|
||||
|
||||
IMPORTANT NOTES:
|
||||
- Respond with ONLY the JSON object. Do not include any explanatory text.
|
||||
- Before creating any document extraction action, verify that AVAILABLE DOCUMENTS contains actual document references.
|
||||
- If AVAILABLE DOCUMENTS shows "NO DOCUMENTS AVAILABLE", use example 6 above to create a status report action instead of document extraction.
|
||||
- Always include a user-friendly userMessage for each action in the user's language ({user_language}).
|
||||
- The examples above show German user messages as reference - adapt the language to match the USER LANGUAGE specified above."""
|
||||
|
||||
logging.debug(f"[ACTION PLAN PROMPT] Enhanced Document Context:\n{available_documents_str}\nUser Connections Section:\n{available_connections_str}\nAvailable Methods (detailed):\n{available_methods_str}")
|
||||
|
||||
return prompt
|
||||
|
||||
def createResultReviewPrompt(context: ReviewContext, service) -> str:
|
||||
"""Create enhanced prompt for result review with user-friendly messages and document context"""
|
||||
# Build comprehensive action and result summary
|
||||
action_summary = ""
|
||||
for i, action in enumerate(context.task_actions or []):
|
||||
action_summary += f"\nACTION {i+1}: {action.execMethod}.{action.execAction}\n"
|
||||
action_summary += f" Status: {action.status}\n"
|
||||
if action.error:
|
||||
action_summary += f" Error: {action.error}\n"
|
||||
if action.resultDocuments:
|
||||
action_summary += f" Documents: {len(action.resultDocuments)} document(s)\n"
|
||||
for doc in action.resultDocuments:
|
||||
# Use Pydantic model properties directly
|
||||
fileName = doc.fileName
|
||||
fileSize = doc.fileSize
|
||||
mimeType = doc.mimeType
|
||||
|
||||
action_summary += f" - {fileName} ({fileSize} bytes, {mimeType})\n"
|
||||
else:
|
||||
action_summary += f" Documents: None\n"
|
||||
|
||||
# Build result summary with SIMPLE DOCUMENT VALIDATION
|
||||
result_summary = ""
|
||||
document_validation_summary = ""
|
||||
document_access_warnings = []
|
||||
|
||||
if context.action_results:
|
||||
for i, result in enumerate(context.action_results):
|
||||
result_summary += f"\nRESULT {i+1}:\n"
|
||||
result_summary += f" Success: {result.success}\n"
|
||||
if result.error:
|
||||
result_summary += f" Error: {result.error}\n"
|
||||
|
||||
if result.documents:
|
||||
result_summary += f" Documents: {len(result.documents)} document(s)\n"
|
||||
for doc in result.documents:
|
||||
# Use correct ActionDocument attributes
|
||||
doc_name = getattr(doc, 'documentName', 'Unknown')
|
||||
doc_mime = getattr(doc, 'mimeType', 'Unknown')
|
||||
doc_data = getattr(doc, 'documentData', None)
|
||||
|
||||
result_summary += f" - {doc_name} ({doc_mime})\n"
|
||||
|
||||
# SIMPLE VALIDATION: Check if documents exist and have basic properties
|
||||
validation_status = "✅ Valid"
|
||||
if not doc_name or str(doc_name).strip() == "":
|
||||
validation_status = "❌ Missing document name"
|
||||
elif not doc_mime or str(doc_mime).strip() == "":
|
||||
validation_status = "❌ Missing MIME type"
|
||||
elif doc_data is None:
|
||||
validation_status = "⚠️ No document data"
|
||||
elif hasattr(doc_data, '__len__') and len(doc_data) == 0:
|
||||
validation_status = "⚠️ Empty document data"
|
||||
|
||||
document_validation_summary += f" - {doc_name}: {validation_status}\n"
|
||||
else:
|
||||
result_summary += f" Documents: None\n"
|
||||
document_validation_summary += f" - No documents produced\n"
|
||||
|
||||
# Get enhanced document context using the new method
|
||||
document_context = service.getEnhancedDocumentContext()
|
||||
|
||||
# Get user language from service
|
||||
user_language = service.user.language if service and service.user else 'en'
|
||||
|
||||
# Build warnings section (only for critical issues)
|
||||
warnings_section = ""
|
||||
if document_access_warnings:
|
||||
warnings_section = f"""
|
||||
⚠️ DOCUMENT VALIDATION ISSUES:
|
||||
{chr(10).join(f"- {warning}" for warning in document_access_warnings)}
|
||||
"""
|
||||
|
||||
prompt = f"""
|
||||
You are a result review AI that evaluates task execution results and provides feedback with user-friendly messages.
|
||||
|
||||
TASK OBJECTIVE: {context.task_step.objective if context.task_step else 'No task objective specified'}
|
||||
SUCCESS CRITERIA: {', '.join(context.task_step.success_criteria) if context.task_step and context.task_step.success_criteria else 'No success criteria specified'}
|
||||
|
||||
EXECUTION SUMMARY:
|
||||
{action_summary}
|
||||
|
||||
RESULT SUMMARY:
|
||||
{result_summary}
|
||||
|
||||
{warnings_section}
|
||||
|
||||
DOCUMENT VALIDATION SUMMARY:
|
||||
{document_validation_summary if document_validation_summary else "No documents to validate"}
|
||||
|
||||
DOCUMENT CONTEXT (Available Documents):
|
||||
{document_context}
|
||||
|
||||
PREVIOUS RESULTS: {', '.join(context.previous_results) if context.previous_results else 'None'}
|
||||
|
||||
REVIEW INSTRUCTIONS:
|
||||
1. Evaluate if the task step was completed successfully
|
||||
2. Check if all success criteria were met
|
||||
3. Assess the quality and completeness of outputs
|
||||
4. Identify any missing or incomplete results
|
||||
5. Provide specific improvement suggestions
|
||||
6. Generate user-friendly messages explaining the results
|
||||
7. Return a JSON object with the exact structure shown below
|
||||
|
||||
DOCUMENT VALIDATION FOCUS:
|
||||
- Check if the agreed result documents label is correct (matches expected format)
|
||||
- Verify that documents are actually present and have basic properties
|
||||
- Do NOT attempt to analyze document content deeply
|
||||
- Focus on document existence and basic metadata validation
|
||||
|
||||
REQUIRED JSON STRUCTURE:
|
||||
{{
|
||||
"status": "success|retry|failed",
|
||||
"reason": "Brief explanation of the status",
|
||||
"improvements": ["improvement1", "improvement2"],
|
||||
"quality_score": 8, // 1-10 scale
|
||||
"missing_outputs": ["missing_output1", "missing_output2"],
|
||||
"met_criteria": ["criteria1", "criteria2"],
|
||||
"unmet_criteria": ["criteria3", "criteria4"],
|
||||
"confidence": 0.85, // 0.0-1.0 confidence level in this assessment
|
||||
"userMessage": "User-friendly message explaining the review results in the user's language"
|
||||
}}
|
||||
|
||||
FIELD REQUIREMENTS:
|
||||
- "status": Overall task completion status
|
||||
- "success": All criteria met, high-quality outputs
|
||||
- "retry": Some criteria met, outputs need improvement and retry
|
||||
- "failed": Most criteria unmet, significant issues
|
||||
- "reason": Clear explanation of why this status was assigned
|
||||
- "improvements": List of specific, actionable improvements
|
||||
- "quality_score": 1-10 rating of output quality
|
||||
- "missing_outputs": List of expected outputs that were not produced
|
||||
- "met_criteria": List of success criteria that were fully met
|
||||
- "unmet_criteria": List of success criteria that were not met
|
||||
- "confidence": 0.0-1.0 confidence level in this assessment
|
||||
- "userMessage": User-friendly explanation of results in the user's language
|
||||
|
||||
EXAMPLES OF GOOD IMPROVEMENTS:
|
||||
- "Increase AI prompt specificity for better data extraction"
|
||||
- "Add validation steps to ensure output completeness"
|
||||
- "Improve error handling for failed document processing"
|
||||
- "Enhance document format specifications for better output quality"
|
||||
|
||||
EXAMPLES OF GOOD MISSING OUTPUTS:
|
||||
- "Structured analysis report in JSON format"
|
||||
- "Comparison matrix of candidate profiles"
|
||||
- "Data validation summary with quality metrics"
|
||||
- "Professional business communication document"
|
||||
|
||||
QUALITY SCORE GUIDELINES:
|
||||
- 9-10: Exceptional quality, exceeds expectations
|
||||
- 7-8: Good quality, meets all requirements
|
||||
- 5-6: Acceptable quality, minor issues
|
||||
- 3-4: Poor quality, significant issues
|
||||
- 1-2: Very poor quality, major problems
|
||||
|
||||
USER LANGUAGE: {user_language} - All user messages must be generated in this language.
|
||||
|
||||
NOTE: Respond with ONLY the JSON object. Do not include any explanatory text."""
|
||||
|
||||
return prompt
|
||||
|
|
@ -1,118 +0,0 @@
|
|||
import logging
|
||||
from typing import Dict, Any, List
|
||||
from modules.interfaces.interfaceAppModel import User
|
||||
from modules.interfaces.interfaceChatModel import ChatWorkflow, UserInputRequest, TaskStep, TaskAction, ActionResult, ReviewResult, TaskPlan, WorkflowResult, TaskContext
|
||||
from modules.chat.serviceCenter import ServiceCenter
|
||||
from modules.interfaces.interfaceChatObjects import ChatObjects
|
||||
from .handling.handlingTasks import HandlingTasks, WorkflowStoppedException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ===== STATE MANAGEMENT AND VALIDATION CLASSES =====
|
||||
|
||||
class ChatManager:
|
||||
"""Chat manager with improved AI integration and method handling"""
|
||||
|
||||
def __init__(self, currentUser: User, chatInterface: ChatObjects):
|
||||
self.currentUser = currentUser
|
||||
self.chatInterface = chatInterface
|
||||
self.service: ServiceCenter = None
|
||||
self.workflow: ChatWorkflow = None
|
||||
self.handlingTasks: HandlingTasks = None
|
||||
|
||||
async def initialize(self, workflow: ChatWorkflow) -> None:
|
||||
"""Initialize chat manager with workflow"""
|
||||
self.workflow = workflow
|
||||
self.service = ServiceCenter(self.currentUser, self.workflow)
|
||||
self.handlingTasks = HandlingTasks(self.chatInterface, self.service, self.workflow)
|
||||
|
||||
async def executeUnifiedWorkflow(self, userInput: UserInputRequest, workflow: ChatWorkflow) -> WorkflowResult:
|
||||
"""Unified Workflow Execution"""
|
||||
try:
|
||||
logger.info(f"Starting unified workflow execution for workflow {workflow.id}")
|
||||
logger.debug(f"User request: {userInput.prompt}")
|
||||
|
||||
# Phase 1: High-Level Task Planning
|
||||
logger.info("Phase 1: Generating task plan")
|
||||
task_plan = await self.handlingTasks.generateTaskPlan(userInput.prompt, workflow)
|
||||
if not task_plan or not task_plan.tasks:
|
||||
raise Exception("No tasks generated in task plan.")
|
||||
|
||||
# Phase 2-5: For each task, execute and get results
|
||||
total_tasks = len(task_plan.tasks)
|
||||
logger.info(f"Phase 2: Executing {total_tasks} tasks")
|
||||
all_task_results = []
|
||||
previous_results = []
|
||||
for idx, task_step in enumerate(task_plan.tasks):
|
||||
# Pass task index to executeTask method
|
||||
current_task_index = idx + 1
|
||||
|
||||
logger.info(f"Task {idx+1}/{total_tasks}: {task_step.objective}")
|
||||
|
||||
# Create proper context object for this task
|
||||
task_context = TaskContext(
|
||||
task_step=task_step,
|
||||
workflow=workflow,
|
||||
workflow_id=workflow.id,
|
||||
available_documents=self.service.getAvailableDocuments(workflow),
|
||||
available_connections=self.service.getConnectionReferenceList(),
|
||||
previous_results=previous_results,
|
||||
previous_handover=None,
|
||||
improvements=[],
|
||||
retry_count=0,
|
||||
previous_action_results=[],
|
||||
previous_review_result=None,
|
||||
is_regeneration=False,
|
||||
failure_patterns=[],
|
||||
failed_actions=[],
|
||||
successful_actions=[],
|
||||
criteria_progress={
|
||||
'met_criteria': set(),
|
||||
'unmet_criteria': set(),
|
||||
'attempt_history': []
|
||||
}
|
||||
)
|
||||
|
||||
# Execute task (this handles action generation, execution, and review internally)
|
||||
task_result = await self.handlingTasks.executeTask(task_step, workflow, task_context, current_task_index, total_tasks)
|
||||
# Handover
|
||||
handover_data = await self.handlingTasks.prepareTaskHandover(task_step, [], task_result, workflow)
|
||||
# Collect results
|
||||
all_task_results.append({
|
||||
'task_step': task_step,
|
||||
'task_result': task_result,
|
||||
'handover_data': handover_data
|
||||
})
|
||||
# Update previous results for next task
|
||||
if task_result.success and task_result.feedback:
|
||||
previous_results.append(task_result.feedback)
|
||||
|
||||
# Final workflow result
|
||||
workflow_result = WorkflowResult(
|
||||
status="completed",
|
||||
completed_tasks=len(all_task_results),
|
||||
total_tasks=len(task_plan.tasks),
|
||||
execution_time=0.0, # TODO: Calculate actual execution time
|
||||
final_results_count=len(all_task_results)
|
||||
)
|
||||
logger.info(f"Unified workflow execution completed successfully for workflow {workflow.id}")
|
||||
return workflow_result
|
||||
except WorkflowStoppedException:
|
||||
logger.info(f"Workflow {workflow.id} was stopped by user")
|
||||
return WorkflowResult(
|
||||
status="stopped",
|
||||
completed_tasks=0,
|
||||
total_tasks=0,
|
||||
execution_time=0.0,
|
||||
final_results_count=0
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in executeUnifiedWorkflow: {str(e)}")
|
||||
return WorkflowResult(
|
||||
status="failed",
|
||||
completed_tasks=0,
|
||||
total_tasks=0,
|
||||
execution_time=0.0,
|
||||
final_results_count=0,
|
||||
error=str(e)
|
||||
)
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,185 +0,0 @@
|
|||
import logging
|
||||
import httpx
|
||||
from typing import Dict, Any, List, Union
|
||||
from fastapi import HTTPException
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def loadConfigData():
|
||||
"""Load configuration data for Anthropic connector"""
|
||||
return {
|
||||
"apiKey": APP_CONFIG.get('Connector_AiAnthropic_API_SECRET'),
|
||||
"apiUrl": APP_CONFIG.get('Connector_AiAnthropic_API_URL'),
|
||||
"modelName": APP_CONFIG.get('Connector_AiAnthropic_MODEL_NAME'),
|
||||
"temperature": float(APP_CONFIG.get('Connector_AiAnthropic_TEMPERATURE')),
|
||||
"maxTokens": int(APP_CONFIG.get('Connector_AiAnthropic_MAX_TOKENS'))
|
||||
}
|
||||
|
||||
class AiAnthropic:
|
||||
"""Connector for communication with the Anthropic API."""
|
||||
|
||||
def __init__(self):
|
||||
# Load configuration
|
||||
self.config = loadConfigData()
|
||||
self.apiKey = self.config["apiKey"]
|
||||
self.apiUrl = self.config["apiUrl"]
|
||||
self.modelName = self.config["modelName"]
|
||||
|
||||
# HttpClient for API calls
|
||||
self.httpClient = httpx.AsyncClient(
|
||||
timeout=120.0, # Longer timeout for complex requests
|
||||
headers={
|
||||
"x-api-key": self.apiKey,
|
||||
"anthropic-version": "2023-06-01", # Anthropic API Version
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Anthropic Connector initialized with model: {self.modelName}")
|
||||
|
||||
async def callAiBasic(self, messages: List[Dict[str, Any]], temperature: float = None, maxTokens: int = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Calls the Anthropic API with the given messages.
|
||||
|
||||
Args:
|
||||
messages: List of messages in OpenAI format (role, content)
|
||||
temperature: Temperature for response generation (0.0-1.0)
|
||||
maxTokens: Maximum number of tokens in the response
|
||||
|
||||
Returns:
|
||||
The response in OpenAI format
|
||||
|
||||
Raises:
|
||||
HTTPException: For errors in API communication
|
||||
"""
|
||||
try:
|
||||
# Use parameters from configuration if none were overridden
|
||||
if temperature is None:
|
||||
temperature = self.config.get("temperature", 0.2)
|
||||
|
||||
if maxTokens is None:
|
||||
maxTokens = self.config.get("maxTokens", 2000)
|
||||
|
||||
# Create Anthropic API payload
|
||||
payload = {
|
||||
"model": self.modelName,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": maxTokens
|
||||
}
|
||||
|
||||
response = await self.httpClient.post(
|
||||
self.apiUrl,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_detail = f"Anthropic API error: {response.status_code} - {response.text}"
|
||||
logger.error(error_detail)
|
||||
|
||||
# Provide more specific error messages based on status code
|
||||
if response.status_code == 529:
|
||||
error_message = "Anthropic API is currently overloaded. Please try again in a few minutes."
|
||||
elif response.status_code == 429:
|
||||
error_message = "Rate limit exceeded. Please wait before making another request."
|
||||
elif response.status_code == 401:
|
||||
error_message = "Invalid API key. Please check your Anthropic API configuration."
|
||||
elif response.status_code == 400:
|
||||
error_message = f"Invalid request to Anthropic API: {response.text}"
|
||||
else:
|
||||
error_message = f"Anthropic API error ({response.status_code}): {response.text}"
|
||||
|
||||
raise HTTPException(status_code=500, detail=error_message)
|
||||
|
||||
# Parse response
|
||||
anthropicResponse = response.json()
|
||||
|
||||
# Extract content from response
|
||||
content = ""
|
||||
if "content" in anthropicResponse:
|
||||
if isinstance(anthropicResponse["content"], list):
|
||||
# Content is a list of parts (in newer API versions)
|
||||
for part in anthropicResponse["content"]:
|
||||
if part.get("type") == "text":
|
||||
content += part.get("text", "")
|
||||
else:
|
||||
# Direct content as string (in older API versions)
|
||||
content = anthropicResponse["content"]
|
||||
|
||||
# Return in OpenAI format
|
||||
return {
|
||||
"id": anthropicResponse.get("id", ""),
|
||||
"object": "chat.completion",
|
||||
"created": anthropicResponse.get("created", 0),
|
||||
"model": anthropicResponse.get("model", self.modelName),
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
},
|
||||
"index": 0,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling Anthropic API: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error calling Anthropic API: {str(e)}")
|
||||
|
||||
async def callAiImage(self, prompt: str, imageData: Union[str, bytes], mimeType: str = None) -> str:
|
||||
"""
|
||||
Analyzes an image using Anthropic's vision capabilities.
|
||||
|
||||
Args:
|
||||
imageData: Either a file path (str) or image data (bytes)
|
||||
mimeType: The MIME type of the image (optional, only for binary data)
|
||||
prompt: The prompt for analysis
|
||||
|
||||
Returns:
|
||||
The analysis response as text
|
||||
"""
|
||||
try:
|
||||
# Distinguish between file path and binary data
|
||||
if isinstance(imageData, str):
|
||||
# It's a file path - import filehandling only when needed
|
||||
from modules import agentserviceFilemanager as fileHandler
|
||||
base64Data, autoMimeType = fileHandler.encodeFileToBase64(imageData)
|
||||
mimeType = mimeType or autoMimeType
|
||||
else:
|
||||
# It's binary data
|
||||
import base64
|
||||
base64Data = base64.b64encode(imageData).decode('utf-8')
|
||||
# MIME type must be specified for binary data
|
||||
if not mimeType:
|
||||
# Fallback to generic image type
|
||||
mimeType = "image/png"
|
||||
|
||||
# Prepare the payload for the Vision API
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{mimeType};base64,{base64Data}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Use the existing callApi function with the Vision model
|
||||
response = await self.callApi(messages)
|
||||
|
||||
# Extract and return content
|
||||
return response["choices"][0]["message"]["content"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during image analysis: {str(e)}", exc_info=True)
|
||||
return f"[Error during image analysis: {str(e)}]"
|
||||
|
|
@ -1,191 +0,0 @@
|
|||
import logging
|
||||
import base64
|
||||
import httpx
|
||||
from typing import Dict, Any, List, Union
|
||||
from fastapi import HTTPException
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ContextLengthExceededException(Exception):
|
||||
"""Exception raised when the context length exceeds the model's limit"""
|
||||
pass
|
||||
|
||||
def loadConfigData():
|
||||
"""Load configuration data for OpenAI connector"""
|
||||
return {
|
||||
"apiKey": APP_CONFIG.get('Connector_AiOpenai_API_SECRET'),
|
||||
"apiUrl": APP_CONFIG.get('Connector_AiOpenai_API_URL'),
|
||||
"modelName": APP_CONFIG.get('Connector_AiOpenai_MODEL_NAME'),
|
||||
"temperature": float(APP_CONFIG.get('Connector_AiOpenai_TEMPERATURE')),
|
||||
"maxTokens": int(APP_CONFIG.get('Connector_AiOpenai_MAX_TOKENS'))
|
||||
}
|
||||
|
||||
class AiOpenai:
|
||||
"""Connector for communication with the OpenAI API."""
|
||||
|
||||
def __init__(self):
|
||||
# Load configuration
|
||||
self.config = loadConfigData()
|
||||
self.apiKey = self.config["apiKey"]
|
||||
self.apiUrl = self.config["apiUrl"]
|
||||
self.modelName = self.config["modelName"]
|
||||
|
||||
# HttpClient for API calls
|
||||
self.httpClient = httpx.AsyncClient(
|
||||
timeout=120.0, # Longer timeout for complex requests
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.apiKey}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
logger.info(f"OpenAI Connector initialized with model: {self.modelName}")
|
||||
|
||||
async def callAiBasic(self, messages: List[Dict[str, Any]], temperature: float = None, maxTokens: int = None) -> str:
|
||||
"""
|
||||
Calls the OpenAI API with the given messages.
|
||||
|
||||
Args:
|
||||
messages: List of messages in OpenAI format (role, content)
|
||||
temperature: Temperature for response generation (0.0-1.0)
|
||||
maxTokens: Maximum number of tokens in the response
|
||||
|
||||
Returns:
|
||||
The response from the OpenAI API
|
||||
|
||||
Raises:
|
||||
HTTPException: For errors in API communication
|
||||
"""
|
||||
try:
|
||||
# Use parameters from configuration if none were overridden
|
||||
if temperature is None:
|
||||
temperature = self.config.get("temperature", 0.2)
|
||||
|
||||
if maxTokens is None:
|
||||
maxTokens = self.config.get("maxTokens", 2000)
|
||||
|
||||
payload = {
|
||||
"model": self.modelName,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": maxTokens
|
||||
}
|
||||
|
||||
response = await self.httpClient.post(
|
||||
self.apiUrl,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"OpenAI API error: {response.status_code} - {response.text}")
|
||||
|
||||
# Check for context length exceeded error
|
||||
if response.status_code == 400:
|
||||
try:
|
||||
error_data = response.json()
|
||||
if (error_data.get("error", {}).get("code") == "context_length_exceeded" or
|
||||
"context length" in error_data.get("error", {}).get("message", "").lower()):
|
||||
# Raise a specific exception for context length issues
|
||||
raise ContextLengthExceededException(
|
||||
f"Context length exceeded: {error_data.get('error', {}).get('message', 'Unknown error')}"
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
pass # If we can't parse the error, fall through to generic error
|
||||
|
||||
raise HTTPException(status_code=500, detail="Error communicating with OpenAI API")
|
||||
|
||||
responseJson = response.json()
|
||||
content = responseJson["choices"][0]["message"]["content"]
|
||||
return content
|
||||
|
||||
except ContextLengthExceededException:
|
||||
# Re-raise context length exceptions without wrapping
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling OpenAI API: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error calling OpenAI API: {str(e)}")
|
||||
|
||||
async def callAiImage(self, prompt: str, imageData: Union[str, bytes], mimeType: str = None) -> str:
|
||||
"""
|
||||
Analyzes an image with the OpenAI Vision API.
|
||||
|
||||
Args:
|
||||
imageData: base64encoded data
|
||||
mimeType: The MIME type of the image (optional, only for binary data)
|
||||
prompt: The prompt for analysis
|
||||
|
||||
Returns:
|
||||
The response from the OpenAI Vision API as text
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Starting image analysis with query '{prompt}' for size {len(imageData)}B...")
|
||||
|
||||
# Ensure imageData is a string (base64 encoded)
|
||||
if not isinstance(imageData, str):
|
||||
raise ValueError("imageData must be a string (base64 encoded)")
|
||||
|
||||
# Fix base64 padding if needed
|
||||
padding_needed = len(imageData) % 4
|
||||
if padding_needed:
|
||||
imageData += '=' * (4 - padding_needed)
|
||||
|
||||
# Use default MIME type if not provided
|
||||
if not mimeType:
|
||||
mimeType = "image/jpeg"
|
||||
|
||||
logger.debug(f"Using MIME type: {mimeType}")
|
||||
logger.debug(f"Base64 data length: {len(imageData)} characters")
|
||||
|
||||
# Create the data URL format as required by OpenAI Vision API
|
||||
data_url = f"data:{mimeType};base64,{imageData}"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": data_url
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Use a vision-capable model for image analysis
|
||||
# Override the model for vision tasks
|
||||
visionModel = "gpt-4o" # or "gpt-4-vision-preview" depending on availability
|
||||
|
||||
# Use parameters from configuration
|
||||
temperature = self.config.get("temperature", 0.2)
|
||||
maxTokens = self.config.get("maxTokens", 2000)
|
||||
|
||||
payload = {
|
||||
"model": visionModel,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": maxTokens
|
||||
}
|
||||
|
||||
response = await self.httpClient.post(
|
||||
self.apiUrl,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"OpenAI API error: {response.status_code} - {response.text}")
|
||||
raise HTTPException(status_code=500, detail="Error communicating with OpenAI API")
|
||||
|
||||
responseJson = response.json()
|
||||
content = responseJson["choices"][0]["message"]["content"]
|
||||
return content
|
||||
|
||||
# Return content
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during image analysis: {str(e)}", exc_info=True)
|
||||
return f"[Error during image analysis: {str(e)}]"
|
||||
|
|
@ -1,15 +1,13 @@
|
|||
import json
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional, Union, TypedDict
|
||||
from typing import List, Dict, Any, Optional, TypedDict
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
import threading
|
||||
import time
|
||||
|
||||
from modules.shared.attributeUtils import to_dict
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
from modules.shared.timezoneUtils import getUtcTimestamp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -234,7 +232,7 @@ class DatabaseConnector:
|
|||
raise ValueError(f"Record ID mismatch: file name ID ({recordId}) does not match record ID ({record['id']})")
|
||||
|
||||
# Add metadata
|
||||
currentTime = get_utc_timestamp()
|
||||
currentTime = getUtcTimestamp()
|
||||
if "_createdAt" not in record:
|
||||
record["_createdAt"] = currentTime
|
||||
record["_createdBy"] = self.userId
|
||||
|
|
@ -567,7 +565,7 @@ class DatabaseConnector:
|
|||
|
||||
# If record is a Pydantic model, convert to dict
|
||||
if isinstance(record, BaseModel):
|
||||
record = to_dict(record)
|
||||
record = record.model_dump()
|
||||
|
||||
# Save record
|
||||
self._saveRecord(table, record["id"], record)
|
||||
|
|
@ -582,7 +580,7 @@ class DatabaseConnector:
|
|||
|
||||
# If record is a Pydantic model, convert to dict
|
||||
if isinstance(record, BaseModel):
|
||||
record = to_dict(record)
|
||||
record = record.model_dump()
|
||||
|
||||
# CRITICAL: Ensure we never modify the ID
|
||||
if "id" in record and str(record["id"]) != recordId:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
141
modules/connectors/connectorTicketsClickup.py
Normal file
141
modules/connectors/connectorTicketsClickup.py
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
"""ClickUp connector for CRUD operations (compatible with TicketInterface).
|
||||
|
||||
This module defines its own minimal abstractions to avoid coupling.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import logging
|
||||
import aiohttp
|
||||
from modules.datamodels.datamodelTickets import TicketBase, TicketFieldAttribute
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectorTicketClickup(TicketBase):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
apiToken: str,
|
||||
teamId: str,
|
||||
listId: Optional[str] = None,
|
||||
apiUrl: str = "https://api.clickup.com/api/v2",
|
||||
) -> None:
|
||||
self.apiToken = apiToken
|
||||
self.teamId = teamId
|
||||
self.listId = listId
|
||||
self.apiUrl = apiUrl
|
||||
|
||||
def _headers(self) -> dict:
|
||||
return {
|
||||
"Authorization": self.apiToken,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def readAttributes(self) -> list[TicketFieldAttribute]:
|
||||
"""Fetch field attributes. Uses list custom fields if listId provided; else basic fields."""
|
||||
attributes: list[TicketFieldAttribute] = []
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if self.listId:
|
||||
url = f"{self.apiUrl}/list/{self.listId}/field"
|
||||
async with session.get(url, headers=self._headers()) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"ClickUp fields fetch status: {response.status}")
|
||||
else:
|
||||
data = await response.json()
|
||||
for field in data.get("fields", []):
|
||||
fieldId = field.get("id")
|
||||
fieldName = field.get("name", fieldId)
|
||||
if fieldId:
|
||||
attributes.append(TicketFieldAttribute(fieldName=fieldName, field=fieldId))
|
||||
|
||||
# Add common top-level fields
|
||||
core_fields = [
|
||||
("ID", "id"),
|
||||
("Name", "name"),
|
||||
("Status", "status.status"),
|
||||
("Assignees", "assignees"),
|
||||
("DateCreated", "date_created"),
|
||||
("DueDate", "due_date"),
|
||||
]
|
||||
for name, fid in core_fields:
|
||||
attributes.append(TicketFieldAttribute(fieldName=name, field=fid))
|
||||
except Exception as e:
|
||||
logger.error(f"ClickUp read_attributes error: {e}")
|
||||
return attributes
|
||||
|
||||
async def readTasks(self, *, limit: int = 0) -> list[dict]:
|
||||
"""Read tasks from ClickUp, always returning full task records.
|
||||
If list_id is set, read from that list; otherwise read from team.
|
||||
"""
|
||||
tasks: list[dict] = []
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
page = 0
|
||||
pageSize = 100
|
||||
while True:
|
||||
if self.listId:
|
||||
url = f"{self.apiUrl}/list/{self.listId}/task?subtasks=true&page={page}&order_by=created&reverse=true"
|
||||
else:
|
||||
# Team-level search for open tasks
|
||||
url = f"{self.apiUrl}/team/{self.teamId}/task?subtasks=true&page={page}&order_by=created&reverse=true"
|
||||
|
||||
# Request with parameters to include all fields where possible
|
||||
async with session.get(url, headers=self._headers()) as response:
|
||||
if response.status != 200:
|
||||
errorText = await response.text()
|
||||
logger.error(f"ClickUp read_tasks failed: {response.status} {errorText}")
|
||||
break
|
||||
|
||||
data = await response.json()
|
||||
items = data.get("tasks", [])
|
||||
for item in items:
|
||||
tasks.append(item)
|
||||
if limit and len(tasks) >= limit:
|
||||
return tasks
|
||||
|
||||
if len(items) < pageSize:
|
||||
break
|
||||
page += 1
|
||||
except Exception as e:
|
||||
logger.error(f"ClickUp read_tasks error: {e}")
|
||||
return tasks
|
||||
|
||||
async def writeTasks(self, tasklist: list[dict]) -> None:
|
||||
"""Update tasks in ClickUp. Expects each item to contain {'ID' or 'id' or 'task_id', 'fields': {...}}"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for data in tasklist:
|
||||
taskId = data.get("ID") or data.get("id") or data.get("task_id")
|
||||
fields = data.get("fields", {})
|
||||
if not taskId or not isinstance(fields, dict) or not fields:
|
||||
continue
|
||||
|
||||
# Map generic fields to ClickUp payload
|
||||
payload: dict = {}
|
||||
for fieldId, value in fields.items():
|
||||
# Heuristics: map common field ids
|
||||
if fieldId in ("name", "summary"):
|
||||
payload["name"] = value
|
||||
elif fieldId in ("status",):
|
||||
payload["status"] = value
|
||||
elif fieldId.startswith("customfield_") or fieldId.startswith("cf_"):
|
||||
# ClickUp custom fields need separate endpoint; attempt inline update if supported
|
||||
if "custom_fields" not in payload:
|
||||
payload["custom_fields"] = []
|
||||
payload["custom_fields"].append({"id": fieldId, "value": value})
|
||||
else:
|
||||
# Best-effort assign to description for unknown text fields
|
||||
if isinstance(value, str) and value:
|
||||
payload.setdefault("description", value)
|
||||
|
||||
url = f"{self.apiUrl}/task/{taskId}"
|
||||
async with session.put(url, headers=self._headers(), json=payload) as response:
|
||||
if response.status not in (200, 204):
|
||||
err = await response.text()
|
||||
logger.error(f"ClickUp update failed for {taskId}: {response.status} {err}")
|
||||
except Exception as e:
|
||||
logger.error(f"ClickUp write_tasks error: {e}")
|
||||
|
||||
|
||||
|
|
@ -1,49 +1,35 @@
|
|||
"""Jira connector for CRUD operations."""
|
||||
"""Jira connector for CRUD operations (neutralized to generic ticket interface).
|
||||
|
||||
This module defines its own minimal abstractions to avoid coupling.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import logging
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from modules.interfaces.interfaceTicketModel import (
|
||||
TicketBase,
|
||||
TicketFieldAttribute,
|
||||
Task,
|
||||
)
|
||||
|
||||
from modules.datamodels.datamodelTickets import TicketBase, TicketFieldAttribute
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectorTicketJira(TicketBase):
|
||||
jira_username: str
|
||||
jira_api_token: str
|
||||
jira_url: str
|
||||
project_code: str
|
||||
issue_type: str
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
jira_username: str,
|
||||
jira_api_token: str,
|
||||
jira_url: str,
|
||||
project_code: str,
|
||||
issue_type: str,
|
||||
):
|
||||
return ConnectorTicketJira(
|
||||
jira_username=jira_username,
|
||||
jira_api_token=jira_api_token,
|
||||
jira_url=jira_url,
|
||||
project_code=project_code,
|
||||
issue_type=issue_type,
|
||||
)
|
||||
apiUsername: str,
|
||||
apiToken: str,
|
||||
apiUrl: str,
|
||||
projectCode: str,
|
||||
ticketType: str,
|
||||
) -> None:
|
||||
self.apiUsername = apiUsername
|
||||
self.apiToken = apiToken
|
||||
self.apiUrl = apiUrl
|
||||
self.projectCode = projectCode
|
||||
self.ticketType = ticketType
|
||||
|
||||
async def read_attributes(self) -> list[TicketFieldAttribute]:
|
||||
|
||||
async def readAttributes(self) -> list[TicketFieldAttribute]:
|
||||
"""
|
||||
Read field attributes from Jira by querying for a single issue
|
||||
and extracting the field mappings.
|
||||
|
|
@ -52,22 +38,22 @@ class ConnectorTicketJira(TicketBase):
|
|||
list[TicketFieldAttribute]: List of field attributes with names and IDs
|
||||
"""
|
||||
# Build JQL dynamically; allow empty or '*' issue_type to mean "all types"
|
||||
if self.issue_type and self.issue_type != "*":
|
||||
jql_query = f"project={self.project_code} AND issuetype={self.issue_type}"
|
||||
if self.ticketType and self.ticketType != "*":
|
||||
jql_query = f"project={self.projectCode} AND issuetype={self.ticketType}"
|
||||
else:
|
||||
jql_query = f"project={self.project_code}"
|
||||
jql_query = f"project={self.projectCode}"
|
||||
|
||||
# Prepare the request URL (use JQL search endpoint)
|
||||
url = f"{self.jira_url}/rest/api/3/search/jql"
|
||||
url = f"{self.apiUrl}/rest/api/3/search/jql"
|
||||
|
||||
# Prepare authentication
|
||||
auth = aiohttp.BasicAuth(self.jira_username, self.jira_api_token)
|
||||
auth = aiohttp.BasicAuth(self.apiUsername, self.apiToken)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payload = {
|
||||
"jql": jql_query,
|
||||
"jql": jql_query,
|
||||
"maxResults": 1
|
||||
# Don't specify fields to get all available fields
|
||||
}
|
||||
|
|
@ -100,14 +86,10 @@ class ConnectorTicketJira(TicketBase):
|
|||
fields = issue.get("fields", {})
|
||||
|
||||
for field_id, value in fields.items():
|
||||
field_name = field_names.get(field_id, field_id)
|
||||
fieldName = field_names.get(field_id, field_id)
|
||||
attributes.append(
|
||||
TicketFieldAttribute(field_name=field_name, field=field_id)
|
||||
TicketFieldAttribute(fieldName=fieldName, field=field_id)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Successfully retrieved {len(attributes)} field attributes from Jira"
|
||||
)
|
||||
return attributes
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
|
|
@ -122,8 +104,8 @@ class ConnectorTicketJira(TicketBase):
|
|||
|
||||
async def _read_all_fields_via_fields_api(self) -> list[TicketFieldAttribute]:
|
||||
"""Fallback: use Jira fields API to list all fields with id->name mapping."""
|
||||
auth = aiohttp.BasicAuth(self.jira_username, self.jira_api_token)
|
||||
url = f"{self.jira_url}/rest/api/3/field"
|
||||
auth = aiohttp.BasicAuth(self.apiUsername, self.apiToken)
|
||||
url = f"{self.apiUrl}/rest/api/3/field"
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, auth=auth) as response:
|
||||
|
|
@ -138,20 +120,17 @@ class ConnectorTicketJira(TicketBase):
|
|||
attributes: list[TicketFieldAttribute] = []
|
||||
for field in data:
|
||||
field_id = field.get("id")
|
||||
field_name = field.get("name", field_id)
|
||||
fieldName = field.get("name", field_id)
|
||||
if field_id:
|
||||
attributes.append(
|
||||
TicketFieldAttribute(field_name=field_name, field=field_id)
|
||||
TicketFieldAttribute(fieldName=fieldName, field=field_id)
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully retrieved {len(attributes)} field attributes via fields API"
|
||||
)
|
||||
return attributes
|
||||
except Exception as e:
|
||||
logger.error(f"Error while calling fields API: {str(e)}")
|
||||
return []
|
||||
|
||||
async def read_tasks(self, *, limit: int = 0) -> list[Task]:
|
||||
async def readTasks(self, *, limit: int = 0) -> list[dict]:
|
||||
"""
|
||||
Read tasks from Jira with pagination support.
|
||||
|
||||
|
|
@ -159,25 +138,25 @@ class ConnectorTicketJira(TicketBase):
|
|||
limit: Maximum number of tasks to retrieve. 0 means no limit.
|
||||
|
||||
Returns:
|
||||
list[Task]: List of tasks with their data
|
||||
list[dict]: List of tasks with their data
|
||||
"""
|
||||
# Build JQL dynamically; allow empty or '*' issue_type to mean "all types"
|
||||
if self.issue_type and self.issue_type != "*":
|
||||
jql_query = f"project={self.project_code} AND issuetype={self.issue_type}"
|
||||
if self.ticketType and self.ticketType != "*":
|
||||
jql_query = f"project={self.projectCode} AND issuetype={self.ticketType}"
|
||||
else:
|
||||
jql_query = f"project={self.project_code}"
|
||||
jql_query = f"project={self.projectCode}"
|
||||
|
||||
# Initialize variables for pagination (cursor-based /search/jql)
|
||||
max_results = 100
|
||||
next_page_token: str | None = None
|
||||
tasks = []
|
||||
tasks: list[dict] = []
|
||||
page_counter = 0
|
||||
max_pages_safety_cap = 1000
|
||||
seen_issue_ids: set[str] = set()
|
||||
|
||||
# Prepare authentication
|
||||
auth = aiohttp.BasicAuth(self.jira_username, self.jira_api_token)
|
||||
url = f"{self.jira_url}/rest/api/3/search/jql"
|
||||
auth = aiohttp.BasicAuth(self.apiUsername, self.apiToken)
|
||||
url = f"{self.apiUrl}/rest/api/3/search/jql"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
|
@ -195,9 +174,6 @@ class ConnectorTicketJira(TicketBase):
|
|||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# Debug: log the payload being sent
|
||||
logger.debug(f"JIRA request payload: {json.dumps(payload, indent=2)}")
|
||||
|
||||
async with session.post(
|
||||
url, json=payload, auth=auth, headers=headers
|
||||
) as response:
|
||||
|
|
@ -227,8 +203,7 @@ class ConnectorTicketJira(TicketBase):
|
|||
continue
|
||||
if issue_id:
|
||||
seen_issue_ids.add(issue_id)
|
||||
task = Task(data=issue)
|
||||
tasks.append(task)
|
||||
tasks.append(issue)
|
||||
new_items_added += 1
|
||||
|
||||
# Check limit
|
||||
|
|
@ -278,20 +253,19 @@ class ConnectorTicketJira(TicketBase):
|
|||
logger.error(f"Unexpected error while fetching Jira tasks: {str(e)}")
|
||||
raise
|
||||
|
||||
async def write_tasks(self, tasklist: list[Task]) -> None:
|
||||
async def writeTasks(self, tasklist: list[dict]) -> None:
|
||||
"""
|
||||
Write/update tasks to Jira.
|
||||
|
||||
Args:
|
||||
tasklist: List of Task objects containing task data to update
|
||||
tasklist: List of dicts containing task data to update
|
||||
"""
|
||||
headers = {"Accept": "application/json", "Content-Type": "application/json"}
|
||||
auth = aiohttp.BasicAuth(self.jira_username, self.jira_api_token)
|
||||
auth = aiohttp.BasicAuth(self.apiUsername, self.apiToken)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for task in tasklist:
|
||||
task_data = task.data
|
||||
for task_data in tasklist:
|
||||
task_id = (
|
||||
task_data.get("ID")
|
||||
or task_data.get("id")
|
||||
|
|
@ -299,7 +273,7 @@ class ConnectorTicketJira(TicketBase):
|
|||
)
|
||||
|
||||
if not task_id:
|
||||
logger.warning("Task missing ID or key, skipping update")
|
||||
logger.warning("Ticket update missing ID or key, skipping")
|
||||
continue
|
||||
|
||||
# Extract fields to update from task data
|
||||
|
|
@ -333,7 +307,6 @@ class ConnectorTicketJira(TicketBase):
|
|||
}
|
||||
else:
|
||||
# Skip empty ADF fields
|
||||
logger.debug(f"Skipping empty ADF field {field_id} for task {task_id}")
|
||||
continue
|
||||
else:
|
||||
processed_fields[field_id] = field_value
|
||||
|
|
@ -346,13 +319,13 @@ class ConnectorTicketJira(TicketBase):
|
|||
update_data = {"fields": processed_fields}
|
||||
|
||||
# Make the update request
|
||||
url = f"{self.jira_url}/rest/api/3/issue/{task_id}"
|
||||
url = f"{self.apiUrl}/rest/api/3/issue/{task_id}"
|
||||
|
||||
async with session.put(
|
||||
url, json=update_data, headers=headers, auth=auth
|
||||
) as response:
|
||||
if response.status == 204:
|
||||
logger.info(f"JIRA task {task_id} updated successfully.")
|
||||
pass
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
|
|
@ -365,3 +338,5 @@ class ConnectorTicketJira(TicketBase):
|
|||
except Exception as e:
|
||||
logger.error(f"Unexpected error while updating Jira tasks: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
|
|
@ -3,12 +3,9 @@ Google Cloud Speech-to-Text and Translation Connector
|
|||
Replaces Azure Speech Services with Google Cloud APIs
|
||||
"""
|
||||
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import html
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Dict, Optional, Any
|
||||
from google.cloud import speech
|
||||
from google.cloud import translate_v2 as translate
|
||||
|
|
@ -29,18 +26,18 @@ class ConnectorGoogleSpeech:
|
|||
"""
|
||||
try:
|
||||
# Get JSON key from config.ini
|
||||
api_key = APP_CONFIG.get("Connector_GoogleSpeech_API_KEY")
|
||||
apiKey = APP_CONFIG.get("Connector_GoogleSpeech_API_KEY_SECRET")
|
||||
|
||||
if not api_key or api_key == "YOUR_GOOGLE_SERVICE_ACCOUNT_JSON_KEY_HERE":
|
||||
raise ValueError("Google Speech API key not configured. Please set Connector_GoogleSpeech_API_KEY in config.ini with the full service account JSON key")
|
||||
if not apiKey or apiKey == "YOUR_GOOGLE_SERVICE_ACCOUNT_JSON_KEY_HERE":
|
||||
raise ValueError("Google Speech API key not configured. Please set Connector_GoogleSpeech_API_KEY_SECRET in config.ini with the full service account JSON key")
|
||||
|
||||
# Parse the JSON key and set up authentication
|
||||
try:
|
||||
credentials_info = json.loads(api_key)
|
||||
credentialsInfo = json.loads(apiKey)
|
||||
|
||||
# Create credentials object directly (no file needed!)
|
||||
from google.oauth2 import service_account
|
||||
credentials = service_account.Credentials.from_service_account_info(credentials_info)
|
||||
credentials = service_account.Credentials.from_service_account_info(credentialsInfo)
|
||||
|
||||
logger.info("✅ Using Google Speech credentials from config.ini")
|
||||
|
||||
|
|
@ -58,13 +55,13 @@ class ConnectorGoogleSpeech:
|
|||
logger.error(f"❌ Failed to initialize Google Cloud clients: {e}")
|
||||
raise
|
||||
|
||||
async def speech_to_text(self, audio_content: bytes, language: str = "de-DE",
|
||||
sample_rate: int = None, channels: int = None) -> Dict:
|
||||
async def speechToText(self, audioContent: bytes, language: str = "de-DE",
|
||||
sampleRate: int = None, channels: int = None) -> Dict:
|
||||
"""
|
||||
Convert speech to text using Google Cloud Speech-to-Text API.
|
||||
|
||||
Args:
|
||||
audio_content: Raw audio data (various formats supported)
|
||||
audioContent: Raw audio data (various formats supported)
|
||||
language: Language code (e.g., 'de-DE', 'en-US')
|
||||
sample_rate: Audio sample rate (auto-detected if None)
|
||||
channels: Number of audio channels (auto-detected if None)
|
||||
|
|
@ -74,8 +71,8 @@ class ConnectorGoogleSpeech:
|
|||
"""
|
||||
try:
|
||||
# Auto-detect audio format if not provided
|
||||
if sample_rate is None or channels is None:
|
||||
validation = self.validate_audio_format(audio_content)
|
||||
if sampleRate is None or channels is None:
|
||||
validation = self.validateAudioFormat(audioContent)
|
||||
if not validation["valid"]:
|
||||
return {
|
||||
"success": False,
|
||||
|
|
@ -83,59 +80,59 @@ class ConnectorGoogleSpeech:
|
|||
"confidence": 0.0,
|
||||
"error": f"Invalid audio format: {validation.get('error', 'Unknown error')}"
|
||||
}
|
||||
sample_rate = validation["sample_rate"]
|
||||
sampleRate = validation["sample_rate"]
|
||||
channels = validation["channels"]
|
||||
audio_format = validation["format"]
|
||||
logger.info(f"Auto-detected audio: {audio_format}, {sample_rate}Hz, {channels}ch")
|
||||
audioFormat = validation["format"]
|
||||
logger.info(f"Auto-detected audio: {audioFormat}, {sampleRate}Hz, {channels}ch")
|
||||
|
||||
logger.info(f"Processing audio with Google Cloud Speech-to-Text")
|
||||
logger.info(f"Audio: {len(audio_content)} bytes, {sample_rate}Hz, {channels}ch")
|
||||
logger.info(f"Audio: {len(audioContent)} bytes, {sampleRate}Hz, {channels}ch")
|
||||
|
||||
# Configure audio settings
|
||||
audio = speech.RecognitionAudio(content=audio_content)
|
||||
audio = speech.RecognitionAudio(content=audioContent)
|
||||
|
||||
# Determine encoding based on detected format
|
||||
# Google Cloud Speech API has specific requirements for different formats
|
||||
if audio_format == "webm_opus":
|
||||
if audioFormat == "webm_opus":
|
||||
# For WEBM OPUS, we need to ensure proper format
|
||||
encoding = speech.RecognitionConfig.AudioEncoding.WEBM_OPUS
|
||||
# WEBM_OPUS requires specific sample rate handling - must match header
|
||||
if sample_rate != 48000:
|
||||
logger.warning(f"WEBM_OPUS detected but sample rate is {sample_rate}, adjusting to 48000")
|
||||
sample_rate = 48000
|
||||
if sampleRate != 48000:
|
||||
logger.warning(f"WEBM_OPUS detected but sample rate is {sampleRate}, adjusting to 48000")
|
||||
sampleRate = 48000
|
||||
# For WEBM_OPUS, don't specify sample_rate_hertz in config
|
||||
# Google Cloud will read it from the WEBM header
|
||||
use_sample_rate = False
|
||||
elif audio_format == "linear16":
|
||||
useSampleRate = False
|
||||
elif audioFormat == "linear16":
|
||||
# For LINEAR16 format (PCM)
|
||||
encoding = speech.RecognitionConfig.AudioEncoding.LINEAR16
|
||||
# Ensure sample rate is reasonable
|
||||
if sample_rate not in [8000, 16000, 22050, 24000, 32000, 44100, 48000]:
|
||||
logger.warning(f"Unusual sample rate {sample_rate}, adjusting to 16000")
|
||||
sample_rate = 16000
|
||||
use_sample_rate = True
|
||||
elif audio_format == "mp3":
|
||||
if sampleRate not in [8000, 16000, 22050, 24000, 32000, 44100, 48000]:
|
||||
logger.warning(f"Unusual sample rate {sampleRate}, adjusting to 16000")
|
||||
sampleRate = 16000
|
||||
useSampleRate = True
|
||||
elif audioFormat == "mp3":
|
||||
# For MP3 format
|
||||
encoding = speech.RecognitionConfig.AudioEncoding.MP3
|
||||
use_sample_rate = True
|
||||
elif audio_format == "flac":
|
||||
useSampleRate = True
|
||||
elif audioFormat == "flac":
|
||||
# For FLAC format
|
||||
encoding = speech.RecognitionConfig.AudioEncoding.FLAC
|
||||
use_sample_rate = True
|
||||
elif audio_format == "wav":
|
||||
useSampleRate = True
|
||||
elif audioFormat == "wav":
|
||||
# For WAV format
|
||||
encoding = speech.RecognitionConfig.AudioEncoding.LINEAR16
|
||||
use_sample_rate = True
|
||||
useSampleRate = True
|
||||
else:
|
||||
# For unknown formats, try LINEAR16 as fallback
|
||||
encoding = speech.RecognitionConfig.AudioEncoding.LINEAR16
|
||||
sample_rate = 16000 # Use standard sample rate
|
||||
sampleRate = 16000 # Use standard sample rate
|
||||
channels = 1 # Use mono
|
||||
use_sample_rate = True
|
||||
logger.warning(f"Unknown audio format '{audio_format}', using LINEAR16 encoding with 16000Hz")
|
||||
useSampleRate = True
|
||||
logger.warning(f"Unknown audio format '{audioFormat}', using LINEAR16 encoding with 16000Hz")
|
||||
|
||||
# Build config based on format requirements
|
||||
config_params = {
|
||||
configParams = {
|
||||
"encoding": encoding,
|
||||
"audio_channel_count": channels,
|
||||
"language_code": language,
|
||||
|
|
@ -148,13 +145,13 @@ class ConnectorGoogleSpeech:
|
|||
}
|
||||
|
||||
# Only add sample_rate_hertz if needed (not for WEBM_OPUS)
|
||||
if use_sample_rate:
|
||||
config_params["sample_rate_hertz"] = sample_rate
|
||||
logger.debug(f"Recognition config: encoding={encoding}, sample_rate={sample_rate}, channels={channels}, language={language}")
|
||||
if useSampleRate:
|
||||
configParams["sample_rate_hertz"] = sampleRate
|
||||
logger.debug(f"Recognition config: encoding={encoding}, sample_rate={sampleRate}, channels={channels}, language={language}")
|
||||
else:
|
||||
logger.debug(f"Recognition config: encoding={encoding}, sample_rate=auto (from header), channels={channels}, language={language}")
|
||||
|
||||
config = speech.RecognitionConfig(**config_params)
|
||||
config = speech.RecognitionConfig(**configParams)
|
||||
|
||||
# Perform speech recognition
|
||||
logger.info("Sending audio to Google Cloud Speech-to-Text...")
|
||||
|
|
@ -165,12 +162,12 @@ class ConnectorGoogleSpeech:
|
|||
response = self.speech_client.recognize(config=config, audio=audio)
|
||||
logger.debug(f"Google Cloud response: {response}")
|
||||
|
||||
except Exception as api_error:
|
||||
logger.error(f"Google Cloud API error: {api_error}")
|
||||
except Exception as apiError:
|
||||
logger.error(f"Google Cloud API error: {apiError}")
|
||||
# Try with different encoding as fallback
|
||||
if encoding != speech.RecognitionConfig.AudioEncoding.LINEAR16:
|
||||
logger.info("Trying fallback with LINEAR16 encoding...")
|
||||
fallback_config = speech.RecognitionConfig(
|
||||
fallbackConfig = speech.RecognitionConfig(
|
||||
encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=16000, # Use standard sample rate
|
||||
audio_channel_count=1,
|
||||
|
|
@ -180,13 +177,13 @@ class ConnectorGoogleSpeech:
|
|||
)
|
||||
|
||||
try:
|
||||
response = self.speech_client.recognize(config=fallback_config, audio=audio)
|
||||
response = self.speech_client.recognize(config=fallbackConfig, audio=audio)
|
||||
logger.debug(f"Google Cloud fallback response: {response}")
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Google Cloud fallback error: {fallback_error}")
|
||||
raise api_error
|
||||
except Exception as fallbackError:
|
||||
logger.error(f"Google Cloud fallback error: {fallbackError}")
|
||||
raise apiError
|
||||
else:
|
||||
raise api_error
|
||||
raise apiError
|
||||
|
||||
# Process results
|
||||
if response.results:
|
||||
|
|
@ -237,18 +234,18 @@ class ConnectorGoogleSpeech:
|
|||
|
||||
if encoding != speech.RecognitionConfig.AudioEncoding.LINEAR16:
|
||||
# For WEBM_OPUS, don't try LINEAR16 with detected sample rate as it causes conflicts
|
||||
if audio_format != "webm_opus":
|
||||
if audioFormat != "webm_opus":
|
||||
# Try LINEAR16 with detected sample rate for non-WEBM formats
|
||||
fallback_configs.append({
|
||||
"encoding": speech.RecognitionConfig.AudioEncoding.LINEAR16,
|
||||
"sample_rate": sample_rate,
|
||||
"sample_rate": sampleRate,
|
||||
"channels": channels,
|
||||
"use_sample_rate": True,
|
||||
"description": f"LINEAR16 with {sample_rate}Hz"
|
||||
"description": f"LINEAR16 with {sampleRate}Hz"
|
||||
})
|
||||
|
||||
# For WEBM_OPUS, only try compatible sample rates or skip sample rate specification
|
||||
if audio_format == "webm_opus":
|
||||
if audioFormat == "webm_opus":
|
||||
# Try WEBM_OPUS without sample rate specification (let Google read from header)
|
||||
fallback_configs.append({
|
||||
"encoding": speech.RecognitionConfig.AudioEncoding.WEBM_OPUS,
|
||||
|
|
@ -276,7 +273,7 @@ class ConnectorGoogleSpeech:
|
|||
else:
|
||||
# For other formats, try standard sample rates
|
||||
for std_rate in [16000, 8000, 22050, 44100]:
|
||||
if std_rate != sample_rate:
|
||||
if std_rate != sampleRate:
|
||||
fallback_configs.append({
|
||||
"encoding": speech.RecognitionConfig.AudioEncoding.LINEAR16,
|
||||
"sample_rate": std_rate,
|
||||
|
|
@ -351,8 +348,8 @@ class ConnectorGoogleSpeech:
|
|||
"error": str(e)
|
||||
}
|
||||
|
||||
async def translate_text(self, text: str, target_language: str = "en",
|
||||
source_language: str = "de") -> Dict:
|
||||
async def translateText(self, text: str, targetLanguage: str = "en",
|
||||
sourceLanguage: str = "de") -> Dict:
|
||||
"""
|
||||
Translate text using Google Cloud Translation API.
|
||||
|
||||
|
|
@ -373,28 +370,28 @@ class ConnectorGoogleSpeech:
|
|||
"error": "Empty text provided"
|
||||
}
|
||||
|
||||
logger.info(f"🌐 Translating: '{text}' ({source_language} -> {target_language})")
|
||||
logger.info(f"🌐 Translating: '{text}' ({sourceLanguage} -> {targetLanguage})")
|
||||
|
||||
# Perform translation
|
||||
result = self.translate_client.translate(
|
||||
text,
|
||||
source_language=source_language,
|
||||
target_language=target_language
|
||||
source_language=sourceLanguage,
|
||||
target_language=targetLanguage
|
||||
)
|
||||
|
||||
translated_text = result['translatedText']
|
||||
detected_language = result.get('detectedSourceLanguage', source_language)
|
||||
translatedText = result['translatedText']
|
||||
detectedLanguage = result.get('detectedSourceLanguage', sourceLanguage)
|
||||
|
||||
# Decode HTML entities in translated text
|
||||
translated_text = html.unescape(translated_text)
|
||||
translatedText = html.unescape(translatedText)
|
||||
|
||||
logger.info(f"✅ Translation successful: '{translated_text}'")
|
||||
logger.info(f"✅ Translation successful: '{translatedText}'")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"translated_text": translated_text,
|
||||
"source_language": detected_language,
|
||||
"target_language": target_language,
|
||||
"translated_text": translatedText,
|
||||
"source_language": detectedLanguage,
|
||||
"target_language": targetLanguage,
|
||||
"original_text": text
|
||||
}
|
||||
|
||||
|
|
@ -406,14 +403,14 @@ class ConnectorGoogleSpeech:
|
|||
"error": str(e)
|
||||
}
|
||||
|
||||
async def speech_to_translated_text(self, audio_content: bytes,
|
||||
from_language: str = "de-DE",
|
||||
to_language: str = "en") -> Dict:
|
||||
async def speechToTranslatedText(self, audioContent: bytes,
|
||||
fromLanguage: str = "de-DE",
|
||||
toLanguage: str = "en") -> Dict:
|
||||
"""
|
||||
Complete pipeline: Speech-to-Text + Translation.
|
||||
|
||||
Args:
|
||||
audio_content: Raw audio data
|
||||
audioContent: Raw audio data
|
||||
from_language: Source language for speech recognition
|
||||
to_language: Target language for translation
|
||||
|
||||
|
|
@ -421,52 +418,52 @@ class ConnectorGoogleSpeech:
|
|||
Dict containing original text, translated text, and metadata
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🔄 Starting speech-to-translation pipeline: {from_language} -> {to_language}")
|
||||
logger.info(f"🔄 Starting speech-to-translation pipeline: {fromLanguage} -> {toLanguage}")
|
||||
|
||||
# Step 1: Speech-to-Text
|
||||
speech_result = await self.speech_to_text(
|
||||
audio_content=audio_content,
|
||||
language=from_language
|
||||
speechResult = await self.speechToText(
|
||||
audioContent=audioContent,
|
||||
language=fromLanguage
|
||||
)
|
||||
|
||||
if not speech_result["success"]:
|
||||
if not speechResult["success"]:
|
||||
return {
|
||||
"success": False,
|
||||
"original_text": "",
|
||||
"translated_text": "",
|
||||
"error": f"Speech recognition failed: {speech_result.get('error', 'Unknown error')}"
|
||||
"error": f"Speech recognition failed: {speechResult.get('error', 'Unknown error')}"
|
||||
}
|
||||
|
||||
original_text = speech_result["text"]
|
||||
originalText = speechResult["text"]
|
||||
|
||||
# Step 2: Translation
|
||||
translation_result = await self.translate_text(
|
||||
text=original_text,
|
||||
source_language=from_language.split('-')[0], # Convert 'de-DE' to 'de'
|
||||
target_language=to_language.split('-')[0] # Convert 'en-US' to 'en'
|
||||
translationResult = await self.translateText(
|
||||
text=originalText,
|
||||
sourceLanguage=fromLanguage.split('-')[0], # Convert 'de-DE' to 'de'
|
||||
targetLanguage=toLanguage.split('-')[0] # Convert 'en-US' to 'en'
|
||||
)
|
||||
|
||||
if not translation_result["success"]:
|
||||
if not translationResult["success"]:
|
||||
return {
|
||||
"success": False,
|
||||
"original_text": original_text,
|
||||
"original_text": originalText,
|
||||
"translated_text": "",
|
||||
"error": f"Translation failed: {translation_result.get('error', 'Unknown error')}"
|
||||
"error": f"Translation failed: {translationResult.get('error', 'Unknown error')}"
|
||||
}
|
||||
|
||||
translated_text = translation_result["translated_text"]
|
||||
translatedText = translationResult["translated_text"]
|
||||
|
||||
logger.info(f"✅ Complete pipeline successful:")
|
||||
logger.info(f" Original: '{original_text}'")
|
||||
logger.info(f" Translated: '{translated_text}'")
|
||||
logger.info(f" Original: '{originalText}'")
|
||||
logger.info(f" Translated: '{translatedText}'")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"original_text": original_text,
|
||||
"translated_text": translated_text,
|
||||
"confidence": speech_result["confidence"],
|
||||
"source_language": from_language,
|
||||
"target_language": to_language
|
||||
"original_text": originalText,
|
||||
"translated_text": translatedText,
|
||||
"confidence": speechResult["confidence"],
|
||||
"source_language": fromLanguage,
|
||||
"target_language": toLanguage
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -478,19 +475,19 @@ class ConnectorGoogleSpeech:
|
|||
"error": str(e)
|
||||
}
|
||||
|
||||
def validate_audio_format(self, audio_content: bytes) -> Dict:
|
||||
def validateAudioFormat(self, audioContent: bytes) -> Dict:
|
||||
"""
|
||||
Validate audio format for Google Cloud Speech-to-Text.
|
||||
|
||||
Args:
|
||||
audio_content: Raw audio data
|
||||
audioContent: Raw audio data
|
||||
|
||||
Returns:
|
||||
Dict containing validation results
|
||||
"""
|
||||
try:
|
||||
# Basic validation
|
||||
if len(audio_content) < 100:
|
||||
if len(audioContent) < 100:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": "Audio too short (less than 100 bytes)"
|
||||
|
|
@ -502,11 +499,11 @@ class ConnectorGoogleSpeech:
|
|||
channels = 1 # Default fallback
|
||||
|
||||
# Debug: Log first few bytes for format detection
|
||||
logger.debug(f"Audio header bytes: {audio_content[:20].hex()}")
|
||||
logger.debug(f"Audio content length: {len(audio_content)} bytes")
|
||||
logger.debug(f"Audio header bytes: {audioContent[:20].hex()}")
|
||||
logger.debug(f"Audio content length: {len(audioContent)} bytes")
|
||||
|
||||
# Check for WEBM/OPUS format (common from web recordings)
|
||||
if audio_content.startswith(b'\x1a\x45\xdf\xa3'):
|
||||
if audioContent.startswith(b'\x1a\x45\xdf\xa3'):
|
||||
audio_format = "webm_opus"
|
||||
sample_rate = 48000 # WEBM OPUS typically uses 48kHz
|
||||
channels = 1
|
||||
|
|
@ -514,10 +511,10 @@ class ConnectorGoogleSpeech:
|
|||
|
||||
# Check for specific header patterns seen in logs (43c381...)
|
||||
# This appears to be a different audio format or corrupted WEBM
|
||||
elif audio_content.startswith(b'\x43\xc3\x81') and len(audio_content) > 1000:
|
||||
elif audioContent.startswith(b'\x43\xc3\x81') and len(audioContent) > 1000:
|
||||
# This might be a different format or corrupted audio
|
||||
# Try to detect if it's actually WEBM by looking deeper
|
||||
if b'webm' in audio_content[:200] or b'opus' in audio_content[:200]:
|
||||
if b'webm' in audioContent[:200] or b'opus' in audioContent[:200]:
|
||||
audio_format = "webm_opus"
|
||||
sample_rate = 48000
|
||||
channels = 1
|
||||
|
|
@ -527,58 +524,58 @@ class ConnectorGoogleSpeech:
|
|||
audio_format = "linear16"
|
||||
sample_rate = 16000
|
||||
channels = 1
|
||||
logger.warning(f"Unknown audio format with header {audio_content[:8].hex()}, trying LINEAR16")
|
||||
logger.warning(f"Unknown audio format with header {audioContent[:8].hex()}, trying LINEAR16")
|
||||
|
||||
# Check for WEBM format (alternative detection)
|
||||
elif b'webm' in audio_content[:100].lower() or b'opus' in audio_content[:100].lower():
|
||||
elif b'webm' in audioContent[:100].lower() or b'opus' in audioContent[:100].lower():
|
||||
audio_format = "webm_opus"
|
||||
sample_rate = 48000 # WEBM OPUS typically uses 48kHz
|
||||
channels = 1
|
||||
logger.info(f"Detected WEBM format: {sample_rate}Hz, {channels}ch")
|
||||
|
||||
# Check for MediaRecorder WEBM chunks (common in browser recordings)
|
||||
elif audio_content.startswith(b'\x1a\x45\xdf\xa3') and len(audio_content) > 1000:
|
||||
elif audioContent.startswith(b'\x1a\x45\xdf\xa3') and len(audioContent) > 1000:
|
||||
audio_format = "webm_opus"
|
||||
sample_rate = 48000 # Browser MediaRecorder typically uses 48kHz
|
||||
channels = 1
|
||||
logger.info(f"Detected MediaRecorder WEBM: {sample_rate}Hz, {channels}ch")
|
||||
|
||||
# Check for OPUS format by looking for OPUS magic bytes
|
||||
elif audio_content.startswith(b'OpusHead') or b'OpusHead' in audio_content[:50]:
|
||||
elif audioContent.startswith(b'OpusHead') or b'OpusHead' in audioContent[:50]:
|
||||
audio_format = "webm_opus"
|
||||
sample_rate = 48000 # OPUS typically uses 48kHz
|
||||
channels = 1
|
||||
logger.info(f"Detected OPUS format: {sample_rate}Hz, {channels}ch")
|
||||
|
||||
# Check for OGG format (often contains OPUS)
|
||||
elif audio_content.startswith(b'OggS'):
|
||||
elif audioContent.startswith(b'OggS'):
|
||||
audio_format = "webm_opus"
|
||||
sample_rate = 48000 # OGG OPUS typically uses 48kHz
|
||||
channels = 1
|
||||
logger.info(f"Detected OGG format: {sample_rate}Hz, {channels}ch")
|
||||
|
||||
# Check for WAV format
|
||||
elif audio_content.startswith(b'RIFF') and b'WAVE' in audio_content[:12]:
|
||||
elif audioContent.startswith(b'RIFF') and b'WAVE' in audioContent[:12]:
|
||||
audio_format = "wav"
|
||||
# Try to extract sample rate from WAV header
|
||||
try:
|
||||
# WAV header sample rate is at offset 24-27 (little endian)
|
||||
sample_rate = int.from_bytes(audio_content[24:28], 'little')
|
||||
channels = int.from_bytes(audio_content[22:24], 'little')
|
||||
sample_rate = int.from_bytes(audioContent[24:28], 'little')
|
||||
channels = int.from_bytes(audioContent[22:24], 'little')
|
||||
logger.info(f"Detected WAV format: {sample_rate}Hz, {channels}ch")
|
||||
except:
|
||||
sample_rate = 16000 # Fallback
|
||||
channels = 1
|
||||
|
||||
# Check for MP3 format
|
||||
elif audio_content.startswith(b'\xff\xfb') or audio_content.startswith(b'ID3'):
|
||||
elif audioContent.startswith(b'\xff\xfb') or audioContent.startswith(b'ID3'):
|
||||
audio_format = "mp3"
|
||||
sample_rate = 44100 # MP3 typically uses 44.1kHz
|
||||
channels = 2 # Usually stereo
|
||||
logger.info(f"Detected MP3 format: {sample_rate}Hz, {channels}ch")
|
||||
|
||||
# Check for FLAC format
|
||||
elif audio_content.startswith(b'fLaC'):
|
||||
elif audioContent.startswith(b'fLaC'):
|
||||
audio_format = "flac"
|
||||
sample_rate = 44100 # Common FLAC sample rate
|
||||
channels = 2
|
||||
|
|
@ -597,31 +594,31 @@ class ConnectorGoogleSpeech:
|
|||
estimated_duration = 3.0 # Assume 3 seconds for web recordings
|
||||
else:
|
||||
# Rough estimate for uncompressed audio
|
||||
estimated_duration = len(audio_content) / (sample_rate * channels * 2) # 16-bit = 2 bytes per sample
|
||||
estimated_duration = len(audioContent) / (sample_rate * channels * 2) # 16-bit = 2 bytes per sample
|
||||
|
||||
# Check if audio is too short (less than 0.5 seconds)
|
||||
if estimated_duration < 0.5:
|
||||
logger.warning(f"Audio too short: {estimated_duration:.2f}s, may not be recognized")
|
||||
|
||||
# Log audio details for debugging
|
||||
logger.info(f"Audio analysis: {len(audio_content)} bytes, {estimated_duration:.2f}s, {sample_rate}Hz, {channels}ch, format={audio_format}")
|
||||
logger.info(f"Audio analysis: {len(audioContent)} bytes, {estimated_duration:.2f}s, {sample_rate}Hz, {channels}ch, format={audio_format}")
|
||||
|
||||
# Check audio levels (simple check for silence)
|
||||
if audio_format == "webm_opus":
|
||||
# For WEBM, we can't easily check levels, but log the first few bytes
|
||||
logger.debug(f"Audio sample bytes: {audio_content[:20].hex()}")
|
||||
logger.debug(f"Audio sample bytes: {audioContent[:20].hex()}")
|
||||
# Check if audio has some variation (not all same bytes)
|
||||
if len(audio_content) > 100:
|
||||
sample_bytes = audio_content[100:200] # Skip header
|
||||
if len(audioContent) > 100:
|
||||
sample_bytes = audioContent[100:200] # Skip header
|
||||
if len(set(sample_bytes)) < 5: # Less than 5 different byte values
|
||||
logger.warning("Audio may be silent or very quiet (low byte variation)")
|
||||
else:
|
||||
logger.debug(f"Audio has good byte variation: {len(set(sample_bytes))} unique values")
|
||||
else:
|
||||
# For PCM audio, check for silence
|
||||
if len(audio_content) > 100:
|
||||
if len(audioContent) > 100:
|
||||
# Convert first 100 bytes to check for silence
|
||||
sample_bytes = audio_content[:100]
|
||||
sample_bytes = audioContent[:100]
|
||||
if all(b == 0 for b in sample_bytes):
|
||||
logger.warning("Audio appears to be silent (all zeros)")
|
||||
else:
|
||||
|
|
@ -635,7 +632,7 @@ class ConnectorGoogleSpeech:
|
|||
"format": audio_format,
|
||||
"sample_rate": sample_rate,
|
||||
"channels": channels,
|
||||
"size": len(audio_content),
|
||||
"size": len(audioContent),
|
||||
"estimated_duration": estimated_duration
|
||||
}
|
||||
|
||||
|
|
@ -645,7 +642,7 @@ class ConnectorGoogleSpeech:
|
|||
"error": f"Validation error: {e}"
|
||||
}
|
||||
|
||||
async def text_to_speech(self, text: str, language_code: str = "de-DE", voice_name: str = None) -> Dict[str, Any]:
|
||||
async def textToSpeech(self, text: str, languageCode: str = "de-DE", voiceName: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert text to speech using Google Cloud Text-to-Speech.
|
||||
|
||||
|
|
@ -658,31 +655,38 @@ class ConnectorGoogleSpeech:
|
|||
Dict with success status and audio data
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Converting text to speech: '{text[:50]}...' in {language_code}")
|
||||
logger.info(f"Converting text to speech: '{text[:50]}...' in {languageCode}")
|
||||
|
||||
# Set up the synthesis input
|
||||
synthesis_input = texttospeech.SynthesisInput(text=text)
|
||||
synthesisInput = texttospeech.SynthesisInput(text=text)
|
||||
|
||||
# Build the voice request
|
||||
selected_voice = voice_name or self._get_default_voice(language_code)
|
||||
logger.info(f"Using TTS voice: {selected_voice} for language: {language_code}")
|
||||
selectedVoice = voiceName or self._getDefaultVoice(languageCode)
|
||||
|
||||
if not selectedVoice:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"No voice specified for language {languageCode}. Please select a voice."
|
||||
}
|
||||
|
||||
logger.info(f"Using TTS voice: {selectedVoice} for language: {languageCode}")
|
||||
|
||||
voice = texttospeech.VoiceSelectionParams(
|
||||
language_code=language_code,
|
||||
name=selected_voice,
|
||||
language_code=languageCode,
|
||||
name=selectedVoice,
|
||||
ssml_gender=texttospeech.SsmlVoiceGender.NEUTRAL
|
||||
)
|
||||
|
||||
# Select the type of audio file to return
|
||||
audio_config = texttospeech.AudioConfig(
|
||||
audioConfig = texttospeech.AudioConfig(
|
||||
audio_encoding=texttospeech.AudioEncoding.MP3
|
||||
)
|
||||
|
||||
# Perform the text-to-speech request
|
||||
response = self.tts_client.synthesize_speech(
|
||||
input=synthesis_input,
|
||||
input=synthesisInput,
|
||||
voice=voice,
|
||||
audio_config=audio_config
|
||||
audio_config=audioConfig
|
||||
)
|
||||
|
||||
# Return the audio content
|
||||
|
|
@ -690,7 +694,7 @@ class ConnectorGoogleSpeech:
|
|||
"success": True,
|
||||
"audio_content": response.audio_content,
|
||||
"audio_format": "mp3",
|
||||
"language_code": language_code,
|
||||
"language_code": languageCode,
|
||||
"voice_name": voice.name
|
||||
}
|
||||
|
||||
|
|
@ -701,124 +705,121 @@ class ConnectorGoogleSpeech:
|
|||
"error": f"Text-to-Speech failed: {str(e)}"
|
||||
}
|
||||
|
||||
def _get_default_voice(self, language_code: str) -> str:
|
||||
def _getDefaultVoice(self, languageCode: str) -> str:
|
||||
"""
|
||||
Get default voice name for a language code.
|
||||
Uses female voices as default for better user experience.
|
||||
Returns None - no defaults, let the frontend handle voice selection.
|
||||
"""
|
||||
voice_mapping = {
|
||||
# European Languages
|
||||
'de-DE': 'de-DE-Wavenet-B', # German, female
|
||||
'en-US': 'en-US-Wavenet-B', # English US, female
|
||||
'en-GB': 'en-GB-Wavenet-B', # English UK, female
|
||||
'en-AU': 'en-AU-Wavenet-B', # English Australia, female
|
||||
'en-CA': 'en-CA-Wavenet-B', # English Canada, female
|
||||
'en-IN': 'en-IN-Wavenet-B', # English India, female
|
||||
'fr-FR': 'fr-FR-Wavenet-B', # French, female
|
||||
'fr-CA': 'fr-CA-Wavenet-B', # French Canada, female
|
||||
'es-ES': 'es-ES-Wavenet-B', # Spanish Spain, female
|
||||
'es-MX': 'es-MX-Wavenet-B', # Spanish Mexico, female
|
||||
'es-AR': 'es-AR-Wavenet-B', # Spanish Argentina, female
|
||||
'es-CO': 'es-CO-Wavenet-B', # Spanish Colombia, female
|
||||
'es-PE': 'es-PE-Wavenet-B', # Spanish Peru, female
|
||||
'es-VE': 'es-VE-Wavenet-B', # Spanish Venezuela, female
|
||||
'es-CL': 'es-CL-Wavenet-B', # Spanish Chile, female
|
||||
'es-UY': 'es-UY-Wavenet-B', # Spanish Uruguay, female
|
||||
'es-BO': 'es-BO-Wavenet-B', # Spanish Bolivia, female
|
||||
'es-CR': 'es-CR-Wavenet-B', # Spanish Costa Rica, female
|
||||
'es-EC': 'es-EC-Wavenet-B', # Spanish Ecuador, female
|
||||
'es-GT': 'es-GT-Wavenet-B', # Spanish Guatemala, female
|
||||
'es-HN': 'es-HN-Wavenet-B', # Spanish Honduras, female
|
||||
'es-NI': 'es-NI-Wavenet-B', # Spanish Nicaragua, female
|
||||
'es-PA': 'es-PA-Wavenet-B', # Spanish Panama, female
|
||||
'es-PY': 'es-PY-Wavenet-B', # Spanish Paraguay, female
|
||||
'es-PR': 'es-PR-Wavenet-B', # Spanish Puerto Rico, female
|
||||
'es-DO': 'es-DO-Wavenet-B', # Spanish Dominican Republic, female
|
||||
'es-SV': 'es-SV-Wavenet-B', # Spanish El Salvador, female
|
||||
'it-IT': 'it-IT-Wavenet-B', # Italian, female
|
||||
'pt-PT': 'pt-PT-Wavenet-B', # Portuguese Portugal, female
|
||||
'pt-BR': 'pt-BR-Wavenet-B', # Portuguese Brazil, female
|
||||
'nl-NL': 'nl-NL-Wavenet-B', # Dutch, female
|
||||
'pl-PL': 'pl-PL-Wavenet-B', # Polish, female
|
||||
'ru-RU': 'ru-RU-Wavenet-B', # Russian, female
|
||||
'uk-UA': 'uk-UA-Wavenet-B', # Ukrainian, female
|
||||
'cs-CZ': 'cs-CZ-Wavenet-B', # Czech, female
|
||||
'sk-SK': 'sk-SK-Wavenet-B', # Slovak, female
|
||||
'hu-HU': 'hu-HU-Wavenet-B', # Hungarian, female
|
||||
'ro-RO': 'ro-RO-Wavenet-B', # Romanian, female
|
||||
'bg-BG': 'bg-BG-Wavenet-B', # Bulgarian, female
|
||||
'hr-HR': 'hr-HR-Wavenet-B', # Croatian, female
|
||||
'sr-RS': 'sr-RS-Wavenet-B', # Serbian, female
|
||||
'sl-SI': 'sl-SI-Wavenet-B', # Slovenian, female
|
||||
'et-EE': 'et-EE-Wavenet-B', # Estonian, female
|
||||
'lv-LV': 'lv-LV-Wavenet-B', # Latvian, female
|
||||
'lt-LT': 'lt-LT-Wavenet-B', # Lithuanian, female
|
||||
'fi-FI': 'fi-FI-Wavenet-B', # Finnish, female
|
||||
'sv-SE': 'sv-SE-Wavenet-B', # Swedish, female
|
||||
'no-NO': 'no-NO-Wavenet-B', # Norwegian, female
|
||||
'da-DK': 'da-DK-Wavenet-B', # Danish, female
|
||||
'is-IS': 'is-IS-Wavenet-B', # Icelandic, female
|
||||
'el-GR': 'el-GR-Wavenet-B', # Greek, female
|
||||
'ca-ES': 'ca-ES-Wavenet-B', # Catalan, female
|
||||
'eu-ES': 'eu-ES-Wavenet-B', # Basque, female
|
||||
'gl-ES': 'gl-ES-Wavenet-B', # Galician, female
|
||||
'cy-GB': 'cy-GB-Wavenet-B', # Welsh, female
|
||||
'ga-IE': 'ga-IE-Wavenet-B', # Irish, female
|
||||
'mt-MT': 'mt-MT-Wavenet-B', # Maltese, female
|
||||
return None
|
||||
|
||||
async def getAvailableLanguages(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get available languages from Google Cloud Text-to-Speech.
|
||||
|
||||
Returns:
|
||||
Dict containing success status and list of available languages
|
||||
"""
|
||||
try:
|
||||
logger.info("🌐 Getting available languages from Google Cloud TTS")
|
||||
|
||||
# Asian Languages
|
||||
'ja-JP': 'ja-JP-Wavenet-B', # Japanese, female
|
||||
'ko-KR': 'ko-KR-Wavenet-B', # Korean, female
|
||||
'zh-CN': 'cmn-CN-Wavenet-B', # Chinese Mandarin, female
|
||||
'zh-TW': 'cmn-TW-Wavenet-B', # Chinese Traditional, female
|
||||
'zh-HK': 'cmn-HK-Wavenet-B', # Chinese Hong Kong, female
|
||||
'hi-IN': 'hi-IN-Wavenet-B', # Hindi, female
|
||||
'bn-IN': 'bn-IN-Wavenet-B', # Bengali, female
|
||||
'te-IN': 'te-IN-Wavenet-B', # Telugu, female
|
||||
'ta-IN': 'ta-IN-Wavenet-B', # Tamil, female
|
||||
'gu-IN': 'gu-IN-Wavenet-B', # Gujarati, female
|
||||
'kn-IN': 'kn-IN-Wavenet-B', # Kannada, female
|
||||
'ml-IN': 'ml-IN-Wavenet-B', # Malayalam, female
|
||||
'pa-IN': 'pa-IN-Wavenet-B', # Punjabi, female
|
||||
'or-IN': 'or-IN-Wavenet-B', # Odia, female
|
||||
'as-IN': 'as-IN-Wavenet-B', # Assamese, female
|
||||
'ne-NP': 'ne-NP-Wavenet-B', # Nepali, female
|
||||
'si-LK': 'si-LK-Wavenet-B', # Sinhala, female
|
||||
'th-TH': 'th-TH-Wavenet-B', # Thai, female
|
||||
'vi-VN': 'vi-VN-Wavenet-B', # Vietnamese, female
|
||||
'id-ID': 'id-ID-Wavenet-B', # Indonesian, female
|
||||
'ms-MY': 'ms-MY-Wavenet-B', # Malay, female
|
||||
'tl-PH': 'fil-PH-Wavenet-B', # Filipino, female
|
||||
'tr-TR': 'tr-TR-Wavenet-B', # Turkish, female
|
||||
# List voices from Google Cloud TTS
|
||||
response = self.tts_client.list_voices()
|
||||
|
||||
# Middle Eastern & African Languages
|
||||
'ar-SA': 'ar-SA-Wavenet-B', # Arabic Saudi Arabia, female
|
||||
'ar-EG': 'ar-EG-Wavenet-B', # Arabic Egypt, female
|
||||
'ar-AE': 'ar-AE-Wavenet-B', # Arabic UAE, female
|
||||
'ar-JO': 'ar-JO-Wavenet-B', # Arabic Jordan, female
|
||||
'ar-KW': 'ar-KW-Wavenet-B', # Arabic Kuwait, female
|
||||
'ar-LB': 'ar-LB-Wavenet-B', # Arabic Lebanon, female
|
||||
'ar-QA': 'ar-QA-Wavenet-B', # Arabic Qatar, female
|
||||
'ar-BH': 'ar-BH-Wavenet-B', # Arabic Bahrain, female
|
||||
'ar-OM': 'ar-OM-Wavenet-B', # Arabic Oman, female
|
||||
'ar-IQ': 'ar-IQ-Wavenet-B', # Arabic Iraq, female
|
||||
'ar-PS': 'ar-PS-Wavenet-B', # Arabic Palestine, female
|
||||
'ar-SY': 'ar-SY-Wavenet-B', # Arabic Syria, female
|
||||
'ar-YE': 'ar-YE-Wavenet-B', # Arabic Yemen, female
|
||||
'ar-MA': 'ar-MA-Wavenet-B', # Arabic Morocco, female
|
||||
'ar-DZ': 'ar-DZ-Wavenet-B', # Arabic Algeria, female
|
||||
'ar-TN': 'ar-TN-Wavenet-B', # Arabic Tunisia, female
|
||||
'ar-LY': 'ar-LY-Wavenet-B', # Arabic Libya, female
|
||||
'ar-SD': 'ar-SD-Wavenet-B', # Arabic Sudan, female
|
||||
'he-IL': 'he-IL-Wavenet-B', # Hebrew, female
|
||||
'fa-IR': 'fa-IR-Wavenet-B', # Persian, female
|
||||
'ur-PK': 'ur-PK-Wavenet-B', # Urdu, female
|
||||
'af-ZA': 'af-ZA-Wavenet-B', # Afrikaans, female
|
||||
'sw-KE': 'sw-KE-Wavenet-B', # Swahili Kenya, female
|
||||
'am-ET': 'am-ET-Wavenet-B', # Amharic, female
|
||||
'sw-TZ': 'sw-TZ-Wavenet-B', # Swahili Tanzania, female
|
||||
'zu-ZA': 'zu-ZA-Wavenet-B', # Zulu, female
|
||||
'xh-ZA': 'xh-ZA-Wavenet-B', # Xhosa, female
|
||||
}
|
||||
return voice_mapping.get(language_code, 'en-US-Wavenet-B')
|
||||
# Extract unique language codes
|
||||
# Note: Google TTS API doesn't provide language descriptions, only codes
|
||||
language_codes = set()
|
||||
for voice in response.voices:
|
||||
if voice.language_codes:
|
||||
language_codes.update(voice.language_codes)
|
||||
|
||||
# Convert to sorted list of language codes
|
||||
available_languages = sorted(list(language_codes))
|
||||
|
||||
logger.info(f"✅ Found {len(available_languages)} available languages")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"languages": available_languages
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to get available languages: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"languages": []
|
||||
}
|
||||
|
||||
async def getAvailableVoices(self, languageCode: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get available voices from Google Cloud Text-to-Speech.
|
||||
|
||||
Args:
|
||||
language_code: Optional language code to filter voices (e.g., 'de-DE', 'en-US')
|
||||
|
||||
Returns:
|
||||
Dict containing success status and list of available voices
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🎤 Getting available voices from Google Cloud TTS, language filter: {languageCode}")
|
||||
|
||||
# List voices from Google Cloud TTS
|
||||
response = self.tts_client.list_voices()
|
||||
|
||||
availableVoices = []
|
||||
|
||||
for voice in response.voices:
|
||||
# Extract language code from voice name (e.g., 'de-DE-Wavenet-A' -> 'de-DE')
|
||||
voiceLanguage = voice.language_codes[0] if voice.language_codes else None
|
||||
|
||||
# Filter by language if specified
|
||||
if languageCode and voiceLanguage != languageCode:
|
||||
continue
|
||||
|
||||
# Determine gender from voice name (A/C = male, B/D = female)
|
||||
gender = "Unknown"
|
||||
if voice.name:
|
||||
if voice.name.endswith(('-A', '-C')):
|
||||
gender = "Male"
|
||||
elif voice.name.endswith(('-B', '-D')):
|
||||
gender = "Female"
|
||||
|
||||
# Create voice info with all available fields from Google API
|
||||
voiceInfo = {
|
||||
"name": voice.name,
|
||||
"language_code": voiceLanguage,
|
||||
"language_codes": list(voice.language_codes) if voice.language_codes else [],
|
||||
"gender": gender,
|
||||
"ssml_gender": voice.ssml_gender.name if voice.ssml_gender else "NEUTRAL",
|
||||
"natural_sample_rate_hertz": voice.natural_sample_rate_hertz
|
||||
}
|
||||
|
||||
# Include any additional fields if available from Google API
|
||||
# Check for common fields that might exist
|
||||
for field_name in ['description', 'display_name', 'labels']:
|
||||
if hasattr(voice, field_name):
|
||||
field_value = getattr(voice, field_name, None)
|
||||
if field_value:
|
||||
voiceInfo[field_name] = field_value
|
||||
|
||||
availableVoices.append(voiceInfo)
|
||||
|
||||
# Sort by language code, then by gender, then by name
|
||||
availableVoices.sort(key=lambda x: (x["language_code"], x["gender"], x["name"]))
|
||||
|
||||
logger.info(f"✅ Found {len(availableVoices)} voices for language filter: {languageCode}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"voices": availableVoices,
|
||||
"total_count": len(availableVoices)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to get available voices: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"voices": []
|
||||
}
|
||||
|
||||
|
|
@ -1,268 +0,0 @@
|
|||
"""Tavily web search class."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from modules.interfaces.interfaceWebModel import (
|
||||
WebCrawlBase,
|
||||
WebCrawlDocumentData,
|
||||
WebCrawlRequest,
|
||||
WebCrawlResultItem,
|
||||
WebScrapeActionDocument,
|
||||
WebScrapeActionResult,
|
||||
WebScrapeBase,
|
||||
WebScrapeDocumentData,
|
||||
WebScrapeRequest,
|
||||
WebScrapeResultItem,
|
||||
WebSearchBase,
|
||||
WebSearchRequest,
|
||||
WebSearchActionResult,
|
||||
WebSearchActionDocument,
|
||||
WebSearchDocumentData,
|
||||
WebSearchResultItem,
|
||||
WebCrawlActionDocument,
|
||||
WebCrawlActionResult,
|
||||
get_web_search_min_results,
|
||||
get_web_search_max_results,
|
||||
)
|
||||
|
||||
# from modules.interfaces.interfaceChatModel import ActionResult, ActionDocument
|
||||
from tavily import AsyncTavilyClient
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Configuration loading functions
|
||||
def get_web_crawl_timeout() -> int:
|
||||
"""Get web crawl timeout from configuration"""
|
||||
return int(APP_CONFIG.get("Web_Crawl_TIMEOUT", "30"))
|
||||
|
||||
|
||||
def get_web_crawl_max_retries() -> int:
|
||||
"""Get web crawl max retries from configuration"""
|
||||
return int(APP_CONFIG.get("Web_Crawl_MAX_RETRIES", "3"))
|
||||
|
||||
|
||||
def get_web_crawl_retry_delay() -> int:
|
||||
"""Get web crawl retry delay from configuration"""
|
||||
return int(APP_CONFIG.get("Web_Crawl_RETRY_DELAY", "2"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TavilySearchResult:
|
||||
title: str
|
||||
url: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TavilyCrawlResult:
|
||||
url: str
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectorTavily(WebSearchBase, WebCrawlBase, WebScrapeBase):
|
||||
client: AsyncTavilyClient = None
|
||||
|
||||
@classmethod
|
||||
async def create(cls):
|
||||
api_key = APP_CONFIG.get("Connector_WebTavily_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("Tavily API key not configured. Please set Connector_WebTavily_API_KEY in config.ini")
|
||||
return cls(client=AsyncTavilyClient(api_key=api_key))
|
||||
|
||||
async def search_urls(self, request: WebSearchRequest) -> WebSearchActionResult:
|
||||
"""Handles the web search request.
|
||||
|
||||
Takes a query and returns a list of URLs.
|
||||
"""
|
||||
# Step 1: Search
|
||||
try:
|
||||
search_results = await self._search(request.query, request.max_results)
|
||||
except Exception as e:
|
||||
return WebSearchActionResult(success=False, error=str(e))
|
||||
|
||||
# Step 2: Build ActionResult
|
||||
try:
|
||||
result = self._build_search_action_result(search_results, request.query)
|
||||
except Exception as e:
|
||||
return WebSearchActionResult(success=False, error=str(e))
|
||||
|
||||
return result
|
||||
|
||||
async def crawl_urls(self, request: WebCrawlRequest) -> WebCrawlActionResult:
|
||||
"""Crawls the given URLs and returns the extracted text content."""
|
||||
# Step 1: Crawl
|
||||
try:
|
||||
crawl_results = await self._crawl(request.urls)
|
||||
except Exception as e:
|
||||
return WebCrawlActionResult(success=False, error=str(e))
|
||||
|
||||
# Step 2: Build ActionResult
|
||||
try:
|
||||
result = self._build_crawl_action_result(crawl_results, request.urls)
|
||||
except Exception as e:
|
||||
return WebCrawlActionResult(success=False, error=str(e))
|
||||
|
||||
return result
|
||||
|
||||
async def scrape(self, request: WebScrapeRequest) -> WebScrapeActionResult:
|
||||
"""Turns a query in a list of urls with extracted content."""
|
||||
# Step 1: Search
|
||||
try:
|
||||
search_results = await self._search(request.query, request.max_results)
|
||||
except Exception as e:
|
||||
return WebScrapeActionResult(success=False, error=str(e))
|
||||
|
||||
# Step 2: Crawl
|
||||
try:
|
||||
urls = [result.url for result in search_results]
|
||||
crawl_results = await self._crawl(urls)
|
||||
except Exception as e:
|
||||
return WebScrapeActionResult(success=False, error=str(e))
|
||||
|
||||
# Step 3: Build ActionResult
|
||||
try:
|
||||
result = self._build_scrape_action_result(crawl_results, request.query)
|
||||
except Exception as e:
|
||||
return WebScrapeActionResult(success=False, error=str(e))
|
||||
|
||||
return result
|
||||
|
||||
async def _search(self, query: str, max_results: int) -> list[TavilySearchResult]:
|
||||
"""Calls the Tavily API to perform a web search."""
|
||||
# Make sure max_results is within the allowed range
|
||||
min_results = get_web_search_min_results()
|
||||
max_allowed_results = get_web_search_max_results()
|
||||
if max_results < min_results or max_results > max_allowed_results:
|
||||
raise ValueError(f"max_results must be between {min_results} and {max_allowed_results}")
|
||||
|
||||
# Perform actual API call
|
||||
response = await self.client.search(query=query, max_results=max_results)
|
||||
|
||||
return [
|
||||
TavilySearchResult(title=result["title"], url=result["url"])
|
||||
for result in response["results"]
|
||||
]
|
||||
|
||||
def _build_search_action_result(
|
||||
self, search_results: list[TavilySearchResult], query: str = ""
|
||||
) -> WebSearchActionResult:
|
||||
"""Builds the ActionResult from the search results."""
|
||||
# Convert to result items
|
||||
result_items = [
|
||||
WebSearchResultItem(title=result.title, url=result.url)
|
||||
for result in search_results
|
||||
]
|
||||
|
||||
# Create document data with all results
|
||||
document_data = WebSearchDocumentData(
|
||||
query=query, results=result_items, total_count=len(result_items)
|
||||
)
|
||||
|
||||
# Create single document
|
||||
document = WebSearchActionDocument(
|
||||
documentName=f"web_search_results_{get_utc_timestamp()}.json",
|
||||
documentData=document_data,
|
||||
mimeType="application/json",
|
||||
)
|
||||
|
||||
return WebSearchActionResult(
|
||||
success=True, documents=[document], resultLabel="web_search_results"
|
||||
)
|
||||
|
||||
async def _crawl(self, urls: list) -> list[TavilyCrawlResult]:
|
||||
"""Calls the Tavily API to extract text content from URLs with retry logic."""
|
||||
import asyncio
|
||||
|
||||
max_retries = get_web_crawl_max_retries()
|
||||
retry_delay = get_web_crawl_retry_delay()
|
||||
timeout = get_web_crawl_timeout()
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
# Use asyncio.wait_for for timeout
|
||||
response = await asyncio.wait_for(
|
||||
self.client.extract(urls=urls, extract_depth="advanced", format="text"),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
return [
|
||||
TavilyCrawlResult(url=result["url"], content=result["raw_content"])
|
||||
for result in response["results"]
|
||||
]
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Crawl attempt {attempt + 1} timed out after {timeout} seconds")
|
||||
if attempt < max_retries:
|
||||
logger.info(f"Retrying in {retry_delay} seconds...")
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
raise Exception(f"Crawl failed after {max_retries + 1} attempts due to timeout")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Crawl attempt {attempt + 1} failed: {str(e)}")
|
||||
if attempt < max_retries:
|
||||
logger.info(f"Retrying in {retry_delay} seconds...")
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
raise Exception(f"Crawl failed after {max_retries + 1} attempts: {str(e)}")
|
||||
|
||||
def _build_crawl_action_result(
|
||||
self, crawl_results: list[TavilyCrawlResult], urls: list[str] = None
|
||||
) -> WebCrawlActionResult:
|
||||
"""Builds the ActionResult from the crawl results."""
|
||||
# Convert to result items
|
||||
result_items = [
|
||||
WebCrawlResultItem(url=result.url, content=result.content)
|
||||
for result in crawl_results
|
||||
]
|
||||
|
||||
# Create document data with all results
|
||||
document_data = WebCrawlDocumentData(
|
||||
urls=urls or [result.url for result in crawl_results],
|
||||
results=result_items,
|
||||
total_count=len(result_items),
|
||||
)
|
||||
|
||||
# Create single document
|
||||
document = WebCrawlActionDocument(
|
||||
documentName=f"web_crawl_results_{get_utc_timestamp()}.json",
|
||||
documentData=document_data,
|
||||
mimeType="application/json",
|
||||
)
|
||||
|
||||
return WebCrawlActionResult(
|
||||
success=True, documents=[document], resultLabel="web_crawl_results"
|
||||
)
|
||||
|
||||
def _build_scrape_action_result(
|
||||
self, crawl_results: list[TavilyCrawlResult], query: str = ""
|
||||
) -> WebScrapeActionResult:
|
||||
"""Builds the ActionResult from the scrape results."""
|
||||
# Convert to result items
|
||||
result_items = [
|
||||
WebScrapeResultItem(url=result.url, content=result.content)
|
||||
for result in crawl_results
|
||||
]
|
||||
|
||||
# Create document data with all results
|
||||
document_data = WebScrapeDocumentData(
|
||||
query=query,
|
||||
results=result_items,
|
||||
total_count=len(result_items),
|
||||
)
|
||||
|
||||
# Create single document
|
||||
document = WebScrapeActionDocument(
|
||||
documentName=f"web_scrape_results_{get_utc_timestamp()}.json",
|
||||
documentData=document_data,
|
||||
mimeType="application/json",
|
||||
)
|
||||
|
||||
return WebScrapeActionResult(
|
||||
success=True, documents=[document], resultLabel="web_scrape_results"
|
||||
)
|
||||
15
modules/datamodels/__init__.py
Normal file
15
modules/datamodels/__init__.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
"""
|
||||
Unified modules.datamodels package.
|
||||
|
||||
Usage examples:
|
||||
from modules.datamodels import ai
|
||||
from modules.datamodels import uam
|
||||
"""
|
||||
from . import datamodelAi as ai
|
||||
from . import datamodelUam as uam
|
||||
from . import datamodelSecurity as security
|
||||
from . import datamodelNeutralizer as neutralizer
|
||||
from . import datamodelChat as chat
|
||||
from . import datamodelFiles as files
|
||||
from . import datamodelVoice as voice
|
||||
from . import datamodelUtils as utils
|
||||
233
modules/datamodels/datamodelAi.py
Normal file
233
modules/datamodels/datamodelAi.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
from typing import Optional, List, Dict, Any, Callable, TYPE_CHECKING, Tuple
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
# Import ContentPart for runtime use (needed for Pydantic model rebuilding)
|
||||
from modules.datamodels.datamodelExtraction import ContentPart
|
||||
|
||||
# Operation Types
|
||||
class OperationTypeEnum(str, Enum):
|
||||
|
||||
# Planning Operation
|
||||
PLAN = "plan"
|
||||
|
||||
# Data Operations
|
||||
DATA_ANALYSE = "dataAnalyse"
|
||||
DATA_GENERATE = "dataGenerate"
|
||||
DATA_EXTRACT = "dataExtract"
|
||||
|
||||
# Image Operations
|
||||
IMAGE_ANALYSE = "imageAnalyse"
|
||||
IMAGE_GENERATE = "imageGenerate"
|
||||
|
||||
# Web Operations
|
||||
WEB_SEARCH = "webSearch" # Returns list of URLs only
|
||||
WEB_CRAWL = "webCrawl" # Web crawl for a given URL
|
||||
|
||||
|
||||
# Operation Type Rating - Helper class for capability ratings
|
||||
class OperationTypeRating(BaseModel):
|
||||
"""Represents an operation type with its capability rating (1-10)."""
|
||||
operationType: OperationTypeEnum = Field(description="The operation type")
|
||||
rating: int = Field(ge=1, le=10, description="Capability rating (1-10, higher = better for this operation type)")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.operationType.value}({self.rating})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"OperationTypeRating({self.operationType.value}, {self.rating})"
|
||||
|
||||
|
||||
# Helper function to create operation type ratings easily
|
||||
def createOperationTypeRatings(*ratings: Tuple[OperationTypeEnum, int]) -> List[OperationTypeRating]:
|
||||
"""
|
||||
Helper function to create operation type ratings easily.
|
||||
|
||||
Usage:
|
||||
operationTypes = createOperationTypeRatings(
|
||||
(OperationTypeEnum.DATA_ANALYSE, 8),
|
||||
(OperationTypeEnum.WEB_SEARCH, 10),
|
||||
(OperationTypeEnum.WEB_CRAWL, 9)
|
||||
)
|
||||
"""
|
||||
return [OperationTypeRating(operationType=ot, rating=rating) for ot, rating in ratings]
|
||||
|
||||
|
||||
# Processing Modes
|
||||
class ProcessingModeEnum(str, Enum):
|
||||
BASIC = "basic"
|
||||
ADVANCED = "advanced"
|
||||
DETAILED = "detailed"
|
||||
|
||||
# Priority Levels
|
||||
class PriorityEnum(str, Enum):
|
||||
SPEED = "speed"
|
||||
QUALITY = "quality"
|
||||
COST = "cost"
|
||||
BALANCED = "balanced"
|
||||
|
||||
|
||||
# Model Capabilities - REMOVED: Not used in business logic
|
||||
|
||||
|
||||
class AiModel(BaseModel):
|
||||
"""Enhanced AI model definition with dynamic capabilities."""
|
||||
|
||||
# Core identification
|
||||
name: str = Field(description="Actual LLM model name used for API calls")
|
||||
displayName: str = Field(description="Human-readable model name with module prefix")
|
||||
connectorType: str = Field(description="Type of connector (openai, anthropic, perplexity, tavily, etc.)")
|
||||
|
||||
# API configuration
|
||||
apiUrl: str = Field(description="API endpoint URL for this model")
|
||||
temperature: float = Field(default=0.2, ge=0.0, le=2.0, description="Default temperature for this model")
|
||||
|
||||
# Token and context limits
|
||||
maxTokens: int = Field(description="Maximum tokens this model can generate")
|
||||
contextLength: int = Field(description="Maximum context length this model can handle")
|
||||
|
||||
# Cost information
|
||||
costPer1kTokensInput: float = Field(default=0.0, description="Cost per 1000 input tokens")
|
||||
costPer1kTokensOutput: float = Field(default=0.0, description="Cost per 1000 output tokens")
|
||||
|
||||
# Performance ratings
|
||||
speedRating: int = Field(ge=1, le=10, description="Speed rating (1-10, higher = faster)")
|
||||
qualityRating: int = Field(ge=1, le=10, description="Quality rating (1-10, higher = better)")
|
||||
|
||||
# Function reference (not serialized)
|
||||
functionCall: Optional[Callable] = Field(default=None, exclude=True, description="Function to call for this model")
|
||||
calculatePriceUsd: Optional[Callable] = Field(default=None, exclude=True, description="Function to calculate price in USD")
|
||||
|
||||
# Selection criteria - capabilities with ratings
|
||||
priority: PriorityEnum = Field(default=PriorityEnum.BALANCED, description="Default priority for this model. See PriorityEnum for available values.")
|
||||
processingMode: ProcessingModeEnum = Field(default=ProcessingModeEnum.BASIC, description="Default processing mode. See ProcessingModeEnum for available values.")
|
||||
operationTypes: List[OperationTypeRating] = Field(default=[], description="Operation types this model can handle with capability ratings (1-10)")
|
||||
minContextLength: Optional[int] = Field(default=None, description="Minimum context length required")
|
||||
isAvailable: bool = Field(default=True, description="Whether model is currently available")
|
||||
|
||||
# Metadata
|
||||
version: Optional[str] = Field(default=None, description="Model version")
|
||||
lastUpdated: Optional[str] = Field(default=None, description="Last update timestamp")
|
||||
|
||||
class Config:
|
||||
arbitraryTypesAllowed = True # Allow Callable type
|
||||
|
||||
|
||||
class SelectionRule(BaseModel):
|
||||
"""A rule for model selection."""
|
||||
name: str = Field(description="Rule name identifier")
|
||||
condition: str = Field(description="Description of when this rule applies")
|
||||
weight: float = Field(description="Weight for scoring (higher = more important)")
|
||||
operationTypes: List[OperationTypeEnum] = Field(description="Operation types this rule applies to")
|
||||
priority: PriorityEnum = Field(default=PriorityEnum.BALANCED, description="Priority level for this rule")
|
||||
minQualityRating: Optional[int] = Field(default=None, description="Minimum quality rating")
|
||||
maxCost: Optional[float] = Field(default=None, description="Maximum cost threshold")
|
||||
minContextLength: Optional[int] = Field(default=None, description="Minimum context length required")
|
||||
|
||||
|
||||
class AiCallOptions(BaseModel):
|
||||
"""Options for centralized AI processing with clear operation types and tags."""
|
||||
operationType: OperationTypeEnum = Field(default=OperationTypeEnum.DATA_ANALYSE, description="Type of operation")
|
||||
priority: PriorityEnum = Field(default=PriorityEnum.BALANCED, description="Priority level")
|
||||
compressPrompt: bool = Field(default=True, description="Whether to compress the prompt")
|
||||
compressContext: bool = Field(default=True, description="If False: process each chunk; If True: summarize and work on summary")
|
||||
processDocumentsIndividually: bool = Field(default=True, description="If True, process each document separately; else pool docs")
|
||||
maxCost: Optional[float] = Field(default=None, description="Max cost budget")
|
||||
maxProcessingTime: Optional[int] = Field(default=None, description="Max processing time in seconds")
|
||||
processingMode: ProcessingModeEnum = Field(default=ProcessingModeEnum.BASIC, description="Processing mode")
|
||||
resultFormat: Optional[str] = Field(default=None, description="Expected result format: txt, json, csv, xml, etc.")
|
||||
|
||||
safetyMargin: float = Field(default=0.1, ge=0.0, le=0.5, description="Safety margin for token limits (0.0-0.5)")
|
||||
|
||||
# Model generation parameters
|
||||
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0, description="Temperature for response generation (0.0-2.0, lower = more consistent)")
|
||||
maxParts: Optional[int] = Field(default=1000, ge=1, le=1000, description="Maximum number of continuation parts to fetch")
|
||||
|
||||
|
||||
class AiCallRequest(BaseModel):
|
||||
"""Centralized AI call request payload for interface use."""
|
||||
|
||||
prompt: str = Field(description="The user prompt")
|
||||
context: Optional[str] = Field(default=None, description="Optional external context (e.g., extracted docs)")
|
||||
options: AiCallOptions = Field(default_factory=AiCallOptions)
|
||||
contentParts: Optional[List['ContentPart']] = None # NEW: Content parts for model-aware chunking
|
||||
|
||||
|
||||
class AiCallResponse(BaseModel):
|
||||
"""Standardized AI call response."""
|
||||
|
||||
content: str = Field(description="AI response content")
|
||||
modelName: str = Field(description="Selected model name")
|
||||
priceUsd: float = Field(default=0.0, description="Calculated price in USD")
|
||||
processingTime: float = Field(default=0.0, description="Duration in seconds")
|
||||
bytesSent: int = Field(default=0, description="Input data size in bytes")
|
||||
bytesReceived: int = Field(default=0, description="Output data size in bytes")
|
||||
errorCount: int = Field(default=0, description="0 for success, 1+ for errors")
|
||||
|
||||
|
||||
class AiModelCall(BaseModel):
|
||||
"""Standardized input for AI model calls."""
|
||||
|
||||
messages: List[Dict[str, Any]] = Field(description="Messages in OpenAI format (role, content)")
|
||||
model: Optional[AiModel] = Field(default=None, description="The AI model being called")
|
||||
options: AiCallOptions = Field(default_factory=AiCallOptions, description="Additional model-specific options")
|
||||
|
||||
class Config:
|
||||
arbitraryTypesAllowed = True
|
||||
|
||||
|
||||
class AiModelResponse(BaseModel):
|
||||
"""Standardized output from AI model calls."""
|
||||
|
||||
content: str = Field(description="The AI response content")
|
||||
success: bool = Field(default=True, description="Whether the call was successful")
|
||||
error: Optional[str] = Field(default=None, description="Error message if success=False")
|
||||
|
||||
# Optional metadata that models can include
|
||||
modelId: Optional[str] = Field(default=None, description="Model identifier used")
|
||||
processingTime: Optional[float] = Field(default=None, description="Processing time in seconds")
|
||||
tokensUsed: Optional[Dict[str, int]] = Field(default=None, description="Token usage (input, output, total)")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, description="Additional model-specific metadata")
|
||||
|
||||
class Config:
|
||||
arbitraryTypesAllowed = True
|
||||
|
||||
|
||||
# Structured prompt models for specialized operations
|
||||
|
||||
class AiCallPromptWebSearch(BaseModel):
|
||||
"""Structured prompt format for WEB_SEARCH operation - returns list of URLs."""
|
||||
|
||||
instruction: str = Field(description="Search instruction/query for finding relevant URLs")
|
||||
country: Optional[str] = Field(default=None, description="Two-digit country code (lowercase, e.g., ch, us, de, fr)")
|
||||
maxNumberPages: Optional[int] = Field(default=10, description="Maximum number of pages to search (default: 10)")
|
||||
language: Optional[str] = Field(default=None, description="Language code (lowercase, e.g., de, en, fr)")
|
||||
researchDepth: Optional[str] = Field(default="general", description="Research depth: fast (maxDepth=1), general (maxDepth=2), deep (maxDepth=3)")
|
||||
|
||||
class Config:
|
||||
pass
|
||||
|
||||
|
||||
class AiCallPromptWebCrawl(BaseModel):
|
||||
"""Structured prompt format for WEB_CRAWL operation - crawls ONE specific URL and returns content."""
|
||||
|
||||
instruction: str = Field(description="Instruction for what content to extract from URL")
|
||||
url: str = Field(description="Single URL to crawl")
|
||||
maxDepth: Optional[int] = Field(default=2, description="Maximum number of hops from starting page (default: 2)")
|
||||
maxWidth: Optional[int] = Field(default=10, description="Maximum pages to crawl per level (default: 10)")
|
||||
|
||||
class Config:
|
||||
pass
|
||||
|
||||
|
||||
class AiCallPromptImage(BaseModel):
|
||||
"""Structured prompt format for image generation."""
|
||||
|
||||
prompt: str = Field(description="Text description of the image to generate")
|
||||
size: Optional[str] = Field(default="1024x1024", description="Image size (1024x1024, 1792x1024, 1024x1792)")
|
||||
quality: Optional[str] = Field(default="standard", description="Image quality (standard, hd)")
|
||||
style: Optional[str] = Field(default="vivid", description="Image style (vivid, natural)")
|
||||
|
||||
class Config:
|
||||
pass
|
||||
|
||||
1098
modules/datamodels/datamodelChat.py
Normal file
1098
modules/datamodels/datamodelChat.py
Normal file
File diff suppressed because it is too large
Load diff
109
modules/datamodels/datamodelDocument.py
Normal file
109
modules/datamodels/datamodelDocument.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
from typing import Any, Dict, List, Optional, Literal, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class DocumentMetadata(BaseModel):
|
||||
"""Metadata for the entire document."""
|
||||
title: str = Field(description="Document title")
|
||||
author: Optional[str] = Field(default=None, description="Document author")
|
||||
createdAt: datetime = Field(default_factory=datetime.now, description="Creation timestamp")
|
||||
sourceDocuments: List[str] = Field(default_factory=list, description="Source document IDs")
|
||||
extractionMethod: str = Field(default="ai_extraction", description="Method used for extraction")
|
||||
version: str = Field(default="1.0", description="Document version")
|
||||
|
||||
|
||||
class TableData(BaseModel):
|
||||
"""Structured table data."""
|
||||
headers: List[str] = Field(description="Table column headers")
|
||||
rows: List[List[str]] = Field(description="Table data rows")
|
||||
caption: Optional[str] = Field(default=None, description="Table caption")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Table metadata")
|
||||
|
||||
|
||||
class ListItem(BaseModel):
|
||||
"""Individual list item with optional sub-items."""
|
||||
text: str = Field(description="List item text")
|
||||
subitems: Optional[List['ListItem']] = Field(default=None, description="Nested sub-items")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Item metadata")
|
||||
|
||||
|
||||
class BulletList(BaseModel):
|
||||
"""Bulleted or numbered list."""
|
||||
items: List[ListItem] = Field(description="List items")
|
||||
listType: Literal["bullet", "numbered", "checklist"] = Field(default="bullet", description="List type")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="List metadata")
|
||||
|
||||
|
||||
class Paragraph(BaseModel):
|
||||
"""Text paragraph with optional formatting."""
|
||||
text: str = Field(description="Paragraph text")
|
||||
formatting: Optional[Dict[str, Any]] = Field(default=None, description="Text formatting (bold, italic, etc.)")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Paragraph metadata")
|
||||
|
||||
|
||||
class Heading(BaseModel):
|
||||
"""Document heading."""
|
||||
text: str = Field(description="Heading text")
|
||||
level: int = Field(ge=1, le=6, description="Heading level (1-6)")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Heading metadata")
|
||||
|
||||
|
||||
class CodeBlock(BaseModel):
|
||||
"""Code block with syntax highlighting."""
|
||||
code: str = Field(description="Code content")
|
||||
language: Optional[str] = Field(default=None, description="Programming language")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Code block metadata")
|
||||
|
||||
|
||||
class Image(BaseModel):
|
||||
"""Image with metadata."""
|
||||
data: str = Field(description="Base64 encoded image data")
|
||||
altText: Optional[str] = Field(default=None, description="Alternative text")
|
||||
caption: Optional[str] = Field(default=None, description="Image caption")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Image metadata")
|
||||
|
||||
|
||||
class DocumentSection(BaseModel):
|
||||
"""A section of the document containing one or more content elements."""
|
||||
id: str = Field(description="Unique section identifier")
|
||||
title: Optional[str] = Field(default=None, description="Section title")
|
||||
contentType: Literal["table", "list", "paragraph", "heading", "code", "image", "mixed"] = Field(description="Primary content type")
|
||||
elements: List[Union[TableData, BulletList, Paragraph, Heading, CodeBlock, Image]] = Field(description="Content elements in this section")
|
||||
order: int = Field(description="Section order in document")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Section metadata")
|
||||
|
||||
|
||||
class StructuredDocument(BaseModel):
|
||||
"""Complete structured document in JSON format."""
|
||||
metadata: DocumentMetadata = Field(description="Document metadata")
|
||||
sections: List[DocumentSection] = Field(description="Document sections")
|
||||
summary: Optional[str] = Field(default=None, description="Document summary")
|
||||
tags: List[str] = Field(default_factory=list, description="Document tags")
|
||||
|
||||
def getSectionsByType(self, contentType: str) -> List[DocumentSection]:
|
||||
"""Get all sections of a specific content type."""
|
||||
return [section for section in self.sections if section.contentType == contentType]
|
||||
|
||||
def getAllTables(self) -> List[TableData]:
|
||||
"""Get all table data from the document."""
|
||||
tables = []
|
||||
for section in self.sections:
|
||||
for element in section.elements:
|
||||
if isinstance(element, TableData):
|
||||
tables.append(element)
|
||||
return tables
|
||||
|
||||
def getAllLists(self) -> List[BulletList]:
|
||||
"""Get all lists from the document."""
|
||||
lists = []
|
||||
for section in self.sections:
|
||||
for element in section.elements:
|
||||
if isinstance(element, BulletList):
|
||||
lists.append(element)
|
||||
return lists
|
||||
|
||||
|
||||
|
||||
# Update forward references
|
||||
ListItem.model_rebuild()
|
||||
91
modules/datamodels/datamodelExtraction.py
Normal file
91
modules/datamodels/datamodelExtraction.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
from typing import Any, Dict, List, Optional, Literal, TYPE_CHECKING
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from modules.datamodels.datamodelAi import OperationTypeEnum
|
||||
|
||||
|
||||
class ContentPart(BaseModel):
|
||||
id: str = Field(description="Unique content part identifier")
|
||||
parentId: Optional[str] = Field(default=None, description="Optional parent content part id")
|
||||
label: str = Field(description="Human readable label of the part")
|
||||
typeGroup: str = Field(description="Logical type group: text, table, structure, binary, ...")
|
||||
mimeType: str = Field(description="MIME type of the part payload")
|
||||
data: str = Field(default="", description="Primary data payload, often extracted text")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Arbitrary metadata for the part")
|
||||
|
||||
|
||||
class ContentExtracted(BaseModel):
|
||||
id: str = Field(description="Extraction id or source document id")
|
||||
parts: List[ContentPart] = Field(default_factory=list, description="List of extracted parts")
|
||||
summary: Optional[Dict[str, Any]] = Field(default=None, description="Optional extraction summary")
|
||||
|
||||
|
||||
class ChunkResult(BaseModel):
|
||||
"""Preserves the relationship between a chunk and its AI result."""
|
||||
originalChunk: ContentPart
|
||||
aiResult: str
|
||||
chunkIndex: int
|
||||
documentId: str
|
||||
processingTime: float = 0.0
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class PartResult(BaseModel):
|
||||
"""Preserves the relationship between a content part and its AI result."""
|
||||
originalPart: ContentPart
|
||||
aiResult: str
|
||||
partIndex: int
|
||||
documentId: str
|
||||
processingTime: float = 0.0
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class MergeStrategy(BaseModel):
|
||||
"""Strategy configuration for merging content parts and AI results."""
|
||||
groupBy: str = Field(default="typeGroup", description="Field to group parts by (typeGroup, parentId, label, etc.)")
|
||||
orderBy: str = Field(default="id", description="Field to order parts within groups (id, order, pageIndex, etc.)")
|
||||
mergeType: Literal["concatenate", "hierarchical", "intelligent"] = Field(default="concatenate", description="How to merge content within groups")
|
||||
maxSize: Optional[int] = Field(default=None, description="Maximum size for merged content in bytes")
|
||||
textMerge: Optional[Dict[str, Any]] = Field(default=None, description="Text-specific merge settings (separator, formatting, etc.)")
|
||||
tableMerge: Optional[Dict[str, Any]] = Field(default=None, description="Table-specific merge settings (header handling, etc.)")
|
||||
structureMerge: Optional[Dict[str, Any]] = Field(default=None, description="Structure-specific merge settings (hierarchy, etc.)")
|
||||
aiResultMerge: Optional[Dict[str, Any]] = Field(default=None, description="AI result merging settings (prompt, context, etc.)")
|
||||
preserveChunks: bool = Field(default=False, description="Whether to preserve individual chunks or merge them")
|
||||
chunkSeparator: str = Field(default="\n\n---\n\n", description="Separator between chunks when merging")
|
||||
preserveMetadata: bool = Field(default=True, description="Whether to preserve metadata from original parts")
|
||||
metadataFields: Optional[List[str]] = Field(default=None, description="Specific metadata fields to preserve (None = all)")
|
||||
onError: Literal["skip", "include", "fail"] = Field(default="skip", description="How to handle errors during merging")
|
||||
validateContent: bool = Field(default=True, description="Whether to validate content before merging")
|
||||
useIntelligentMerging: bool = Field(default=False, description="Whether to use intelligent token-aware merging")
|
||||
prompt: Optional[str] = Field(default=None, description="Prompt for intelligent merging")
|
||||
capabilities: Optional[Dict[str, Any]] = Field(default=None, description="Model capabilities for intelligent merging")
|
||||
|
||||
|
||||
class ExtractionOptions(BaseModel):
|
||||
"""Options for document extraction and processing with clear data structures."""
|
||||
|
||||
# Core extraction parameters
|
||||
prompt: str = Field(description="Extraction prompt for AI processing")
|
||||
operationType: 'OperationTypeEnum' = Field(description="Type of operation for AI processing")
|
||||
processDocumentsIndividually: bool = Field(default=True, description="Process each document separately")
|
||||
|
||||
# Image processing parameters
|
||||
imageMaxPixels: int = Field(default=1024 * 1024, ge=1, description="Maximum pixels for image processing")
|
||||
imageQuality: int = Field(default=85, ge=1, le=100, description="Image quality (1-100)")
|
||||
|
||||
# Merging strategy
|
||||
mergeStrategy: MergeStrategy = Field(description="Strategy for merging extraction results")
|
||||
|
||||
# Optional chunking parameters (for backward compatibility)
|
||||
chunkAllowed: Optional[bool] = Field(default=None, description="Whether chunking is allowed")
|
||||
maxSize: Optional[int] = Field(default=None, description="Maximum size for processing")
|
||||
textChunkSize: Optional[int] = Field(default=None, description="Size for text chunks")
|
||||
imageChunkSize: Optional[int] = Field(default=None, description="Size for image chunks")
|
||||
|
||||
# Additional processing options
|
||||
enableParallelProcessing: bool = Field(default=True, description="Enable parallel processing of chunks")
|
||||
maxConcurrentChunks: int = Field(default=5, ge=1, le=20, description="Maximum number of chunks to process concurrently")
|
||||
|
||||
class Config:
|
||||
arbitraryTypesAllowed = True # Allow OperationTypeEnum import
|
||||
73
modules/datamodels/datamodelFiles.py
Normal file
73
modules/datamodels/datamodelFiles.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""File-related datamodels: FileItem, FilePreview, FileData."""
|
||||
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
from modules.shared.timezoneUtils import getUtcTimestamp
|
||||
import uuid
|
||||
import base64
|
||||
|
||||
|
||||
class FileItem(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
mandateId: str = Field(description="ID of the mandate this file belongs to", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
fileName: str = Field(description="Name of the file", frontend_type="text", frontend_readonly=False, frontend_required=True)
|
||||
mimeType: str = Field(description="MIME type of the file", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
fileHash: str = Field(description="Hash of the file", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
fileSize: int = Field(description="Size of the file in bytes", frontend_type="integer", frontend_readonly=True, frontend_required=False)
|
||||
creationDate: float = Field(default_factory=getUtcTimestamp, description="Date when the file was created (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
||||
|
||||
registerModelLabels(
|
||||
"FileItem",
|
||||
{"en": "File Item", "fr": "Élément de fichier"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||
"fileName": {"en": "fileName", "fr": "Nom de fichier"},
|
||||
"mimeType": {"en": "MIME Type", "fr": "Type MIME"},
|
||||
"fileHash": {"en": "File Hash", "fr": "Hash du fichier"},
|
||||
"fileSize": {"en": "File Size", "fr": "Taille du fichier"},
|
||||
"creationDate": {"en": "Creation Date", "fr": "Date de création"},
|
||||
},
|
||||
)
|
||||
|
||||
class FilePreview(BaseModel):
|
||||
content: Union[str, bytes] = Field(description="File content (text or binary)")
|
||||
mimeType: str = Field(description="MIME type of the file")
|
||||
fileName: str = Field(description="Original fileName")
|
||||
isText: bool = Field(description="Whether the content is text (True) or binary (False)")
|
||||
encoding: Optional[str] = Field(None, description="Text encoding if content is text")
|
||||
size: int = Field(description="Size of the content in bytes")
|
||||
|
||||
def toDictWithBase64Encoding(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary with base64 encoding for binary content."""
|
||||
data = self.model_dump()
|
||||
if isinstance(data.get("content"), bytes):
|
||||
data["content"] = base64.b64encode(data["content"]).decode("utf-8")
|
||||
return data
|
||||
registerModelLabels(
|
||||
"FilePreview",
|
||||
{"en": "File Preview", "fr": "Aperçu du fichier"},
|
||||
{
|
||||
"content": {"en": "Content", "fr": "Contenu"},
|
||||
"mimeType": {"en": "MIME Type", "fr": "Type MIME"},
|
||||
"fileName": {"en": "fileName", "fr": "Nom de fichier"},
|
||||
"isText": {"en": "Is Text", "fr": "Est du texte"},
|
||||
"encoding": {"en": "Encoding", "fr": "Encodage"},
|
||||
"size": {"en": "Size", "fr": "Taille"},
|
||||
},
|
||||
)
|
||||
|
||||
class FileData(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
data: str = Field(description="File data content")
|
||||
base64Encoded: bool = Field(description="Whether the data is base64 encoded")
|
||||
registerModelLabels(
|
||||
"FileData",
|
||||
{"en": "File Data", "fr": "Données de fichier"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"data": {"en": "Data", "fr": "Données"},
|
||||
"base64Encoded": {"en": "Base64 Encoded", "fr": "Encodé en Base64"},
|
||||
},
|
||||
)
|
||||
90
modules/datamodels/datamodelJson.py
Normal file
90
modules/datamodels/datamodelJson.py
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
"""
|
||||
Unified JSON document schema and helpers used by both generation prompts and renderers.
|
||||
|
||||
This defines a single canonical template and the supported section types.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
# Canonical list of supported section types across the system
|
||||
supportedSectionTypes: List[str] = [
|
||||
"table",
|
||||
"bullet_list",
|
||||
"heading",
|
||||
"paragraph",
|
||||
"code_block",
|
||||
"image",
|
||||
]
|
||||
|
||||
# Canonical JSON template used for AI generation (documents array + sections)
|
||||
# Rendering pipelines can select the first document and read its sections.
|
||||
jsonTemplateDocument: str = """{
|
||||
"metadata": {
|
||||
"split_strategy": "single_document",
|
||||
"source_documents": [],
|
||||
"extraction_method": "ai_generation"
|
||||
},
|
||||
"documents": [
|
||||
{
|
||||
"id": "doc_1",
|
||||
"title": "{{DOCUMENT_TITLE}}",
|
||||
"filename": "document.json",
|
||||
"sections": [
|
||||
{
|
||||
"id": "section_heading_example",
|
||||
"content_type": "heading",
|
||||
"elements": [
|
||||
{"level": 1, "text": "Heading Text"}
|
||||
],
|
||||
"order": 0
|
||||
},
|
||||
{
|
||||
"id": "section_paragraph_example",
|
||||
"content_type": "paragraph",
|
||||
"elements": [
|
||||
{"text": "Paragraph text content"}
|
||||
],
|
||||
"order": 0
|
||||
},
|
||||
{
|
||||
"id": "section_bullet_list_example",
|
||||
"content_type": "bullet_list",
|
||||
"elements": [
|
||||
{
|
||||
"items": ["Item 1", "Item 2"]
|
||||
}
|
||||
],
|
||||
"order": 0
|
||||
},
|
||||
{
|
||||
"id": "section_table_example",
|
||||
"content_type": "table",
|
||||
"elements": [
|
||||
{
|
||||
"headers": ["Column 1", "Column 2"],
|
||||
"rows": [
|
||||
["Row 1 Col 1", "Row 1 Col 2"],
|
||||
["Row 2 Col 1", "Row 2 Col 2"]
|
||||
],
|
||||
"caption": "Table caption"
|
||||
}
|
||||
],
|
||||
"order": 0
|
||||
},
|
||||
{
|
||||
"id": "section_code_example",
|
||||
"content_type": "code_block",
|
||||
"elements": [
|
||||
{
|
||||
"code": "function example() { return true; }",
|
||||
"language": "javascript"
|
||||
}
|
||||
],
|
||||
"order": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}"""
|
||||
|
||||
|
||||
51
modules/datamodels/datamodelNeutralizer.py
Normal file
51
modules/datamodels/datamodelNeutralizer.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
"""Neutralizer models: DataNeutraliserConfig and DataNeutralizerAttributes."""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
|
||||
|
||||
class DataNeutraliserConfig(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the configuration", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
mandateId: str = Field(description="ID of the mandate this configuration belongs to", frontend_type="text", frontend_readonly=True, frontend_required=True)
|
||||
userId: str = Field(description="ID of the user who created this configuration", frontend_type="text", frontend_readonly=True, frontend_required=True)
|
||||
enabled: bool = Field(default=True, description="Whether data neutralization is enabled", frontend_type="checkbox", frontend_readonly=False, frontend_required=False)
|
||||
namesToParse: str = Field(default="", description="Multiline list of names to parse for neutralization", frontend_type="textarea", frontend_readonly=False, frontend_required=False)
|
||||
sharepointSourcePath: str = Field(default="", description="SharePoint path to read files for neutralization", frontend_type="text", frontend_readonly=False, frontend_required=False)
|
||||
sharepointTargetPath: str = Field(default="", description="SharePoint path to store neutralized files", frontend_type="text", frontend_readonly=False, frontend_required=False)
|
||||
registerModelLabels(
|
||||
"DataNeutraliserConfig",
|
||||
{"en": "Data Neutralization Config", "fr": "Configuration de neutralisation des données"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID de mandat"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"enabled": {"en": "Enabled", "fr": "Activé"},
|
||||
"namesToParse": {"en": "Names to Parse", "fr": "Noms à analyser"},
|
||||
"sharepointSourcePath": {"en": "Source Path", "fr": "Chemin source"},
|
||||
"sharepointTargetPath": {"en": "Target Path", "fr": "Chemin cible"},
|
||||
},
|
||||
)
|
||||
|
||||
class DataNeutralizerAttributes(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the attribute mapping (used as UID in neutralized files)", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
mandateId: str = Field(description="ID of the mandate this attribute belongs to", frontend_type="text", frontend_readonly=True, frontend_required=True)
|
||||
userId: str = Field(description="ID of the user who created this attribute", frontend_type="text", frontend_readonly=True, frontend_required=True)
|
||||
originalText: str = Field(description="Original text that was neutralized", frontend_type="text", frontend_readonly=True, frontend_required=True)
|
||||
fileId: Optional[str] = Field(default=None, description="ID of the file this attribute belongs to", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
patternType: str = Field(description="Type of pattern that matched (email, phone, name, etc.)", frontend_type="text", frontend_readonly=True, frontend_required=True)
|
||||
registerModelLabels(
|
||||
"DataNeutralizerAttributes",
|
||||
{"en": "Neutralized Data Attribute", "fr": "Attribut de données neutralisées"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID de mandat"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"originalText": {"en": "Original Text", "fr": "Texte original"},
|
||||
"fileId": {"en": "File ID", "fr": "ID de fichier"},
|
||||
"patternType": {"en": "Pattern Type", "fr": "Type de modèle"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
72
modules/datamodels/datamodelPagination.py
Normal file
72
modules/datamodels/datamodelPagination.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""
|
||||
Pagination models for server-side pagination, sorting, and filtering.
|
||||
|
||||
All models use camelStyle naming convention for consistency with frontend.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Generic, TypeVar
|
||||
from pydantic import BaseModel, Field
|
||||
import math
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class SortField(BaseModel):
|
||||
"""
|
||||
Single sort field configuration.
|
||||
"""
|
||||
field: str = Field(..., description="Field name to sort by")
|
||||
direction: str = Field(..., description="Sort direction: 'asc' or 'desc'")
|
||||
|
||||
|
||||
class PaginationParams(BaseModel):
|
||||
"""
|
||||
Complete pagination state including page, sorting, and filters.
|
||||
"""
|
||||
page: int = Field(ge=1, description="Current page number (1-based)")
|
||||
pageSize: int = Field(ge=1, le=1000, description="Number of items per page")
|
||||
sort: List[SortField] = Field(default_factory=list, description="List of sort fields in priority order")
|
||||
filters: Optional[Dict[str, Any]] = Field(default=None, description="Filter criteria (structure TBD for future implementation)")
|
||||
|
||||
|
||||
class PaginationRequest(BaseModel):
|
||||
"""
|
||||
Pagination request parameters sent from frontend to backend.
|
||||
All fields are optional. If pagination=None, no pagination is applied.
|
||||
"""
|
||||
pagination: Optional[PaginationParams] = None
|
||||
|
||||
|
||||
class PaginatedResult(BaseModel):
|
||||
"""
|
||||
Internal result structure from interface layer.
|
||||
Used when pagination is requested.
|
||||
"""
|
||||
items: List[Any]
|
||||
totalItems: int
|
||||
totalPages: int # Calculated as: math.ceil(totalItems / pageSize)
|
||||
|
||||
|
||||
class PaginationMetadata(BaseModel):
|
||||
"""
|
||||
Pagination metadata returned to frontend for rendering controls.
|
||||
Contains all information needed to render pagination UI and handle user interactions.
|
||||
"""
|
||||
currentPage: int = Field(..., description="Current page number (1-based)")
|
||||
pageSize: int = Field(..., description="Number of items per page")
|
||||
totalItems: int = Field(..., description="Total number of items across all pages (after filters)")
|
||||
totalPages: int = Field(..., description="Total number of pages (calculated from totalItems / pageSize)")
|
||||
sort: List[SortField] = Field(..., description="Current sort configuration applied")
|
||||
filters: Optional[Dict[str, Any]] = Field(default=None, description="Current filters applied (for future use)")
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
"""
|
||||
Response containing paginated data and metadata.
|
||||
"""
|
||||
items: List[T] = Field(..., description="Array of items for current page")
|
||||
pagination: Optional[PaginationMetadata] = Field(..., description="Pagination metadata (None if pagination not applied)")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
147
modules/datamodels/datamodelSecurity.py
Normal file
147
modules/datamodels/datamodelSecurity.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
"""Security models: Token and AuthEvent."""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
from modules.shared.timezoneUtils import getUtcTimestamp
|
||||
from .datamodelUam import AuthAuthority
|
||||
from enum import Enum
|
||||
import uuid
|
||||
|
||||
|
||||
class TokenStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
REVOKED = "revoked"
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
id: Optional[str] = None
|
||||
userId: str
|
||||
authority: AuthAuthority
|
||||
connectionId: Optional[str] = Field(
|
||||
None, description="ID of the connection this token belongs to"
|
||||
)
|
||||
tokenAccess: str
|
||||
tokenType: str = "bearer"
|
||||
expiresAt: float = Field(
|
||||
description="When the token expires (UTC timestamp in seconds)"
|
||||
)
|
||||
tokenRefresh: Optional[str] = None
|
||||
createdAt: Optional[float] = Field(
|
||||
None, description="When the token was created (UTC timestamp in seconds)"
|
||||
)
|
||||
status: TokenStatus = Field(
|
||||
default=TokenStatus.ACTIVE, description="Token status: active/revoked"
|
||||
)
|
||||
revokedAt: Optional[float] = Field(
|
||||
None, description="When the token was revoked (UTC timestamp in seconds)"
|
||||
)
|
||||
revokedBy: Optional[str] = Field(
|
||||
None, description="User ID who revoked the token (admin/self)"
|
||||
)
|
||||
reason: Optional[str] = Field(None, description="Optional revocation reason")
|
||||
sessionId: Optional[str] = Field(
|
||||
None, description="Logical session grouping for logout revocation"
|
||||
)
|
||||
mandateId: Optional[str] = Field(
|
||||
None, description="Mandate ID for tenant scoping of the token"
|
||||
)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"Token",
|
||||
{"en": "Token", "fr": "Jeton"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"authority": {"en": "Authority", "fr": "Autorité"},
|
||||
"connectionId": {"en": "Connection ID", "fr": "ID de connexion"},
|
||||
"tokenAccess": {"en": "Access Token", "fr": "Jeton d'accès"},
|
||||
"tokenType": {"en": "Token Type", "fr": "Type de jeton"},
|
||||
"expiresAt": {"en": "Expires At", "fr": "Expire le"},
|
||||
"tokenRefresh": {"en": "Refresh Token", "fr": "Jeton de rafraîchissement"},
|
||||
"createdAt": {"en": "Created At", "fr": "Créé le"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"revokedAt": {"en": "Revoked At", "fr": "Révoqué le"},
|
||||
"revokedBy": {"en": "Revoked By", "fr": "Révoqué par"},
|
||||
"reason": {"en": "Reason", "fr": "Raison"},
|
||||
"sessionId": {"en": "Session ID", "fr": "ID de session"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID de mandat"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class AuthEvent(BaseModel):
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Unique ID of the auth event",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
userId: str = Field(
|
||||
description="ID of the user this event belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
eventType: str = Field(
|
||||
description="Type of authentication event (e.g., 'login', 'logout', 'token_refresh')",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
timestamp: float = Field(
|
||||
default_factory=getUtcTimestamp,
|
||||
description="Unix timestamp when the event occurred",
|
||||
frontend_type="datetime",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
ipAddress: Optional[str] = Field(
|
||||
default=None,
|
||||
description="IP address from which the event originated",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
userAgent: Optional[str] = Field(
|
||||
default=None,
|
||||
description="User agent string from the request",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
success: bool = Field(
|
||||
default=True,
|
||||
description="Whether the authentication event was successful",
|
||||
frontend_type="boolean",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
details: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Additional details about the event",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"AuthEvent",
|
||||
{"en": "Authentication Event", "fr": "Événement d'authentification"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"eventType": {"en": "Event Type", "fr": "Type d'événement"},
|
||||
"timestamp": {"en": "Timestamp", "fr": "Horodatage"},
|
||||
"ipAddress": {"en": "IP Address", "fr": "Adresse IP"},
|
||||
"userAgent": {"en": "User Agent", "fr": "Agent utilisateur"},
|
||||
"success": {"en": "Success", "fr": "Succès"},
|
||||
"details": {"en": "Details", "fr": "Détails"},
|
||||
},
|
||||
)
|
||||
22
modules/datamodels/datamodelTickets.py
Normal file
22
modules/datamodels/datamodelTickets.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""Ticket datamodels used across Jira/ClickUp connectors."""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class TicketFieldAttribute(BaseModel):
|
||||
fieldName: str = Field(description="Human-readable field name")
|
||||
field: str = Field(description="Ticket field ID/key")
|
||||
|
||||
class TicketBase(ABC):
|
||||
@abstractmethod
|
||||
async def readAttributes(self) -> list[TicketFieldAttribute]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def readTasks(self, *, limit: int = 0) -> list[dict]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def writeTasks(self, tasklist: list[dict]) -> None: ...
|
||||
|
||||
|
||||
226
modules/datamodels/datamodelTools.py
Normal file
226
modules/datamodels/datamodelTools.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""
|
||||
Utility data models and classes for common tools and mappings.
|
||||
"""
|
||||
|
||||
class CountryCodes:
|
||||
"""
|
||||
Centralized country code mapping for different services.
|
||||
|
||||
Maps ISO-2 country codes to service-specific country names.
|
||||
Each service may have different requirements for country names.
|
||||
"""
|
||||
|
||||
# Mapping: ISO-2 code -> (Tavily country name, Perplexity country name)
|
||||
_COUNTRY_MAP = {
|
||||
"AF": ("afghanistan", "Afghanistan"),
|
||||
"AL": ("albania", "Albania"),
|
||||
"DZ": ("algeria", "Algeria"),
|
||||
"AD": ("andorra", "Andorra"),
|
||||
"AO": ("angola", "Angola"),
|
||||
"AR": ("argentina", "Argentina"),
|
||||
"AM": ("armenia", "Armenia"),
|
||||
"AU": ("australia", "Australia"),
|
||||
"AT": ("austria", "Austria"),
|
||||
"AZ": ("azerbaijan", "Azerbaijan"),
|
||||
"BS": ("bahamas", "Bahamas"),
|
||||
"BH": ("bahrain", "Bahrain"),
|
||||
"BD": ("bangladesh", "Bangladesh"),
|
||||
"BB": ("barbados", "Barbados"),
|
||||
"BY": ("belarus", "Belarus"),
|
||||
"BE": ("belgium", "Belgium"),
|
||||
"BZ": ("belize", "Belize"),
|
||||
"BJ": ("benin", "Benin"),
|
||||
"BT": ("bhutan", "Bhutan"),
|
||||
"BO": ("bolivia", "Bolivia"),
|
||||
"BA": ("bosnia and herzegovina", "Bosnia and Herzegovina"),
|
||||
"BW": ("botswana", "Botswana"),
|
||||
"BR": ("brazil", "Brazil"),
|
||||
"BN": ("brunei", "Brunei"),
|
||||
"BG": ("bulgaria", "Bulgaria"),
|
||||
"BF": ("burkina faso", "Burkina Faso"),
|
||||
"BI": ("burundi", "Burundi"),
|
||||
"KH": ("cambodia", "Cambodia"),
|
||||
"CM": ("cameroon", "Cameroon"),
|
||||
"CA": ("canada", "Canada"),
|
||||
"CV": ("cape verde", "Cape Verde"),
|
||||
"CF": ("central african republic", "Central African Republic"),
|
||||
"TD": ("chad", "Chad"),
|
||||
"CL": ("chile", "Chile"),
|
||||
"CN": ("china", "China"),
|
||||
"CO": ("colombia", "Colombia"),
|
||||
"KM": ("comoros", "Comoros"),
|
||||
"CG": ("congo", "Congo"),
|
||||
"CR": ("costa rica", "Costa Rica"),
|
||||
"HR": ("croatia", "Croatia"),
|
||||
"CU": ("cuba", "Cuba"),
|
||||
"CY": ("cyprus", "Cyprus"),
|
||||
"CZ": ("czech republic", "Czech Republic"),
|
||||
"DK": ("denmark", "Denmark"),
|
||||
"DJ": ("djibouti", "Djibouti"),
|
||||
"DO": ("dominican republic", "Dominican Republic"),
|
||||
"EC": ("ecuador", "Ecuador"),
|
||||
"EG": ("egypt", "Egypt"),
|
||||
"SV": ("el salvador", "El Salvador"),
|
||||
"GQ": ("equatorial guinea", "Equatorial Guinea"),
|
||||
"ER": ("eritrea", "Eritrea"),
|
||||
"EE": ("estonia", "Estonia"),
|
||||
"ET": ("ethiopia", "Ethiopia"),
|
||||
"FJ": ("fiji", "Fiji"),
|
||||
"FI": ("finland", "Finland"),
|
||||
"FR": ("france", "France"),
|
||||
"GA": ("gabon", "Gabon"),
|
||||
"GM": ("gambia", "Gambia"),
|
||||
"GE": ("georgia", "Georgia"),
|
||||
"DE": ("germany", "Germany"),
|
||||
"GH": ("ghana", "Ghana"),
|
||||
"GR": ("greece", "Greece"),
|
||||
"GT": ("guatemala", "Guatemala"),
|
||||
"GN": ("guinea", "Guinea"),
|
||||
"HT": ("haiti", "Haiti"),
|
||||
"HN": ("honduras", "Honduras"),
|
||||
"HU": ("hungary", "Hungary"),
|
||||
"IS": ("iceland", "Iceland"),
|
||||
"IN": ("india", "India"),
|
||||
"ID": ("indonesia", "Indonesia"),
|
||||
"IR": ("iran", "Iran"),
|
||||
"IQ": ("iraq", "Iraq"),
|
||||
"IE": ("ireland", "Ireland"),
|
||||
"IL": ("israel", "Israel"),
|
||||
"IT": ("italy", "Italy"),
|
||||
"JM": ("jamaica", "Jamaica"),
|
||||
"JP": ("japan", "Japan"),
|
||||
"JO": ("jordan", "Jordan"),
|
||||
"KZ": ("kazakhstan", "Kazakhstan"),
|
||||
"KE": ("kenya", "Kenya"),
|
||||
"KW": ("kuwait", "Kuwait"),
|
||||
"KG": ("kyrgyzstan", "Kyrgyzstan"),
|
||||
"LV": ("latvia", "Latvia"),
|
||||
"LB": ("lebanon", "Lebanon"),
|
||||
"LS": ("lesotho", "Lesotho"),
|
||||
"LR": ("liberia", "Liberia"),
|
||||
"LY": ("libya", "Libya"),
|
||||
"LI": ("liechtenstein", "Liechtenstein"),
|
||||
"LT": ("lithuania", "Lithuania"),
|
||||
"LU": ("luxembourg", "Luxembourg"),
|
||||
"MG": ("madagascar", "Madagascar"),
|
||||
"MW": ("malawi", "Malawi"),
|
||||
"MY": ("malaysia", "Malaysia"),
|
||||
"MV": ("maldives", "Maldives"),
|
||||
"ML": ("mali", "Mali"),
|
||||
"MT": ("malta", "Malta"),
|
||||
"MR": ("mauritania", "Mauritania"),
|
||||
"MU": ("mauritius", "Mauritius"),
|
||||
"MX": ("mexico", "Mexico"),
|
||||
"MD": ("moldova", "Moldova"),
|
||||
"MC": ("monaco", "Monaco"),
|
||||
"MN": ("mongolia", "Mongolia"),
|
||||
"ME": ("montenegro", "Montenegro"),
|
||||
"MA": ("morocco", "Morocco"),
|
||||
"MZ": ("mozambique", "Mozambique"),
|
||||
"MM": ("myanmar", "Myanmar"),
|
||||
"NA": ("namibia", "Namibia"),
|
||||
"NP": ("nepal", "Nepal"),
|
||||
"NL": ("netherlands", "Netherlands"),
|
||||
"NZ": ("new zealand", "New Zealand"),
|
||||
"NI": ("nicaragua", "Nicaragua"),
|
||||
"NE": ("niger", "Niger"),
|
||||
"NG": ("nigeria", "Nigeria"),
|
||||
"KP": ("north korea", "North Korea"),
|
||||
"MK": ("north macedonia", "North Macedonia"),
|
||||
"NO": ("norway", "Norway"),
|
||||
"OM": ("oman", "Oman"),
|
||||
"PK": ("pakistan", "Pakistan"),
|
||||
"PA": ("panama", "Panama"),
|
||||
"PG": ("papua new guinea", "Papua New Guinea"),
|
||||
"PY": ("paraguay", "Paraguay"),
|
||||
"PE": ("peru", "Peru"),
|
||||
"PH": ("philippines", "Philippines"),
|
||||
"PL": ("poland", "Poland"),
|
||||
"PT": ("portugal", "Portugal"),
|
||||
"QA": ("qatar", "Qatar"),
|
||||
"RO": ("romania", "Romania"),
|
||||
"RU": ("russia", "Russia"),
|
||||
"RW": ("rwanda", "Rwanda"),
|
||||
"SA": ("saudi arabia", "Saudi Arabia"),
|
||||
"SN": ("senegal", "Senegal"),
|
||||
"RS": ("serbia", "Serbia"),
|
||||
"SG": ("singapore", "Singapore"),
|
||||
"SK": ("slovakia", "Slovakia"),
|
||||
"SI": ("slovenia", "Slovenia"),
|
||||
"SO": ("somalia", "Somalia"),
|
||||
"ZA": ("south africa", "South Africa"),
|
||||
"KR": ("south korea", "South Korea"),
|
||||
"SS": ("south sudan", "South Sudan"),
|
||||
"ES": ("spain", "Spain"),
|
||||
"LK": ("sri lanka", "Sri Lanka"),
|
||||
"SD": ("sudan", "Sudan"),
|
||||
"SE": ("sweden", "Sweden"),
|
||||
"CH": ("switzerland", "Switzerland"),
|
||||
"SY": ("syria", "Syria"),
|
||||
"TW": ("taiwan", "Taiwan"),
|
||||
"TJ": ("tajikistan", "Tajikistan"),
|
||||
"TZ": ("tanzania", "Tanzania"),
|
||||
"TH": ("thailand", "Thailand"),
|
||||
"TG": ("togo", "Togo"),
|
||||
"TT": ("trinidad and tobago", "Trinidad and Tobago"),
|
||||
"TN": ("tunisia", "Tunisia"),
|
||||
"TR": ("turkey", "Turkey"),
|
||||
"TM": ("turkmenistan", "Turkmenistan"),
|
||||
"UG": ("uganda", "Uganda"),
|
||||
"UA": ("ukraine", "Ukraine"),
|
||||
"AE": ("united arab emirates", "United Arab Emirates"),
|
||||
"GB": ("united kingdom", "United Kingdom"),
|
||||
"US": ("united states", "United States"),
|
||||
"UY": ("uruguay", "Uruguay"),
|
||||
"UZ": ("uzbekistan", "Uzbekistan"),
|
||||
"VE": ("venezuela", "Venezuela"),
|
||||
"VN": ("vietnam", "Vietnam"),
|
||||
"YE": ("yemen", "Yemen"),
|
||||
"ZM": ("zambia", "Zambia"),
|
||||
"ZW": ("zimbabwe", "Zimbabwe"),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def getForTavily(cls, isoCode: str) -> str:
|
||||
"""
|
||||
Get Tavily-compatible country name from ISO-2 code.
|
||||
|
||||
Args:
|
||||
isoCode: ISO-2 country code (e.g., "CH", "ch", "US", "us")
|
||||
|
||||
Returns:
|
||||
Country name in lowercase as required by Tavily (e.g., "switzerland", "united states")
|
||||
"""
|
||||
# Convert to uppercase for lookup
|
||||
isoCodeUpper = isoCode.upper() if isoCode else ""
|
||||
mapping = cls._COUNTRY_MAP.get(isoCodeUpper)
|
||||
return mapping[0] if mapping else isoCode
|
||||
|
||||
@classmethod
|
||||
def getForPerplexity(cls, isoCode: str) -> str:
|
||||
"""
|
||||
Get Perplexity-compatible country name from ISO-2 code.
|
||||
|
||||
Args:
|
||||
isoCode: ISO-2 country code (e.g., "CH", "US")
|
||||
|
||||
Returns:
|
||||
Full country name as required by Perplexity (e.g., "Switzerland", "United States")
|
||||
"""
|
||||
isoCodeUpper = isoCode.upper()
|
||||
mapping = cls._COUNTRY_MAP.get(isoCodeUpper)
|
||||
return mapping[1] if mapping else isoCode
|
||||
|
||||
@classmethod
|
||||
def isValid(cls, isoCode: str) -> bool:
|
||||
"""
|
||||
Check if ISO-2 code is valid.
|
||||
|
||||
Args:
|
||||
isoCode: ISO-2 country code to check
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
return isoCode.upper() in cls._COUNTRY_MAP
|
||||
|
||||
138
modules/datamodels/datamodelUam.py
Normal file
138
modules/datamodels/datamodelUam.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""UAM models: User, Mandate, UserConnection."""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
from modules.shared.timezoneUtils import getUtcTimestamp
|
||||
|
||||
|
||||
class AuthAuthority(str, Enum):
|
||||
LOCAL = "local"
|
||||
GOOGLE = "google"
|
||||
MSFT = "msft"
|
||||
|
||||
class UserPrivilege(str, Enum):
|
||||
SYSADMIN = "sysadmin"
|
||||
ADMIN = "admin"
|
||||
USER = "user"
|
||||
|
||||
class ConnectionStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
EXPIRED = "expired"
|
||||
REVOKED = "revoked"
|
||||
PENDING = "pending"
|
||||
|
||||
class Mandate(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the mandate", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
name: str = Field(description="Name of the mandate", frontend_type="text", frontend_readonly=False, frontend_required=True)
|
||||
language: str = Field(default="en", description="Default language of the mandate", frontend_type="select", frontend_readonly=False, frontend_required=True, frontend_options=[
|
||||
{"value": "de", "label": {"en": "Deutsch", "fr": "Allemand"}},
|
||||
{"value": "en", "label": {"en": "English", "fr": "Anglais"}},
|
||||
{"value": "fr", "label": {"en": "Français", "fr": "Français"}},
|
||||
{"value": "it", "label": {"en": "Italiano", "fr": "Italien"}},
|
||||
])
|
||||
enabled: bool = Field(default=True, description="Indicates whether the mandate is enabled", frontend_type="checkbox", frontend_readonly=False, frontend_required=False)
|
||||
registerModelLabels(
|
||||
"Mandate",
|
||||
{"en": "Mandate", "fr": "Mandat"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"name": {"en": "Name", "fr": "Nom"},
|
||||
"language": {"en": "Language", "fr": "Langue"},
|
||||
"enabled": {"en": "Enabled", "fr": "Activé"},
|
||||
},
|
||||
)
|
||||
|
||||
class UserConnection(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the connection", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
userId: str = Field(description="ID of the user this connection belongs to", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
authority: AuthAuthority = Field(description="Authentication authority", frontend_type="select", frontend_readonly=True, frontend_required=False, frontend_options=[
|
||||
{"value": "local", "label": {"en": "Local", "fr": "Local"}},
|
||||
{"value": "google", "label": {"en": "Google", "fr": "Google"}},
|
||||
{"value": "msft", "label": {"en": "Microsoft", "fr": "Microsoft"}},
|
||||
])
|
||||
externalId: str = Field(description="User ID in the external system", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
externalUsername: str = Field(description="Username in the external system", frontend_type="text", frontend_readonly=False, frontend_required=False)
|
||||
externalEmail: Optional[EmailStr] = Field(None, description="Email in the external system", frontend_type="email", frontend_readonly=False, frontend_required=False)
|
||||
status: ConnectionStatus = Field(default=ConnectionStatus.ACTIVE, description="Connection status", frontend_type="select", frontend_readonly=False, frontend_required=False, frontend_options=[
|
||||
{"value": "active", "label": {"en": "Active", "fr": "Actif"}},
|
||||
{"value": "inactive", "label": {"en": "Inactive", "fr": "Inactif"}},
|
||||
{"value": "expired", "label": {"en": "Expired", "fr": "Expiré"}},
|
||||
{"value": "pending", "label": {"en": "Pending", "fr": "En attente"}},
|
||||
])
|
||||
connectedAt: float = Field(default_factory=getUtcTimestamp, description="When the connection was established (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
||||
lastChecked: float = Field(default_factory=getUtcTimestamp, description="When the connection was last verified (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
||||
expiresAt: Optional[float] = Field(None, description="When the connection expires (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
||||
tokenStatus: Optional[str] = Field(None, description="Current token status: active, expired, none", frontend_type="select", frontend_readonly=True, frontend_required=False, frontend_options=[
|
||||
{"value": "active", "label": {"en": "Active", "fr": "Actif"}},
|
||||
{"value": "expired", "label": {"en": "Expired", "fr": "Expiré"}},
|
||||
{"value": "none", "label": {"en": "None", "fr": "Aucun"}},
|
||||
])
|
||||
tokenExpiresAt: Optional[float] = Field(None, description="When the current token expires (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
||||
registerModelLabels(
|
||||
"UserConnection",
|
||||
{"en": "User Connection", "fr": "Connexion utilisateur"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"authority": {"en": "Authority", "fr": "Autorité"},
|
||||
"externalId": {"en": "External ID", "fr": "ID externe"},
|
||||
"externalUsername": {"en": "External Username", "fr": "Nom d'utilisateur externe"},
|
||||
"externalEmail": {"en": "External Email", "fr": "Email externe"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"connectedAt": {"en": "Connected At", "fr": "Connecté le"},
|
||||
"lastChecked": {"en": "Last Checked", "fr": "Dernière vérification"},
|
||||
"expiresAt": {"en": "Expires At", "fr": "Expire le"},
|
||||
"tokenStatus": {"en": "Connection Status", "fr": "Statut de connexion"},
|
||||
"tokenExpiresAt": {"en": "Expires At", "fr": "Expire le"},
|
||||
},
|
||||
)
|
||||
|
||||
class User(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the user", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
username: str = Field(description="Username for login", frontend_type="text", frontend_readonly=False, frontend_required=True)
|
||||
email: Optional[EmailStr] = Field(None, description="Email address of the user", frontend_type="email", frontend_readonly=False, frontend_required=True)
|
||||
fullName: Optional[str] = Field(None, description="Full name of the user", frontend_type="text", frontend_readonly=False, frontend_required=False)
|
||||
language: str = Field(default="en", description="Preferred language of the user", frontend_type="select", frontend_readonly=False, frontend_required=True, frontend_options=[
|
||||
{"value": "de", "label": {"en": "Deutsch", "fr": "Allemand"}},
|
||||
{"value": "en", "label": {"en": "English", "fr": "Anglais"}},
|
||||
{"value": "fr", "label": {"en": "Français", "fr": "Français"}},
|
||||
{"value": "it", "label": {"en": "Italiano", "fr": "Italien"}},
|
||||
])
|
||||
enabled: bool = Field(default=True, description="Indicates whether the user is enabled", frontend_type="checkbox", frontend_readonly=False, frontend_required=False)
|
||||
privilege: UserPrivilege = Field(default=UserPrivilege.USER, description="Permission level", frontend_type="select", frontend_readonly=False, frontend_required=True, frontend_options=[
|
||||
{"value": "user", "label": {"en": "User", "fr": "Utilisateur"}},
|
||||
{"value": "admin", "label": {"en": "Admin", "fr": "Administrateur"}},
|
||||
{"value": "sysadmin", "label": {"en": "SysAdmin", "fr": "Administrateur système"}},
|
||||
])
|
||||
authenticationAuthority: AuthAuthority = Field(default=AuthAuthority.LOCAL, description="Primary authentication authority", frontend_type="select", frontend_readonly=True, frontend_required=False, frontend_options=[
|
||||
{"value": "local", "label": {"en": "Local", "fr": "Local"}},
|
||||
{"value": "google", "label": {"en": "Google", "fr": "Google"}},
|
||||
{"value": "msft", "label": {"en": "Microsoft", "fr": "Microsoft"}},
|
||||
])
|
||||
mandateId: Optional[str] = Field(None, description="ID of the mandate this user belongs to", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
registerModelLabels(
|
||||
"User",
|
||||
{"en": "User", "fr": "Utilisateur"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"username": {"en": "Username", "fr": "Nom d'utilisateur"},
|
||||
"email": {"en": "Email", "fr": "Email"},
|
||||
"fullName": {"en": "Full Name", "fr": "Nom complet"},
|
||||
"language": {"en": "Language", "fr": "Langue"},
|
||||
"enabled": {"en": "Enabled", "fr": "Activé"},
|
||||
"privilege": {"en": "Privilege", "fr": "Privilège"},
|
||||
"authenticationAuthority": {"en": "Auth Authority", "fr": "Autorité d'authentification"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID de mandat"},
|
||||
},
|
||||
)
|
||||
|
||||
class UserInDB(User):
|
||||
hashedPassword: Optional[str] = Field(None, description="Hash of the user password")
|
||||
registerModelLabels(
|
||||
"UserInDB",
|
||||
{"en": "User Access", "fr": "Accès de l'utilisateur"},
|
||||
{"hashedPassword": {"en": "Password hash", "fr": "Hachage de mot de passe"}},
|
||||
)
|
||||
24
modules/datamodels/datamodelUtils.py
Normal file
24
modules/datamodels/datamodelUtils.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
"""Utility datamodels: Prompt."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
import uuid
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
mandateId: str = Field(description="ID of the mandate this prompt belongs to", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
content: str = Field(description="Content of the prompt", frontend_type="textarea", frontend_readonly=False, frontend_required=True)
|
||||
name: str = Field(description="Name of the prompt", frontend_type="text", frontend_readonly=False, frontend_required=True)
|
||||
registerModelLabels(
|
||||
"Prompt",
|
||||
{"en": "Prompt", "fr": "Invite"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||
"content": {"en": "Content", "fr": "Contenu"},
|
||||
"name": {"en": "Name", "fr": "Nom"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
39
modules/datamodels/datamodelVoice.py
Normal file
39
modules/datamodels/datamodelVoice.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""Voice settings datamodel."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
from modules.shared.timezoneUtils import getUtcTimestamp
|
||||
import uuid
|
||||
|
||||
|
||||
class VoiceSettings(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
userId: str = Field(description="ID of the user these settings belong to", frontend_type="text", frontend_readonly=True, frontend_required=True)
|
||||
mandateId: str = Field(description="ID of the mandate these settings belong to", frontend_type="text", frontend_readonly=True, frontend_required=True)
|
||||
sttLanguage: str = Field(default="de-DE", description="Speech-to-Text language", frontend_type="select", frontend_readonly=False, frontend_required=True)
|
||||
ttsLanguage: str = Field(default="de-DE", description="Text-to-Speech language", frontend_type="select", frontend_readonly=False, frontend_required=True)
|
||||
ttsVoice: str = Field(default="de-DE-KatjaNeural", description="Text-to-Speech voice", frontend_type="select", frontend_readonly=False, frontend_required=True)
|
||||
translationEnabled: bool = Field(default=True, description="Whether translation is enabled", frontend_type="checkbox", frontend_readonly=False, frontend_required=False)
|
||||
targetLanguage: str = Field(default="en-US", description="Target language for translation", frontend_type="select", frontend_readonly=False, frontend_required=False)
|
||||
creationDate: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were created (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
||||
lastModified: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were last modified (UTC timestamp in seconds)", frontend_type="timestamp", frontend_readonly=True, frontend_required=False)
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"VoiceSettings",
|
||||
{"en": "Voice Settings", "fr": "Paramètres vocaux"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||
"sttLanguage": {"en": "STT Language", "fr": "Langue STT"},
|
||||
"ttsLanguage": {"en": "TTS Language", "fr": "Langue TTS"},
|
||||
"ttsVoice": {"en": "TTS Voice", "fr": "Voix TTS"},
|
||||
"translationEnabled": {"en": "Translation Enabled", "fr": "Traduction activée"},
|
||||
"targetLanguage": {"en": "Target Language", "fr": "Langue cible"},
|
||||
"creationDate": {"en": "Creation Date", "fr": "Date de création"},
|
||||
"lastModified": {"en": "Last Modified", "fr": "Dernière modification"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
41
modules/features/chatPlayground/mainChatPlayground.py
Normal file
41
modules/features/chatPlayground/mainChatPlayground.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelChat import ChatWorkflow, UserInputRequest, WorkflowModeEnum
|
||||
from modules.workflows.workflowManager import WorkflowManager
|
||||
from modules.services import getInterface as getServices
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode: WorkflowModeEnum, workflowId: Optional[str] = None) -> ChatWorkflow:
|
||||
"""
|
||||
Starts a new chat or continues an existing one, then launches processing asynchronously.
|
||||
|
||||
Args:
|
||||
currentUser: Current user
|
||||
userInput: User input request
|
||||
workflowId: Optional workflow ID to continue existing workflow
|
||||
workflowMode: "Actionplan" for traditional task planning, "Dynamic" for iterative dynamic-style processing, "Template" for template-based processing
|
||||
|
||||
Example usage for Dynamic mode:
|
||||
workflow = await chatStart(currentUser, userInput, workflowMode=WorkflowModeEnum.WORKFLOW_DYNAMIC)
|
||||
"""
|
||||
try:
|
||||
services = getServices(currentUser, None)
|
||||
workflowManager = WorkflowManager(services)
|
||||
workflow = await workflowManager.workflowStart(userInput, workflowMode, workflowId)
|
||||
return workflow
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting chat: {str(e)}")
|
||||
raise
|
||||
|
||||
async def chatStop(currentUser: User, workflowId: str) -> ChatWorkflow:
|
||||
"""Stops a running chat."""
|
||||
try:
|
||||
services = getServices(currentUser, None)
|
||||
workflowManager = WorkflowManager(services)
|
||||
return await workflowManager.workflowStop(workflowId)
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping chat: {str(e)}")
|
||||
raise
|
||||
39
modules/features/featuresLifecycle.py
Normal file
39
modules/features/featuresLifecycle.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
import logging
|
||||
from modules.interfaces.interfaceDbAppObjects import getRootInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def start() -> None:
|
||||
""" Start feature triggers and background managers """
|
||||
|
||||
rootInterface = getRootInterface()
|
||||
eventUser = rootInterface.getUserByUsername("event")
|
||||
|
||||
# Feature SyncDelta
|
||||
from modules.features.syncDelta import mainSyncDelta
|
||||
mainSyncDelta.startSyncManager(eventUser)
|
||||
|
||||
# Feature Automation Events
|
||||
if eventUser:
|
||||
try:
|
||||
from modules.interfaces.interfaceDbChatObjects import getInterface as getChatInterface
|
||||
chatInterface = getChatInterface(eventUser)
|
||||
if hasattr(chatInterface, 'syncAutomationEvents'):
|
||||
await chatInterface.syncAutomationEvents()
|
||||
logger.info("Automation events synced on startup")
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing automation events on startup: {str(e)}")
|
||||
# Don't fail startup if automation sync fails
|
||||
|
||||
# Feature ...
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
async def stop() -> None:
|
||||
""" Stop feature triggers and background managers """
|
||||
|
||||
# Feature ...
|
||||
|
||||
return True
|
||||
|
|
@ -0,0 +1,344 @@
|
|||
import logging
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelNeutralizer import DataNeutralizerAttributes, DataNeutraliserConfig
|
||||
from modules.services import getInterface as getServices
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NeutralizationPlayground:
|
||||
"""Feature/UI wrapper around NeutralizationService for playground & routes."""
|
||||
|
||||
def __init__(self, currentUser: User):
|
||||
self.currentUser = currentUser
|
||||
self.services = getServices(currentUser, None)
|
||||
|
||||
def processText(self, text: str) -> Dict[str, Any]:
|
||||
return self.services.neutralization.processText(text)
|
||||
|
||||
def processFiles(self, fileIds: List[str]) -> Dict[str, Any]:
|
||||
results: List[Dict[str, Any]] = []
|
||||
errors: List[str] = []
|
||||
for fileId in fileIds:
|
||||
try:
|
||||
res = self.services.neutralization.processFile(fileId)
|
||||
results.append({
|
||||
'file_id': fileId,
|
||||
'neutralized_file_name': res.get('neutralized_file_name'),
|
||||
'attributes_count': len(res.get('attributes', []))
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file {fileId}: {str(e)}")
|
||||
errors.append(f"{fileId}: {str(e)}")
|
||||
return {
|
||||
'success': len(errors) == 0,
|
||||
'total_files': len(fileIds),
|
||||
'successful_files': len(results),
|
||||
'failed_files': len(errors),
|
||||
'results': results,
|
||||
'errors': errors,
|
||||
}
|
||||
|
||||
|
||||
# Cleanup attributes
|
||||
def cleanAttributes(self, fileId: str) -> bool:
|
||||
return self.services.neutralization.deleteNeutralizationAttributes(fileId)
|
||||
|
||||
# Stats
|
||||
def getStats(self) -> Dict[str, Any]:
|
||||
try:
|
||||
allAttributes = self.services.neutralization.getAttributes()
|
||||
patternCounts: Dict[str, int] = {}
|
||||
for attr in allAttributes:
|
||||
# Handle both dict and object access patterns
|
||||
if isinstance(attr, dict):
|
||||
patternType = attr.get('patternType', 'unknown')
|
||||
fileId = attr.get('fileId')
|
||||
else:
|
||||
patternType = getattr(attr, 'patternType', 'unknown')
|
||||
fileId = getattr(attr, 'fileId', None)
|
||||
|
||||
if patternType:
|
||||
patternCounts[patternType] = patternCounts.get(patternType, 0) + 1
|
||||
|
||||
# Get unique files - handle both dict and object
|
||||
uniqueFiles = set()
|
||||
for attr in allAttributes:
|
||||
if isinstance(attr, dict):
|
||||
fileId = attr.get('fileId')
|
||||
else:
|
||||
fileId = getattr(attr, 'fileId', None)
|
||||
if fileId:
|
||||
uniqueFiles.add(fileId)
|
||||
|
||||
return {
|
||||
'total_attributes': len(allAttributes),
|
||||
'unique_files': len(uniqueFiles),
|
||||
'pattern_counts': patternCounts,
|
||||
'mandate_id': self.currentUser.mandateId if self.currentUser else None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting stats: {str(e)}")
|
||||
return {
|
||||
'total_attributes': 0,
|
||||
'unique_files': 0,
|
||||
'pattern_counts': {},
|
||||
'error': str(e),
|
||||
}
|
||||
|
||||
# Additional methods needed by the route
|
||||
def getConfig(self) -> Optional[DataNeutraliserConfig]:
|
||||
"""Get neutralization configuration"""
|
||||
return self.services.neutralization.getConfig()
|
||||
|
||||
def saveConfig(self, configData: Dict[str, Any]) -> DataNeutraliserConfig:
|
||||
"""Save neutralization configuration"""
|
||||
return self.services.neutralization.saveConfig(configData)
|
||||
|
||||
def neutralizeText(self, text: str, fileId: str = None) -> Dict[str, Any]:
|
||||
"""Neutralize text content"""
|
||||
return self.services.neutralization.processText(text)
|
||||
|
||||
def resolveText(self, text: str) -> str:
|
||||
"""Resolve UIDs in neutralized text back to original text"""
|
||||
return self.services.neutralization.resolveText(text)
|
||||
|
||||
def getAttributes(self, fileId: str = None) -> List[DataNeutralizerAttributes]:
|
||||
"""Get neutralization attributes, optionally filtered by file ID"""
|
||||
try:
|
||||
allAttributes = self.services.neutralization.getAttributes()
|
||||
if fileId:
|
||||
return [attr for attr in allAttributes if attr.fileId == fileId]
|
||||
return allAttributes
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting attributes: {str(e)}")
|
||||
return []
|
||||
|
||||
async def processSharepointFiles(self, sourcePath: str, targetPath: str) -> Dict[str, Any]:
|
||||
"""Process files from SharePoint source path and store neutralized files in target path"""
|
||||
from modules.services.serviceSharepoint.mainServiceSharepoint import SharepointService
|
||||
processor = SharepointProcessor(self.currentUser, self.services)
|
||||
return await processor.processSharepointFiles(sourcePath, targetPath)
|
||||
|
||||
def batchNeutralizeFiles(self, filesData: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Process multiple files for neutralization"""
|
||||
fileIds = [fileData.get('fileId') for fileData in filesData if fileData.get('fileId')]
|
||||
return self.processFiles(fileIds)
|
||||
|
||||
def getProcessingStats(self) -> Dict[str, Any]:
|
||||
"""Get neutralization processing statistics"""
|
||||
return self.getStats()
|
||||
|
||||
def cleanupFileAttributes(self, fileId: str) -> bool:
|
||||
"""Clean up neutralization attributes for a specific file"""
|
||||
return self.cleanAttributes(fileId)
|
||||
|
||||
|
||||
# Internal SharePoint helper module separated to keep feature logic tidy
|
||||
class SharepointProcessor:
|
||||
def __init__(self, currentUser: User, services):
|
||||
self.currentUser = currentUser
|
||||
self.services = services
|
||||
|
||||
async def processSharepointFiles(self, sourcePath: str, targetPath: str) -> Dict[str, Any]:
|
||||
try:
|
||||
logger.info(f"Processing SharePoint files from {sourcePath} to {targetPath}")
|
||||
|
||||
# Get SharePoint connection
|
||||
connection = await self._getSharepointConnection(sourcePath)
|
||||
if not connection:
|
||||
return {
|
||||
'success': False,
|
||||
'message': 'No SharePoint connection found for user',
|
||||
'processed_files': 0,
|
||||
'errors': ['No SharePoint connection found'],
|
||||
}
|
||||
|
||||
# Set access token for SharePoint service
|
||||
if not self.services.sharepoint.setAccessTokenFromConnection(connection):
|
||||
return {
|
||||
'success': False,
|
||||
'message': 'Failed to set SharePoint access token',
|
||||
'processed_files': 0,
|
||||
'errors': ['Failed to set SharePoint access token'],
|
||||
}
|
||||
|
||||
return await self._processSharepointFilesAsync(sourcePath, targetPath)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing SharePoint files: {str(e)}")
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'Error processing SharePoint files: {str(e)}',
|
||||
'processed_files': 0,
|
||||
'errors': [str(e)],
|
||||
}
|
||||
|
||||
async def _getSharepointConnection(self, sharepointPath: str = None):
|
||||
try:
|
||||
from modules.datamodels.datamodelUam import UserConnection
|
||||
connections = self.services.interfaceDbApp.db.getRecordset(
|
||||
UserConnection,
|
||||
recordFilter={"userId": self.services.interfaceDbApp.userId}
|
||||
)
|
||||
msftConnections = [c for c in connections if c.get('authority') == 'msft']
|
||||
if not msftConnections:
|
||||
logger.warning('No Microsoft connections found for user')
|
||||
return None
|
||||
if len(msftConnections) == 1:
|
||||
logger.info(f"Found single Microsoft connection: {msftConnections[0].get('id')}")
|
||||
return msftConnections[0]
|
||||
if sharepointPath:
|
||||
return await self._matchConnectionToPath(msftConnections, sharepointPath)
|
||||
logger.info(f"Multiple Microsoft connections found, using first one: {msftConnections[0].get('id')}")
|
||||
return msftConnections[0]
|
||||
except Exception:
|
||||
logger.error('Error getting SharePoint connection')
|
||||
return None
|
||||
|
||||
async def _matchConnectionToPath(self, connections: list, sharepointPath: str):
|
||||
try:
|
||||
if not sharepointPath or not sharepointPath.startswith('https://'):
|
||||
logger.warning(f"Invalid sharepointPath for matching: {sharepointPath}")
|
||||
return connections[0] if connections else None
|
||||
|
||||
targetDomain = urlparse(sharepointPath).netloc.lower()
|
||||
if not targetDomain:
|
||||
logger.warning(f"Could not extract domain from path: {sharepointPath}")
|
||||
return connections[0] if connections else None
|
||||
|
||||
logger.info(f"Looking for connection matching domain: {targetDomain}")
|
||||
|
||||
for connection in connections:
|
||||
try:
|
||||
if not self.services.sharepoint.setAccessTokenFromConnection(connection):
|
||||
continue
|
||||
if await self._testSharepointAccess(sharepointPath):
|
||||
logger.info(f"Found matching connection for domain {targetDomain}: {connection.get('id')}")
|
||||
return connection
|
||||
except Exception:
|
||||
continue
|
||||
logger.warning(f"No specific connection match found for {targetDomain}, using first available")
|
||||
return connections[0]
|
||||
except Exception:
|
||||
logger.error('Error matching connection to path')
|
||||
return connections[0] if connections else None
|
||||
|
||||
async def _testSharepointAccess(self, sharepointPath: str) -> bool:
|
||||
try:
|
||||
siteUrl, _ = self._parseSharepointPath(sharepointPath)
|
||||
if not siteUrl:
|
||||
return False
|
||||
siteInfo = await self.services.sharepoint.findSiteByWebUrl(siteUrl)
|
||||
return siteInfo is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def _processSharepointFilesAsync(self, sourcePath: str, targetPath: str) -> Dict[str, Any]:
|
||||
try:
|
||||
sourceSite, sourceFolder = self._parseSharepointPath(sourcePath)
|
||||
targetSite, targetFolder = self._parseSharepointPath(targetPath)
|
||||
if not sourceSite or not targetSite:
|
||||
return {'success': False, 'message': 'Invalid SharePoint path format', 'processed_files': 0, 'errors': ['Invalid SharePoint path format']}
|
||||
sourceSiteInfo = await self.services.sharepoint.findSiteByWebUrl(sourceSite)
|
||||
if not sourceSiteInfo:
|
||||
return {'success': False, 'message': f'Source site not found: {sourceSite}', 'processed_files': 0, 'errors': [f'Source site not found: {sourceSite}']}
|
||||
targetSiteInfo = await self.services.sharepoint.findSiteByWebUrl(targetSite)
|
||||
if not targetSiteInfo:
|
||||
return {'success': False, 'message': f'Target site not found: {targetSite}', 'processed_files': 0, 'errors': [f'Target site not found: {targetSite}']}
|
||||
logger.info(f"Listing files in folder: {sourceFolder} for site: {sourceSiteInfo['id']}")
|
||||
files = await self.services.sharepoint.listFolderContents(sourceSiteInfo['id'], sourceFolder)
|
||||
if not files:
|
||||
logger.warning(f"No files found in folder '{sourceFolder}', trying root folder")
|
||||
files = await self.services.sharepoint.listFolderContents(sourceSiteInfo['id'], '')
|
||||
if files:
|
||||
folders = [f for f in files if f.get('type') == 'folder']
|
||||
folderNames = [f.get('name') for f in folders]
|
||||
logger.info(f"Available folders in root: {folderNames}")
|
||||
folderList = ", ".join(folderNames) if folderNames else "None"
|
||||
return {
|
||||
'success': False,
|
||||
'message': f"Folder '{sourceFolder}' not found. Available folders in root: {folderList}",
|
||||
'processed_files': 0,
|
||||
'errors': [f"Folder '{sourceFolder}' not found. Available folders: {folderList}"],
|
||||
'available_folders': folderNames,
|
||||
}
|
||||
else:
|
||||
return {'success': False, 'message': f'No files found in source folder: {sourceFolder}', 'processed_files': 0, 'errors': [f'No files found in source folder: {sourceFolder}']}
|
||||
|
||||
textFiles = [f for f in files if f.get('type') == 'file']
|
||||
processed: List[Dict[str, Any]] = []
|
||||
errors: List[str] = []
|
||||
|
||||
async def _processSingle(fileInfo: Dict[str, Any]):
|
||||
try:
|
||||
fileContent = await self.services.sharepoint.downloadFile(sourceSiteInfo['id'], fileInfo['id'])
|
||||
if not fileContent:
|
||||
return {'error': f"Failed to download file: {fileInfo['name']}"}
|
||||
try:
|
||||
textContent = fileContent.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
textContent = fileContent.decode('latin-1')
|
||||
result = self.services.neutralization.processText(textContent)
|
||||
neutralizedFilename = f"neutralized_{fileInfo['name']}"
|
||||
uploadResult = await self.services.sharepoint.uploadFile(targetSiteInfo['id'], targetFolder, neutralizedFilename, result['neutralized_text'].encode('utf-8'))
|
||||
if 'error' in uploadResult:
|
||||
return {'error': f"Failed to upload neutralized file: {neutralizedFilename} - {uploadResult['error']}"}
|
||||
return {
|
||||
'success': True,
|
||||
'original_name': fileInfo['name'],
|
||||
'neutralized_name': neutralizedFilename,
|
||||
'attributes_count': len(result.get('attributes', [])),
|
||||
}
|
||||
except Exception as e:
|
||||
return {'error': f"Error processing file {fileInfo['name']}: {str(e)}"}
|
||||
|
||||
tasks = [ _processSingle(f) for f in textFiles ]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for i, r in enumerate(results):
|
||||
if isinstance(r, Exception):
|
||||
errors.append(f"Exception processing file {textFiles[i]['name']}: {str(r)}")
|
||||
elif isinstance(r, dict) and 'error' in r:
|
||||
errors.append(r['error'])
|
||||
elif isinstance(r, dict) and r.get('success'):
|
||||
processed.append({
|
||||
'original_name': r['original_name'],
|
||||
'neutralized_name': r['neutralized_name'],
|
||||
'attributes_count': r['attributes_count'],
|
||||
})
|
||||
else:
|
||||
errors.append(f"Unknown result processing file {textFiles[i]['name']}: {r}")
|
||||
return {
|
||||
'success': len(processed) > 0,
|
||||
'message': f"Processed {len(processed)} files successfully",
|
||||
'processed_files': len(processed),
|
||||
'files': processed,
|
||||
'errors': errors,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in async SharePoint processing: {str(e)}")
|
||||
return {'success': False, 'message': f'Error in async SharePoint processing: {str(e)}', 'processed_files': 0, 'errors': [str(e)]}
|
||||
|
||||
def _parseSharepointPath(self, path: str) -> tuple[str, str]:
|
||||
try:
|
||||
if not path.startswith('https://'):
|
||||
return None, None
|
||||
if '?' in path:
|
||||
path = path.split('?')[0]
|
||||
if '/sites/' not in path:
|
||||
return None, None
|
||||
parts = path.split('/sites/', 1)
|
||||
if len(parts) != 2:
|
||||
return None, None
|
||||
domain = parts[0].replace('https://', '')
|
||||
siteName = parts[1].split('/')[0]
|
||||
siteUrl = f"https://{domain}/sites/{siteName}"
|
||||
folderParts = parts[1].split('/')[1:]
|
||||
folderPath = unquote('/'.join(folderParts) if folderParts else '')
|
||||
return siteUrl, folderPath
|
||||
except Exception:
|
||||
logger.error(f"Error parsing SharePoint path '{path}'")
|
||||
return None, None
|
||||
830
modules/features/syncDelta/mainSyncDelta.py
Normal file
830
modules/features/syncDelta/mainSyncDelta.py
Normal file
|
|
@ -0,0 +1,830 @@
|
|||
"""
|
||||
Delta Group Sync Manager
|
||||
|
||||
This module handles the synchronization of tickets to SharePoint using the new
|
||||
Graph API-based connector architecture.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import io
|
||||
import pandas as pd
|
||||
import csv as csv_module
|
||||
from io import StringIO, BytesIO
|
||||
from datetime import datetime, UTC
|
||||
from modules.services import getInterface as getServices
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ManagerSyncDelta:
|
||||
"""Manages Tickets to SharePoint synchronization for Delta Group.
|
||||
|
||||
Supports two sync modes:
|
||||
- CSV mode: Uses CSV files for synchronization (default)
|
||||
- Excel mode: Uses Excel (.xlsx) files for synchronization
|
||||
|
||||
To change sync mode, use the setSyncMode() method or modify SYNC_MODE class variable.
|
||||
"""
|
||||
|
||||
SHAREPOINT_SITE_NAME = "SteeringBPM"
|
||||
SHAREPOINT_SITE_PATH = "SteeringBPM"
|
||||
SHAREPOINT_HOSTNAME = "deltasecurityag.sharepoint.com"
|
||||
SHAREPOINT_MAIN_FOLDER = "/General/50 Docs hosted by SELISE"
|
||||
SHAREPOINT_BACKUP_FOLDER = "/General/50 Docs hosted by SELISE/SyncHistory"
|
||||
SHAREPOINT_AUDIT_FOLDER = "/General/50 Docs hosted by SELISE/SyncHistory"
|
||||
SHAREPOINT_USER_ID = "patrick.motsch@delta.ch"
|
||||
|
||||
SYNC_MODE = "xlsx" # Can be "csv" or "xlsx"
|
||||
# File names for different sync modes
|
||||
SYNC_FILE_CSV = "DELTAgroup x SELISE Ticket Exchange List.csv"
|
||||
SYNC_FILE_XLSX = "DELTAgroup x SELISE Ticket Exchange List.xlsx"
|
||||
|
||||
# Tickets connection parameters
|
||||
JIRA_USERNAME = "p.motsch@valueon.ch"
|
||||
JIRA_API_TOKEN = "" # Will be set in __init__
|
||||
JIRA_URL = "https://deltasecurity.atlassian.net"
|
||||
JIRA_PROJECT_CODE = "DCS"
|
||||
JIRA_ISSUE_TYPE = "Task"
|
||||
|
||||
# Task sync definition for field mapping
|
||||
|
||||
TASK_SYNC_DEFINITION={
|
||||
#key=excel-header, [get:ticket>excel | put: excel>ticket, tickets-xml-field-list]
|
||||
'ID': ['get', ['key']],
|
||||
'Module Category': ['get', ['fields', 'customfield_10058', 'value']],
|
||||
'Summary': ['get', ['fields', 'summary']],
|
||||
'Description': ['get', ['fields', 'description']], # ADF format - needs conversion to text
|
||||
'References': ['get', ['fields', 'customfield_10066']], # Field exists, may be None
|
||||
'Priority': ['get', ['fields', 'priority', 'name']],
|
||||
'Issue Status': ['get', ['fields', 'status', 'name']],
|
||||
'Assignee': ['get', ['fields', 'assignee', 'displayName']],
|
||||
'Issue Created': ['get', ['fields', 'created']],
|
||||
'Due Date': ['get', ['fields', 'duedate']], # Field exists, may be None
|
||||
'DELTA Comments': ['get', ['fields', 'customfield_10167']], # Field exists, may be None
|
||||
'SELISE Ticket References': ['put', ['fields', 'customfield_10067']],
|
||||
'SELISE Status Values': ['put', ['fields', 'customfield_10065']],
|
||||
'SELISE Comments': ['put', ['fields', 'customfield_10168']],
|
||||
}
|
||||
|
||||
def __init__(self, eventUser=None):
|
||||
self.targetSite = None
|
||||
self.services = None
|
||||
self.sharepointConnection = None
|
||||
self.eventUser = eventUser
|
||||
self.sync_audit_log = [] # Store audit log entries in memory
|
||||
|
||||
try:
|
||||
if not eventUser:
|
||||
logger.error("Event user not found - SharePoint connection required")
|
||||
self._logAuditEvent("SYNC_INIT", "FAILED", "Event user not found")
|
||||
else:
|
||||
self.services = getServices(eventUser, None)
|
||||
# Read config values using services
|
||||
self.APP_ENV_TYPE = self.services.utils.configGet("APP_ENV_TYPE", "dev")
|
||||
self.JIRA_API_TOKEN = self.services.utils.configGet("Feature_SyncDelta_JIRA_DELTA_TOKEN_SECRET", "")
|
||||
# Resolve SharePoint connection for the configured user id
|
||||
self.sharepointConnection = self.services.chat.getUserConnectionByExternalUsername("msft", self.SHAREPOINT_USER_ID)
|
||||
if not self.sharepointConnection:
|
||||
logger.error(
|
||||
f"No SharePoint connection found for user: {self.SHAREPOINT_USER_ID}"
|
||||
)
|
||||
self._logAuditEvent("SYNC_INIT", "FAILED", f"No SharePoint connection for user: {self.SHAREPOINT_USER_ID}")
|
||||
else:
|
||||
# Configure SharePoint service token and set connector reference
|
||||
if not self.services.sharepoint.setAccessTokenFromConnection(
|
||||
self.sharepointConnection
|
||||
):
|
||||
logger.error("Failed to set SharePoint token from UserConnection")
|
||||
self._logAuditEvent("SYNC_INIT", "FAILED", "Failed to set SharePoint token")
|
||||
else:
|
||||
logger.info(
|
||||
f"SharePoint token configured for connection: {self.sharepointConnection.id}"
|
||||
)
|
||||
self._logAuditEvent("SYNC_INIT", "SUCCESS", f"SharePoint token configured for connection: {self.sharepointConnection.id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Initialization error in ManagerSyncDelta.__init__: {e}")
|
||||
self._logAuditEvent("SYNC_INIT", "ERROR", f"Initialization error: {str(e)}")
|
||||
|
||||
def _logAuditEvent(self, action: str, status: str, details: str):
|
||||
"""Log audit events for sync operations to memory."""
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(self.services.utils.timestampGetUtc(), UTC).strftime("%Y-%m-%d %H:%M:%S UTC")
|
||||
userId = str(self.eventUser.id) if self.eventUser else "system"
|
||||
logEntry = f"{timestamp} | {userId} | {action} | {status} | {details}"
|
||||
self.sync_audit_log.append(logEntry)
|
||||
logger.info(f"Sync Audit: {logEntry}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log audit event: {str(e)}")
|
||||
|
||||
def _logSyncChanges(self, mergeDetails: dict, syncMode: str):
|
||||
"""Log detailed field changes for sync operations."""
|
||||
try:
|
||||
# Log summary statistics
|
||||
summary = f"Sync {syncMode} - Updated: {mergeDetails['updated']}, Added: {mergeDetails['added']}, Unchanged: {mergeDetails['unchanged']}"
|
||||
self._logAuditEvent("SYNC_CHANGES_SUMMARY", "INFO", summary)
|
||||
|
||||
# Log individual field changes (limit to first 10 to avoid spam)
|
||||
for change in mergeDetails['changes'][:10]:
|
||||
# Truncate very long changes to avoid logging issues
|
||||
if len(change) > 500:
|
||||
change = change[:500] + "... [truncated]"
|
||||
self._logAuditEvent("SYNC_FIELD_CHANGE", "INFO", f"{syncMode}: {change}")
|
||||
|
||||
# Log count if there were more changes
|
||||
if len(mergeDetails['changes']) > 10:
|
||||
self._logAuditEvent("SYNC_FIELD_CHANGE", "INFO", f"{syncMode}: ... and {len(mergeDetails['changes']) - 10} more changes")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log sync changes: {str(e)}")
|
||||
|
||||
async def _saveAuditLogToSharepoint(self):
|
||||
"""Save the sync audit log to SharePoint."""
|
||||
try:
|
||||
if not self.sync_audit_log or not self.targetSite:
|
||||
return False
|
||||
|
||||
# Generate log filename with current timestamp
|
||||
timestamp = datetime.fromtimestamp(self.services.utils.timestampGetUtc(), UTC).strftime("%Y%m%d_%H%M%S")
|
||||
log_filename = f"log_{timestamp}.log"
|
||||
|
||||
# Create log content
|
||||
log_content = "\n".join(self.sync_audit_log)
|
||||
log_bytes = log_content.encode('utf-8')
|
||||
|
||||
# Upload to SharePoint audit folder
|
||||
await self.services.sharepoint.uploadFile(
|
||||
siteId=self.targetSite['id'],
|
||||
folderPath=self.SHAREPOINT_AUDIT_FOLDER,
|
||||
fileName=log_filename,
|
||||
content=log_bytes
|
||||
)
|
||||
|
||||
logger.info(f"Sync audit log saved to SharePoint: {log_filename}")
|
||||
self._logAuditEvent("AUDIT_LOG_SAVE", "SUCCESS", f"Audit log saved to SharePoint: {log_filename}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save audit log to SharePoint: {str(e)}")
|
||||
self._logAuditEvent("AUDIT_LOG_SAVE", "FAILED", f"Failed to save audit log: {str(e)}")
|
||||
return False
|
||||
|
||||
def getSyncFileName(self) -> str:
|
||||
"""Get the appropriate sync file name based on the sync mode."""
|
||||
if self.SYNC_MODE == "xlsx":
|
||||
return self.SYNC_FILE_XLSX
|
||||
else: # Default to CSV
|
||||
return self.SYNC_FILE_CSV
|
||||
|
||||
def setSyncMode(self, mode: str) -> bool:
|
||||
"""Set the sync mode to either 'csv' or 'xlsx'.
|
||||
|
||||
Args:
|
||||
mode: Either 'csv' or 'xlsx'
|
||||
|
||||
Returns:
|
||||
bool: True if mode was set successfully, False if invalid mode
|
||||
"""
|
||||
if mode.lower() in ["csv", "xlsx"]:
|
||||
self.SYNC_MODE = mode.lower()
|
||||
logger.info(f"Sync mode changed to: {self.SYNC_MODE}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Invalid sync mode: {mode}. Must be 'csv' or 'xlsx'")
|
||||
return False
|
||||
|
||||
async def initializeInterface(self) -> bool:
|
||||
"""Initialize SharePoint connector; tickets connector is created by interface on demand."""
|
||||
try:
|
||||
# Validate init-prepared members
|
||||
if not self.services or not self.sharepointConnection or not self.services.sharepoint:
|
||||
logger.error("Service or SharePoint connection not initialized")
|
||||
return False
|
||||
|
||||
# Resolve the site by hostname + site path to get the real site ID
|
||||
logger.info(
|
||||
f"Resolving site ID via hostname+path: {self.SHAREPOINT_HOSTNAME}:/sites/{self.SHAREPOINT_SITE_PATH}"
|
||||
)
|
||||
resolved = await self.services.sharepoint.findSiteByUrl(
|
||||
hostname=self.SHAREPOINT_HOSTNAME,
|
||||
sitePath=self.SHAREPOINT_SITE_PATH
|
||||
)
|
||||
|
||||
if not resolved:
|
||||
logger.error(
|
||||
f"Failed to resolve site. Hostname: {self.SHAREPOINT_HOSTNAME}, Path: {self.SHAREPOINT_SITE_PATH}"
|
||||
)
|
||||
return False
|
||||
|
||||
self.targetSite = {
|
||||
"id": resolved.get("id"),
|
||||
"displayName": resolved.get("displayName", self.SHAREPOINT_SITE_NAME),
|
||||
"name": resolved.get("name", self.SHAREPOINT_SITE_NAME)
|
||||
}
|
||||
|
||||
# Test site access by listing root of the drive
|
||||
logger.info("Testing site access using resolved site ID...")
|
||||
test_result = await self.services.sharepoint.listFolderContents(
|
||||
siteId=self.targetSite["id"],
|
||||
folderPath=""
|
||||
)
|
||||
|
||||
if test_result is not None:
|
||||
logger.info(
|
||||
f"Site access confirmed: {self.targetSite['displayName']} (ID: {self.targetSite['id']})"
|
||||
)
|
||||
else:
|
||||
logger.error("Could not access site drive - check permissions")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing connectors: {str(e)}")
|
||||
return False
|
||||
|
||||
async def syncTicketsOverSharepoint(self) -> bool:
|
||||
"""Perform Tickets to SharePoint synchronization using list-based interface and local CSV/XLSX handling."""
|
||||
try:
|
||||
logger.info(f"Starting JIRA to SharePoint synchronization (Mode: {self.SYNC_MODE})")
|
||||
self._logAuditEvent("SYNC_START", "INFO", f"Starting JIRA to SharePoint sync (Mode: {self.SYNC_MODE})")
|
||||
|
||||
# Initialize interface
|
||||
if not await self.initializeInterface():
|
||||
logger.error("Failed to initialize connectors")
|
||||
self._logAuditEvent("SYNC_INTERFACE", "FAILED", "Failed to initialize connectors")
|
||||
return False
|
||||
|
||||
# Dump current Jira fields to text file for reference
|
||||
try:
|
||||
pass # await dump_jira_fields_to_file()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to dump JIRA fields (non-blocking): {str(e)}")
|
||||
|
||||
# Dump actual JIRA data for debugging
|
||||
try:
|
||||
pass # await dump_jira_data_to_file()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to dump JIRA data (non-blocking): {str(e)}")
|
||||
|
||||
# Get the appropriate sync file name based on mode
|
||||
sync_file_name = self.getSyncFileName()
|
||||
logger.info(f"Using sync file: {sync_file_name}")
|
||||
|
||||
# Create list-based ticket interface (initialize connector by type)
|
||||
sync_interface = await self.services.ticket.connectTicket(
|
||||
taskSyncDefinition=self.TASK_SYNC_DEFINITION,
|
||||
connectorType="Jira",
|
||||
connectorParams={
|
||||
"apiUsername": self.JIRA_USERNAME,
|
||||
"apiToken": self.JIRA_API_TOKEN,
|
||||
"apiUrl": self.JIRA_URL,
|
||||
"projectCode": self.JIRA_PROJECT_CODE,
|
||||
"ticketType": self.JIRA_ISSUE_TYPE,
|
||||
},
|
||||
)
|
||||
|
||||
# Perform the sophisticated sync based on mode
|
||||
if self.SYNC_MODE == "xlsx":
|
||||
# Export tickets to list
|
||||
data_list = await sync_interface.exportTicketsAsList()
|
||||
self._logAuditEvent("SYNC_EXPORT", "INFO", f"Exported {len(data_list)} tickets from JIRA")
|
||||
# Read existing Excel headers/content
|
||||
existing_data = []
|
||||
existing_headers = {"header1": "Header 1", "header2": "Header 2"}
|
||||
try:
|
||||
file_path = f"{self.SHAREPOINT_MAIN_FOLDER}/{sync_file_name}"
|
||||
excel_content = await self.services.sharepoint.downloadFileByPath(
|
||||
siteId=self.targetSite['id'], filePath=file_path
|
||||
)
|
||||
existing_data, existing_headers = self.parseExcelContent(excel_content)
|
||||
except Exception:
|
||||
pass
|
||||
# Merge and write
|
||||
merged_data, merge_details = self.mergeJiraWithExistingDetailed(data_list, existing_data)
|
||||
|
||||
# Log detailed changes for Excel mode
|
||||
self._logSyncChanges(merge_details, "EXCEL")
|
||||
|
||||
await self.backupSharepointFile(filename=sync_file_name)
|
||||
excel_bytes = self.createExcelContent(merged_data, existing_headers)
|
||||
await self.services.sharepoint.uploadFile(
|
||||
siteId=self.targetSite['id'],
|
||||
folderPath=self.SHAREPOINT_MAIN_FOLDER,
|
||||
fileName=sync_file_name,
|
||||
content=excel_bytes,
|
||||
)
|
||||
# Import back to tickets
|
||||
try:
|
||||
excel_content = await self.services.sharepoint.downloadFileByPath(
|
||||
siteId=self.targetSite['id'], filePath=file_path
|
||||
)
|
||||
excel_rows, _ = self.parseExcelContent(excel_content)
|
||||
self._logAuditEvent("SYNC_IMPORT", "INFO", f"Importing {len(excel_rows)} Excel rows back to tickets")
|
||||
except Exception as e:
|
||||
excel_rows = []
|
||||
self._logAuditEvent("SYNC_IMPORT", "WARNING", f"Failed to download Excel for import: {str(e)}")
|
||||
await sync_interface.importListToTickets(excel_rows)
|
||||
else: # CSV mode (default)
|
||||
# Export tickets to list
|
||||
data_list = await sync_interface.exportTicketsAsList()
|
||||
self._logAuditEvent("SYNC_EXPORT", "INFO", f"Exported {len(data_list)} tickets from JIRA")
|
||||
# Prepare headers by reading existing CSV if present
|
||||
existing_headers = {"header1": "Header 1", "header2": "Header 2"}
|
||||
existing_data: list[dict] = []
|
||||
try:
|
||||
file_path = f"{self.SHAREPOINT_MAIN_FOLDER}/{sync_file_name}"
|
||||
csv_content = await self.services.sharepoint.downloadFileByPath(
|
||||
siteId=self.targetSite['id'], filePath=file_path
|
||||
)
|
||||
csv_lines = csv_content.decode('utf-8').split('\n')
|
||||
if len(csv_lines) >= 2:
|
||||
existing_headers["header1"] = csv_lines[0].rstrip('\r\n')
|
||||
existing_headers["header2"] = csv_lines[1].rstrip('\r\n')
|
||||
# Parse existing CSV rows after the two header lines
|
||||
df_existing = pd.read_csv(io.BytesIO(csv_content), skiprows=2, quoting=1, escapechar='\\', on_bad_lines='skip', engine='python')
|
||||
existing_data = df_existing.to_dict('records')
|
||||
except Exception:
|
||||
pass
|
||||
await self.backupSharepointFile(filename=sync_file_name)
|
||||
merged_data, _ = self.mergeJiraWithExistingDetailed(data_list, existing_data)
|
||||
csv_bytes = self.createCsvContent(merged_data, existing_headers)
|
||||
await self.services.sharepoint.uploadFile(
|
||||
siteId=self.targetSite['id'],
|
||||
folderPath=self.SHAREPOINT_MAIN_FOLDER,
|
||||
fileName=sync_file_name,
|
||||
content=csv_bytes,
|
||||
)
|
||||
# Import from CSV
|
||||
try:
|
||||
csv_content = await self.services.sharepoint.downloadFileByPath(
|
||||
siteId=self.targetSite['id'], filePath=file_path
|
||||
)
|
||||
df = pd.read_csv(io.BytesIO(csv_content), skiprows=2, quoting=1, escapechar='\\', on_bad_lines='skip', engine='python')
|
||||
csv_rows = df.to_dict('records')
|
||||
self._logAuditEvent("SYNC_IMPORT", "INFO", f"Importing {len(csv_rows)} CSV rows back to tickets")
|
||||
except Exception as e:
|
||||
csv_rows = []
|
||||
self._logAuditEvent("SYNC_IMPORT", "WARNING", f"Failed to download CSV for import: {str(e)}")
|
||||
await sync_interface.importListToTickets(csv_rows)
|
||||
|
||||
logger.info(f"JIRA to SharePoint synchronization completed successfully (Mode: {self.SYNC_MODE})")
|
||||
self._logAuditEvent("SYNC_COMPLETE", "SUCCESS", f"JIRA to SharePoint sync completed successfully (Mode: {self.SYNC_MODE})")
|
||||
|
||||
# Save audit log to SharePoint
|
||||
await self._saveAuditLogToSharepoint()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during JIRA to SharePoint synchronization: {str(e)}")
|
||||
self._logAuditEvent("SYNC_ERROR", "FAILED", f"Error during sync: {str(e)}")
|
||||
|
||||
# Save audit log to SharePoint even on error
|
||||
await self._saveAuditLogToSharepoint()
|
||||
|
||||
return False
|
||||
|
||||
async def backupSharepointFile(self, *, filename: str) -> bool:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(self.services.utils.timestampGetUtc(), UTC).strftime("%Y%m%d_%H%M%S")
|
||||
backup_filename = f"backup_{timestamp}_{filename}"
|
||||
await self.services.sharepoint.copyFileAsync(
|
||||
siteId=self.targetSite['id'],
|
||||
sourceFolder=self.SHAREPOINT_MAIN_FOLDER,
|
||||
sourceFile=filename,
|
||||
destFolder=self.SHAREPOINT_BACKUP_FOLDER,
|
||||
destFile=backup_filename,
|
||||
)
|
||||
self._logAuditEvent("SYNC_BACKUP", "SUCCESS", f"Backed up file: {filename} -> {backup_filename}")
|
||||
return True
|
||||
except Exception as e:
|
||||
if "itemNotFound" in str(e) or "404" in str(e):
|
||||
self._logAuditEvent("SYNC_BACKUP", "SKIPPED", f"File not found for backup: {filename}")
|
||||
return True
|
||||
logger.warning(f"Backup failed: {e}")
|
||||
self._logAuditEvent("SYNC_BACKUP", "FAILED", f"Backup failed for {filename}: {str(e)}")
|
||||
return False
|
||||
|
||||
def mergeJiraWithExistingDetailed(self, jira_data: list[dict], existing_data: list[dict]) -> tuple[list[dict], dict]:
|
||||
existing_lookup = {row.get("ID"): row for row in existing_data if row.get("ID")}
|
||||
merged_data: list[dict] = []
|
||||
changes: list[str] = []
|
||||
updated_count = added_count = unchanged_count = 0
|
||||
for jira_row in jira_data:
|
||||
jira_id = jira_row.get("ID")
|
||||
if jira_id and jira_id in existing_lookup:
|
||||
existing_row = existing_lookup[jira_id].copy()
|
||||
row_changes: list[str] = []
|
||||
for field_name, field_config in self.TASK_SYNC_DEFINITION.items():
|
||||
if field_config[0] == 'get':
|
||||
old_value = "" if existing_row.get(field_name) is None else str(existing_row.get(field_name))
|
||||
new_value = "" if jira_row.get(field_name) is None else str(jira_row.get(field_name))
|
||||
|
||||
# Convert ADF data to readable text for logging
|
||||
if isinstance(new_value, dict) and new_value.get("type") == "doc":
|
||||
new_value_readable = self.convertAdfToText(new_value)
|
||||
if old_value != new_value_readable:
|
||||
row_changes.append(f"{field_name}: '{old_value[:100]}...' -> '{new_value_readable[:100]}...'")
|
||||
elif old_value != new_value:
|
||||
# Truncate long values for logging
|
||||
old_truncated = old_value[:100] + "..." if len(old_value) > 100 else old_value
|
||||
new_truncated = new_value[:100] + "..." if len(new_value) > 100 else new_value
|
||||
row_changes.append(f"{field_name}: '{old_truncated}' -> '{new_truncated}'")
|
||||
|
||||
existing_row[field_name] = jira_row.get(field_name)
|
||||
merged_data.append(existing_row)
|
||||
if row_changes:
|
||||
updated_count += 1
|
||||
changes.append(f"Row ID {jira_id} updated: {', '.join(row_changes)}")
|
||||
else:
|
||||
unchanged_count += 1
|
||||
del existing_lookup[jira_id]
|
||||
else:
|
||||
merged_data.append(jira_row)
|
||||
added_count += 1
|
||||
changes.append(f"Row ID {jira_id} added as new record")
|
||||
for remaining in existing_lookup.values():
|
||||
merged_data.append(remaining)
|
||||
unchanged_count += 1
|
||||
details = {"updated": updated_count, "added": added_count, "unchanged": unchanged_count, "changes": changes}
|
||||
return merged_data, details
|
||||
|
||||
def createCsvContent(self, data: list[dict], existing_headers: dict | None = None) -> bytes:
|
||||
timestamp = datetime.fromtimestamp(self.services.utils.timestampGetUtc(), UTC).strftime("%Y-%m-%d %H:%M:%S UTC")
|
||||
if existing_headers is None:
|
||||
existing_headers = {"header1": "Header 1", "header2": "Header 2"}
|
||||
if not data:
|
||||
cols = list(self.TASK_SYNC_DEFINITION.keys())
|
||||
df = pd.DataFrame(columns=cols)
|
||||
else:
|
||||
df = pd.DataFrame(data)
|
||||
for column in df.columns:
|
||||
df[column] = df[column].astype("object").fillna("")
|
||||
df[column] = df[column].astype(str).str.replace('\n', '\\n', regex=False).str.replace('"', '""', regex=False)
|
||||
header1_row = next(csv_module.reader([existing_headers.get("header1", "Header 1")]), [])
|
||||
header2_row = next(csv_module.reader([existing_headers.get("header2", "Header 2")]), [])
|
||||
if len(header2_row) > 1:
|
||||
header2_row[1] = timestamp
|
||||
header_row1 = pd.DataFrame([header1_row + [""] * (len(df.columns) - len(header1_row))], columns=df.columns)
|
||||
header_row2 = pd.DataFrame([header2_row + [""] * (len(df.columns) - len(header2_row))], columns=df.columns)
|
||||
table_headers = pd.DataFrame([df.columns.tolist()], columns=df.columns)
|
||||
final_df = pd.concat([header_row1, header_row2, table_headers, df], ignore_index=True)
|
||||
out = StringIO()
|
||||
final_df.to_csv(out, index=False, header=False, quoting=1, escapechar='\\')
|
||||
return out.getvalue().encode('utf-8')
|
||||
|
||||
def createExcelContent(self, data: list[dict], existing_headers: dict | None = None) -> bytes:
|
||||
timestamp = datetime.fromtimestamp(self.services.utils.timestampGetUtc(), UTC).strftime("%Y-%m-%d %H:%M:%S UTC")
|
||||
if existing_headers is None:
|
||||
existing_headers = {"header1": "Header 1", "header2": "Header 2"}
|
||||
if not data:
|
||||
cols = list(self.TASK_SYNC_DEFINITION.keys())
|
||||
df = pd.DataFrame(columns=cols)
|
||||
else:
|
||||
df = pd.DataFrame(data)
|
||||
for column in df.columns:
|
||||
df[column] = df[column].astype("object").fillna("")
|
||||
df[column] = df[column].astype(str).str.replace('\n', '\\n', regex=False).str.replace('"', '""', regex=False)
|
||||
header1_row = next(csv_module.reader([existing_headers.get("header1", "Header 1")]), [])
|
||||
header2_row = next(csv_module.reader([existing_headers.get("header2", "Header 2")]), [])
|
||||
if len(header2_row) > 1:
|
||||
header2_row[1] = timestamp
|
||||
header_row1 = pd.DataFrame([header1_row + [""] * (len(df.columns) - len(header1_row))], columns=df.columns)
|
||||
header_row2 = pd.DataFrame([header2_row + [""] * (len(df.columns) - len(header2_row))], columns=df.columns)
|
||||
table_headers = pd.DataFrame([df.columns.tolist()], columns=df.columns)
|
||||
final_df = pd.concat([header_row1, header_row2, table_headers, df], ignore_index=True)
|
||||
buf = BytesIO()
|
||||
final_df.to_excel(buf, index=False, header=False, engine='openpyxl')
|
||||
return buf.getvalue()
|
||||
|
||||
def parseExcelContent(self, excel_content: bytes) -> tuple[list[dict], dict]:
|
||||
df = pd.read_excel(BytesIO(excel_content), engine='openpyxl', header=None)
|
||||
header_row1 = df.iloc[0:1].copy()
|
||||
header_row2 = df.iloc[1:2].copy()
|
||||
table_headers = df.iloc[2:3].copy()
|
||||
df_data = df.iloc[3:].copy()
|
||||
df_data.columns = table_headers.iloc[0]
|
||||
df_data = df_data.reset_index(drop=True)
|
||||
for column in df_data.columns:
|
||||
df_data[column] = df_data[column].astype('object').fillna('')
|
||||
data = df_data.to_dict(orient='records')
|
||||
headers = {
|
||||
"header1": ",".join([str(x) if pd.notna(x) else "" for x in header_row1.iloc[0].tolist()]),
|
||||
"header2": ",".join([str(x) if pd.notna(x) else "" for x in header_row2.iloc[0].tolist()]),
|
||||
}
|
||||
return data, headers
|
||||
|
||||
def convertAdfToText(self, adf_data):
|
||||
"""Convert Atlassian Document Format (ADF) to plain text.
|
||||
|
||||
Based on Atlassian Document Format specification for JIRA fields.
|
||||
Handles paragraphs, lists, text formatting, and other ADF node types.
|
||||
|
||||
Args:
|
||||
adf_data: ADF object or None
|
||||
|
||||
Returns:
|
||||
str: Plain text content, or empty string if None/invalid
|
||||
"""
|
||||
if not adf_data or not isinstance(adf_data, dict):
|
||||
return ""
|
||||
|
||||
if adf_data.get("type") != "doc":
|
||||
return str(adf_data) if adf_data else ""
|
||||
|
||||
content = adf_data.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
|
||||
def extractTextFromContent(contentList, listLevel=0):
|
||||
"""Recursively extract text from ADF content with proper formatting."""
|
||||
textParts = []
|
||||
listCounter = 1
|
||||
|
||||
for item in contentList:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
itemType = item.get("type", "")
|
||||
|
||||
if itemType == "text":
|
||||
# Extract text content, preserving formatting
|
||||
text = item.get("text", "")
|
||||
marks = item.get("marks", [])
|
||||
|
||||
# Handle text formatting (bold, italic, etc.)
|
||||
if marks:
|
||||
for mark in marks:
|
||||
if mark.get("type") == "strong":
|
||||
text = f"**{text}**"
|
||||
elif mark.get("type") == "em":
|
||||
text = f"*{text}*"
|
||||
elif mark.get("type") == "code":
|
||||
text = f"`{text}`"
|
||||
elif mark.get("type") == "link":
|
||||
attrs = mark.get("attrs", {})
|
||||
href = attrs.get("href", "")
|
||||
if href:
|
||||
text = f"[{text}]({href})"
|
||||
|
||||
textParts.append(text)
|
||||
|
||||
elif itemType == "hardBreak":
|
||||
textParts.append("\n")
|
||||
|
||||
elif itemType == "paragraph":
|
||||
paragraphContent = item.get("content", [])
|
||||
if paragraphContent:
|
||||
paragraphText = extractTextFromContent(paragraphContent, listLevel)
|
||||
if paragraphText.strip():
|
||||
textParts.append(paragraphText)
|
||||
|
||||
elif itemType == "bulletList":
|
||||
listContent = item.get("content", [])
|
||||
for listItem in listContent:
|
||||
if listItem.get("type") == "listItem":
|
||||
listItemContent = listItem.get("content", [])
|
||||
for listParagraph in listItemContent:
|
||||
if listParagraph.get("type") == "paragraph":
|
||||
listParagraphContent = listParagraph.get("content", [])
|
||||
if listParagraphContent:
|
||||
indent = " " * listLevel
|
||||
bulletText = extractTextFromContent(listParagraphContent, listLevel + 1)
|
||||
if bulletText.strip():
|
||||
textParts.append(f"{indent}• {bulletText}")
|
||||
|
||||
elif itemType == "orderedList":
|
||||
listContent = item.get("content", [])
|
||||
for listItem in listContent:
|
||||
if listItem.get("type") == "listItem":
|
||||
listItemContent = listItem.get("content", [])
|
||||
for listParagraph in listItemContent:
|
||||
if listParagraph.get("type") == "paragraph":
|
||||
listParagraphContent = listParagraph.get("content", [])
|
||||
if listParagraphContent:
|
||||
indent = " " * listLevel
|
||||
orderedText = extractTextFromContent(listParagraphContent, listLevel + 1)
|
||||
if orderedText.strip():
|
||||
textParts.append(f"{indent}{listCounter}. {orderedText}")
|
||||
listCounter += 1
|
||||
|
||||
elif itemType == "listItem":
|
||||
# Handle nested list items
|
||||
listItemContent = item.get("content", [])
|
||||
if listItemContent:
|
||||
textParts.append(extractTextFromContent(listItemContent, listLevel))
|
||||
|
||||
elif itemType == "embedCard":
|
||||
# Handle embedded content (videos, etc.)
|
||||
attrs = item.get("attrs", {})
|
||||
url = attrs.get("url", "")
|
||||
if url:
|
||||
textParts.append(f"[Embedded Content: {url}]")
|
||||
|
||||
elif itemType == "codeBlock":
|
||||
# Handle code blocks
|
||||
codeContent = item.get("content", [])
|
||||
if codeContent:
|
||||
codeText = extractTextFromContent(codeContent, listLevel)
|
||||
if codeText.strip():
|
||||
textParts.append(f"```\n{codeText}\n```")
|
||||
|
||||
elif itemType == "blockquote":
|
||||
# Handle blockquotes
|
||||
quoteContent = item.get("content", [])
|
||||
if quoteContent:
|
||||
quoteText = extractTextFromContent(quoteContent, listLevel)
|
||||
if quoteText.strip():
|
||||
textParts.append(f"> {quoteText}")
|
||||
|
||||
elif itemType == "heading":
|
||||
# Handle headings
|
||||
headingContent = item.get("content", [])
|
||||
if headingContent:
|
||||
headingText = extractTextFromContent(headingContent, listLevel)
|
||||
if headingText.strip():
|
||||
level = item.get("attrs", {}).get("level", 1)
|
||||
textParts.append(f"{'#' * level} {headingText}")
|
||||
|
||||
elif itemType == "rule":
|
||||
# Handle horizontal rules
|
||||
textParts.append("---")
|
||||
|
||||
else:
|
||||
# Handle unknown types by trying to extract content
|
||||
if "content" in item:
|
||||
contentText = extractTextFromContent(item.get("content", []), listLevel)
|
||||
if contentText.strip():
|
||||
textParts.append(contentText)
|
||||
|
||||
return "\n".join(textParts)
|
||||
|
||||
result = extractTextFromContent(content)
|
||||
return result.strip()
|
||||
|
||||
# Utility: dump all ticket fields (name -> field id) to a text file (generic)
|
||||
async def dumpTicketFieldsToFile(self,
|
||||
*,
|
||||
filepath: str = "ticket_sync_fields.txt",
|
||||
connectorType: str = "Jira",
|
||||
connectorParams: dict | None = None,
|
||||
taskSyncDefinition: dict | None = None,
|
||||
) -> bool:
|
||||
"""Write available ticket fields (name -> field id) to a text file (generic)."""
|
||||
try:
|
||||
connectorParams = connectorParams or {}
|
||||
taskSyncDefinition = taskSyncDefinition or self.TASK_SYNC_DEFINITION
|
||||
ticket_interface = await self.services.ticket.connectTicket(
|
||||
taskSyncDefinition=taskSyncDefinition,
|
||||
connectorType=connectorType,
|
||||
connectorParams=connectorParams,
|
||||
)
|
||||
attributes = await ticket_interface.connector_ticket.readAttributes()
|
||||
if not attributes:
|
||||
logger.warning("No ticket attributes returned; nothing to write.")
|
||||
return False
|
||||
dir_name = os.path.dirname(filepath)
|
||||
if dir_name:
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
for attr in attributes:
|
||||
f.write(f"'{attr.field_name}': ['get', ['fields', '{attr.field}']]\n")
|
||||
logger.info(f"Wrote {len(attributes)} ticket fields to {filepath}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dump ticket fields: {str(e)}")
|
||||
return False
|
||||
|
||||
# Utility: dump actual ticket data for debugging (generic)
|
||||
async def dumpTicketDataToFile(self,
|
||||
*,
|
||||
filepath: str = "ticket_sync_data.txt",
|
||||
connectorType: str = "Jira",
|
||||
connectorParams: dict | None = None,
|
||||
taskSyncDefinition: dict | None = None,
|
||||
sampleLimit: int = 5,
|
||||
) -> bool:
|
||||
"""Write actual ticket data to a text file for debugging field mapping (generic)."""
|
||||
try:
|
||||
connectorParams = connectorParams or {}
|
||||
taskSyncDefinition = taskSyncDefinition or self.TASK_SYNC_DEFINITION
|
||||
ticket_interface = await self.services.ticket.connectTicket(
|
||||
taskSyncDefinition=taskSyncDefinition,
|
||||
connectorType=connectorType,
|
||||
connectorParams=connectorParams,
|
||||
)
|
||||
tickets = await ticket_interface.connector_ticket.readTasks(limit=sampleLimit)
|
||||
if not tickets:
|
||||
logger.warning("No tickets returned; nothing to write.")
|
||||
return False
|
||||
dir_name = os.path.dirname(filepath)
|
||||
if dir_name:
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("=== TICKET DATA DEBUG ===\n\n")
|
||||
for i, ticket in enumerate(tickets):
|
||||
f.write(f"--- TICKET {i+1} ---\n")
|
||||
f.write("Raw ticket data:\n")
|
||||
f.write(f"{ticket.data}\n\n")
|
||||
f.write("Field mapping analysis:\n")
|
||||
for fieldName, fieldPath in taskSyncDefinition.items():
|
||||
if fieldPath[0] == 'get':
|
||||
try:
|
||||
value = ticket.data
|
||||
for key in fieldPath[1]:
|
||||
if isinstance(value, dict) and key in value:
|
||||
value = value[key]
|
||||
else:
|
||||
value = f"KEY_NOT_FOUND: {key}"
|
||||
break
|
||||
if isinstance(value, dict) and value.get("type") == "doc":
|
||||
pass # value = self.convertAdfToText(value)
|
||||
elif value is None:
|
||||
value = ""
|
||||
f.write(f" {fieldName}: {value}\n")
|
||||
except Exception as e:
|
||||
f.write(f" {fieldName}: ERROR - {str(e)}\n")
|
||||
f.write("\n" + "="*50 + "\n\n")
|
||||
logger.info(f"Wrote ticket data for {len(tickets)} tickets to {filepath}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dump ticket data: {str(e)}")
|
||||
return False
|
||||
|
||||
# Main part of the module
|
||||
|
||||
async def performSync(eventUser) -> bool:
|
||||
"""Perform tickets to SharePoint synchronization
|
||||
|
||||
This function is called by the scheduler and can be used independently.
|
||||
|
||||
Args:
|
||||
eventUser: Event user to use for synchronization
|
||||
|
||||
Returns:
|
||||
bool: True if synchronization was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
logger.info("Starting DG tickets sync...")
|
||||
|
||||
if not eventUser:
|
||||
logger.error("Event user not provided - cannot perform sync")
|
||||
return False
|
||||
|
||||
# Sync audit logging is handled by ManagerSyncDelta instance
|
||||
syncManager = ManagerSyncDelta(eventUser)
|
||||
success = await syncManager.syncTicketsOverSharepoint()
|
||||
|
||||
if success:
|
||||
logger.info("DG tickets sync completed successfully")
|
||||
else:
|
||||
logger.error("DG tickets sync failed")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in performing DG tickets sync: {str(e)}")
|
||||
return False
|
||||
|
||||
# Create a global instance of ManagerSyncDelta to use for scheduled runs
|
||||
_sync_manager = None
|
||||
|
||||
def startSyncManager(eventUser):
|
||||
"""Initialize the global sync manager with the eventUser."""
|
||||
global _sync_manager
|
||||
if _sync_manager is None:
|
||||
_sync_manager = ManagerSyncDelta(eventUser)
|
||||
logger.info("Global sync manager initialized with eventUser")
|
||||
try:
|
||||
# Register scheduled job based on environment using the manager's services
|
||||
if _sync_manager.APP_ENV_TYPE == "prod":
|
||||
_sync_manager.services.utils.eventRegisterCron(
|
||||
job_id="syncDelta.syncTicket",
|
||||
func=scheduledSync,
|
||||
cron_kwargs={"minute": "0,20,40"},
|
||||
replace_existing=True,
|
||||
coalesce=True,
|
||||
max_instances=1,
|
||||
misfire_grace_time=1800,
|
||||
)
|
||||
logger.info("Registered DG scheduler (every 20 minutes)")
|
||||
else:
|
||||
logger.info(f"Skipping DG scheduler registration for ticket sync in env: {_sync_manager.APP_ENV_TYPE}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register scheduler for DG sync: {str(e)}")
|
||||
return _sync_manager
|
||||
|
||||
async def scheduledSync():
|
||||
"""Scheduled sync function that uses the global sync manager."""
|
||||
try:
|
||||
global _sync_manager
|
||||
if _sync_manager and _sync_manager.eventUser:
|
||||
return await performSync(_sync_manager.eventUser)
|
||||
else:
|
||||
logger.error("Sync manager not properly initialized - no eventUser")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scheduled sync: {str(e)}")
|
||||
return False
|
||||
|
||||
# Scheduler registration and initialization are triggered by startSyncManager(eventUser)
|
||||
|
|
@ -1,527 +0,0 @@
|
|||
import logging
|
||||
from typing import Dict, Any, List, Union, Optional
|
||||
from modules.connectors.connectorAiOpenai import AiOpenai, ContextLengthExceededException
|
||||
from modules.connectors.connectorAiAnthropic import AiAnthropic
|
||||
from modules.chat.documents.documentExtraction import DocumentExtraction
|
||||
from modules.interfaces.interfaceChatModel import ChatDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# AI Model Registry with Performance Data
|
||||
AI_MODELS = {
|
||||
"openai_gpt4o": {
|
||||
"connector": "openai",
|
||||
"max_tokens": 128000,
|
||||
"cost_per_1k_tokens": 0.03, # Input
|
||||
"cost_per_1k_tokens_output": 0.06, # Output
|
||||
"speed_rating": 8, # 1-10
|
||||
"quality_rating": 9, # 1-10
|
||||
"supports_images": True,
|
||||
"supports_documents": True,
|
||||
"context_length": 128000,
|
||||
"model_name": "gpt-4o"
|
||||
},
|
||||
"openai_gpt35": {
|
||||
"connector": "openai",
|
||||
"max_tokens": 16000,
|
||||
"cost_per_1k_tokens": 0.0015,
|
||||
"cost_per_1k_tokens_output": 0.002,
|
||||
"speed_rating": 9,
|
||||
"quality_rating": 7,
|
||||
"supports_images": False,
|
||||
"supports_documents": True,
|
||||
"context_length": 16000,
|
||||
"model_name": "gpt-3.5-turbo"
|
||||
},
|
||||
"anthropic_claude": {
|
||||
"connector": "anthropic",
|
||||
"max_tokens": 200000,
|
||||
"cost_per_1k_tokens": 0.015,
|
||||
"cost_per_1k_tokens_output": 0.075,
|
||||
"speed_rating": 7,
|
||||
"quality_rating": 10,
|
||||
"supports_images": True,
|
||||
"supports_documents": True,
|
||||
"context_length": 200000,
|
||||
"model_name": "claude-3-sonnet-20240229"
|
||||
}
|
||||
}
|
||||
|
||||
class AiCalls:
|
||||
"""Interface for AI service interactions with centralized call method"""
|
||||
|
||||
def __init__(self):
|
||||
self.openaiService = AiOpenai()
|
||||
self.anthropicService = AiAnthropic()
|
||||
self.document_extractor = DocumentExtraction()
|
||||
|
||||
async def callAi(
|
||||
self,
|
||||
prompt: str,
|
||||
documents: List[ChatDocument] = None,
|
||||
operation_type: str = "general",
|
||||
priority: str = "balanced", # "speed", "quality", "cost", "balanced"
|
||||
compress_prompt: bool = True,
|
||||
compress_documents: bool = True,
|
||||
process_documents_individually: bool = False,
|
||||
max_cost: float = None,
|
||||
max_processing_time: int = None
|
||||
) -> str:
|
||||
"""
|
||||
Zentrale AI Call Methode mit intelligenter Modell-Auswahl und Content-Verarbeitung.
|
||||
|
||||
Args:
|
||||
prompt: Der Hauptprompt für die AI
|
||||
documents: Liste von Dokumenten zur Verarbeitung
|
||||
operation_type: Art der Operation ("general", "document_analysis", "image_analysis", etc.)
|
||||
priority: Priorität für Modell-Auswahl ("speed", "quality", "cost", "balanced")
|
||||
compress_prompt: Ob der Prompt komprimiert werden soll
|
||||
compress_documents: Ob Dokumente komprimiert werden sollen
|
||||
process_documents_individually: Ob Dokumente einzeln verarbeitet werden sollen
|
||||
max_cost: Maximale Kosten für den Call
|
||||
max_processing_time: Maximale Verarbeitungszeit in Sekunden
|
||||
|
||||
Returns:
|
||||
AI Response als String
|
||||
"""
|
||||
try:
|
||||
# 1. Dokumente verarbeiten falls vorhanden
|
||||
document_content = ""
|
||||
if documents:
|
||||
document_content = await self._process_documents_for_ai(
|
||||
documents,
|
||||
operation_type,
|
||||
compress_documents,
|
||||
process_documents_individually
|
||||
)
|
||||
|
||||
# 2. Bestes Modell basierend auf Priorität und Content auswählen
|
||||
selected_model = self._select_optimal_model(
|
||||
prompt,
|
||||
document_content,
|
||||
priority,
|
||||
operation_type,
|
||||
max_cost,
|
||||
max_processing_time
|
||||
)
|
||||
|
||||
# 3. Content für das gewählte Modell optimieren
|
||||
optimized_prompt, optimized_content = await self._optimize_content_for_model(
|
||||
prompt,
|
||||
document_content,
|
||||
selected_model,
|
||||
compress_prompt,
|
||||
compress_documents
|
||||
)
|
||||
|
||||
# 4. AI Call mit Failover ausführen
|
||||
return await self._execute_ai_call_with_failover(
|
||||
selected_model,
|
||||
optimized_prompt,
|
||||
optimized_content
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in centralized AI call: {str(e)}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
def _select_optimal_model(
|
||||
self,
|
||||
prompt: str,
|
||||
document_content: str,
|
||||
priority: str,
|
||||
operation_type: str,
|
||||
max_cost: float = None,
|
||||
max_processing_time: int = None
|
||||
) -> str:
|
||||
"""Wählt das optimale Modell basierend auf Priorität und Content aus"""
|
||||
|
||||
# Content-Größe berechnen
|
||||
total_content_size = len(prompt.encode('utf-8')) + len(document_content.encode('utf-8'))
|
||||
|
||||
# Verfügbare Modelle filtern
|
||||
available_models = {}
|
||||
for model_name, model_info in AI_MODELS.items():
|
||||
# Prüfe ob Modell für Content-Größe geeignet ist
|
||||
if total_content_size > model_info["context_length"] * 0.8: # 80% für Content
|
||||
continue
|
||||
|
||||
# Prüfe Kosten-Limit
|
||||
if max_cost:
|
||||
estimated_cost = self._estimate_cost(model_info, total_content_size)
|
||||
if estimated_cost > max_cost:
|
||||
continue
|
||||
|
||||
# Prüfe Operation-Type Kompatibilität
|
||||
if operation_type == "image_analysis" and not model_info["supports_images"]:
|
||||
continue
|
||||
|
||||
available_models[model_name] = model_info
|
||||
|
||||
if not available_models:
|
||||
# Fallback zum kleinsten Modell
|
||||
return "openai_gpt35"
|
||||
|
||||
# Modell basierend auf Priorität auswählen
|
||||
if priority == "speed":
|
||||
return max(available_models.keys(), key=lambda x: available_models[x]["speed_rating"])
|
||||
elif priority == "quality":
|
||||
return max(available_models.keys(), key=lambda x: available_models[x]["quality_rating"])
|
||||
elif priority == "cost":
|
||||
return min(available_models.keys(), key=lambda x: available_models[x]["cost_per_1k_tokens"])
|
||||
else: # balanced
|
||||
# Gewichtete Bewertung: 40% Qualität, 30% Geschwindigkeit, 30% Kosten
|
||||
def balanced_score(model_name):
|
||||
model_info = available_models[model_name]
|
||||
quality_score = model_info["quality_rating"] * 0.4
|
||||
speed_score = model_info["speed_rating"] * 0.3
|
||||
cost_score = (10 - (model_info["cost_per_1k_tokens"] * 1000)) * 0.3 # Niedrigere Kosten = höherer Score
|
||||
return quality_score + speed_score + cost_score
|
||||
|
||||
return max(available_models.keys(), key=balanced_score)
|
||||
|
||||
def _estimate_cost(self, model_info: Dict, content_size: int) -> float:
|
||||
"""Schätzt die Kosten für einen AI Call"""
|
||||
# Grobe Schätzung: 1 Token ≈ 4 Zeichen
|
||||
estimated_tokens = content_size / 4
|
||||
input_cost = (estimated_tokens / 1000) * model_info["cost_per_1k_tokens"]
|
||||
output_cost = (estimated_tokens / 1000) * model_info["cost_per_1k_tokens_output"] * 0.1 # 10% für Output
|
||||
return input_cost + output_cost
|
||||
|
||||
async def _process_documents_for_ai(
|
||||
self,
|
||||
documents: List[ChatDocument],
|
||||
operation_type: str,
|
||||
compress_documents: bool,
|
||||
process_individually: bool
|
||||
) -> str:
|
||||
"""Verarbeitet Dokumente für AI Call mit documentExtraction.py"""
|
||||
|
||||
if not documents:
|
||||
return ""
|
||||
|
||||
processed_contents = []
|
||||
|
||||
for doc in documents:
|
||||
try:
|
||||
# Extrahiere Content mit documentExtraction.py
|
||||
extracted = await self.document_extractor.processFileData(
|
||||
doc.fileData,
|
||||
doc.fileName,
|
||||
doc.mimeType,
|
||||
prompt=f"Extract relevant content for {operation_type}",
|
||||
documentId=doc.id,
|
||||
enableAI=True
|
||||
)
|
||||
|
||||
# Kombiniere alle Content-Items
|
||||
doc_content = []
|
||||
for content_item in extracted.contents:
|
||||
if content_item.data and content_item.data.strip():
|
||||
doc_content.append(content_item.data)
|
||||
|
||||
if doc_content:
|
||||
combined_doc_content = "\n\n".join(doc_content)
|
||||
|
||||
# Komprimiere falls gewünscht
|
||||
if compress_documents and len(combined_doc_content.encode('utf-8')) > 10000: # 10KB Limit
|
||||
combined_doc_content = await self._compress_content(
|
||||
combined_doc_content,
|
||||
10000,
|
||||
"document"
|
||||
)
|
||||
|
||||
processed_contents.append(f"Document: {doc.fileName}\n{combined_doc_content}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing document {doc.fileName}: {str(e)}")
|
||||
processed_contents.append(f"Document: {doc.fileName}\n[Error processing document: {str(e)}]")
|
||||
|
||||
return "\n\n---\n\n".join(processed_contents)
|
||||
|
||||
async def _optimize_content_for_model(
|
||||
self,
|
||||
prompt: str,
|
||||
document_content: str,
|
||||
model_name: str,
|
||||
compress_prompt: bool,
|
||||
compress_documents: bool
|
||||
) -> tuple[str, str]:
|
||||
"""Optimiert Content für das gewählte Modell"""
|
||||
|
||||
model_info = AI_MODELS[model_name]
|
||||
max_content_size = model_info["context_length"] * 0.7 # 70% für Content
|
||||
|
||||
optimized_prompt = prompt
|
||||
optimized_content = document_content
|
||||
|
||||
# Prompt komprimieren falls gewünscht
|
||||
if compress_prompt and len(prompt.encode('utf-8')) > 2000: # 2KB Limit für Prompt
|
||||
optimized_prompt = await self._compress_content(prompt, 2000, "prompt")
|
||||
|
||||
# Dokument-Content komprimieren falls gewünscht
|
||||
if compress_documents and document_content:
|
||||
content_size = len(document_content.encode('utf-8'))
|
||||
if content_size > max_content_size:
|
||||
optimized_content = await self._compress_content(
|
||||
document_content,
|
||||
int(max_content_size),
|
||||
"document"
|
||||
)
|
||||
|
||||
return optimized_prompt, optimized_content
|
||||
|
||||
async def _compress_content(self, content: str, target_size: int, content_type: str) -> str:
|
||||
"""Komprimiert Content intelligent basierend auf Typ"""
|
||||
|
||||
if len(content.encode('utf-8')) <= target_size:
|
||||
return content
|
||||
|
||||
try:
|
||||
# Verwende AI für intelligente Kompression
|
||||
compression_prompt = f"""
|
||||
Komprimiere den folgenden {content_type} auf maximal {target_size} Zeichen,
|
||||
behalte aber alle wichtigen Informationen bei:
|
||||
|
||||
{content}
|
||||
|
||||
Gib nur den komprimierten Inhalt zurück, ohne zusätzliche Erklärungen.
|
||||
"""
|
||||
|
||||
# Verwende das schnellste verfügbare Modell für Kompression
|
||||
compression_model = "openai_gpt35"
|
||||
model_info = AI_MODELS[compression_model]
|
||||
connector = getattr(self, f"{model_info['connector']}Service")
|
||||
|
||||
messages = [{"role": "user", "content": compression_prompt}]
|
||||
|
||||
if model_info["connector"] == "openai":
|
||||
compressed = await connector.callAiBasic(messages)
|
||||
else:
|
||||
response = await connector.callAiBasic(messages)
|
||||
compressed = response["choices"][0]["message"]["content"]
|
||||
|
||||
return compressed
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"AI compression failed, using truncation: {str(e)}")
|
||||
# Fallback: Einfache Truncation
|
||||
return content[:target_size] + "... [truncated]"
|
||||
|
||||
async def _execute_ai_call_with_failover(
|
||||
self,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
document_content: str
|
||||
) -> str:
|
||||
"""Führt AI Call mit automatischem Failover aus"""
|
||||
|
||||
try:
|
||||
model_info = AI_MODELS[model_name]
|
||||
connector = getattr(self, f"{model_info['connector']}Service")
|
||||
|
||||
# Messages vorbereiten
|
||||
messages = []
|
||||
if document_content:
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": f"Context from documents:\n{document_content}"
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
})
|
||||
|
||||
# AI Call ausführen
|
||||
if model_info["connector"] == "openai":
|
||||
return await connector.callAiBasic(messages)
|
||||
else: # anthropic
|
||||
response = await connector.callAiBasic(messages)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
|
||||
except ContextLengthExceededException:
|
||||
logger.warning(f"Context length exceeded for {model_name}, trying fallback")
|
||||
# Fallback zu Modell mit größerem Context
|
||||
fallback_model = self._find_fallback_model(model_name)
|
||||
if fallback_model:
|
||||
return await self._execute_ai_call_with_failover(fallback_model, prompt, document_content)
|
||||
else:
|
||||
# Letzter Ausweg: Content weiter komprimieren
|
||||
compressed_prompt = await self._compress_content(prompt, 1000, "prompt")
|
||||
compressed_content = await self._compress_content(document_content, 5000, "document")
|
||||
return await self._execute_ai_call_with_failover("openai_gpt35", compressed_prompt, compressed_content)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"AI call failed with {model_name}: {e}")
|
||||
# Allgemeiner Fallback
|
||||
return await self._execute_ai_call_with_failover("openai_gpt35", prompt, document_content)
|
||||
|
||||
def _find_fallback_model(self, current_model: str) -> Optional[str]:
|
||||
"""Findet ein Fallback-Modell mit größerem Context"""
|
||||
current_context = AI_MODELS[current_model]["context_length"]
|
||||
|
||||
# Suche Modell mit größerem Context
|
||||
for model_name, model_info in AI_MODELS.items():
|
||||
if model_info["context_length"] > current_context:
|
||||
return model_name
|
||||
|
||||
return None
|
||||
|
||||
# Legacy methods
|
||||
|
||||
async def callAiTextBasic(self, prompt: str, context: Optional[str] = None) -> str:
|
||||
"""
|
||||
Basic text processing - now uses centralized AI call method.
|
||||
|
||||
Args:
|
||||
prompt: The user prompt to process
|
||||
context: Optional system context/prompt
|
||||
|
||||
Returns:
|
||||
The AI response as text
|
||||
"""
|
||||
# Combine context with prompt if provided
|
||||
full_prompt = prompt
|
||||
if context:
|
||||
full_prompt = f"Context: {context}\n\nUser Request: {prompt}"
|
||||
|
||||
# Use centralized AI call with speed priority for basic calls
|
||||
return await self.callAi(
|
||||
prompt=full_prompt,
|
||||
priority="speed",
|
||||
compress_prompt=True,
|
||||
compress_documents=False
|
||||
)
|
||||
|
||||
async def callAiTextAdvanced(self, prompt: str, context: Optional[str] = None, _is_fallback: bool = False) -> str:
|
||||
"""
|
||||
Advanced text processing - now uses centralized AI call method.
|
||||
|
||||
Args:
|
||||
prompt: The user prompt to process
|
||||
context: Optional system context/prompt
|
||||
_is_fallback: Internal flag (kept for compatibility)
|
||||
|
||||
Returns:
|
||||
The AI response as text
|
||||
"""
|
||||
# Combine context with prompt if provided
|
||||
full_prompt = prompt
|
||||
if context:
|
||||
full_prompt = f"Context: {context}\n\nUser Request: {prompt}"
|
||||
|
||||
# Use centralized AI call with quality priority for advanced calls
|
||||
return await self.callAi(
|
||||
prompt=full_prompt,
|
||||
priority="quality",
|
||||
compress_prompt=False,
|
||||
compress_documents=False
|
||||
)
|
||||
|
||||
async def callAiImageBasic(self, prompt: str, imageData: Union[str, bytes], mimeType: str = None) -> str:
|
||||
"""
|
||||
Basic image processing - now uses centralized AI call method.
|
||||
|
||||
Args:
|
||||
prompt: The prompt for image analysis
|
||||
imageData: The image data (file path or bytes)
|
||||
mimeType: Optional MIME type of the image
|
||||
|
||||
Returns:
|
||||
The AI response as text
|
||||
"""
|
||||
try:
|
||||
# For image processing, use the original connector directly
|
||||
# as the centralized method doesn't handle images yet
|
||||
return await self.openaiService.callAiImage(prompt, imageData, mimeType)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in OpenAI image call: {str(e)}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
async def callAiImageAdvanced(self, prompt: str, imageData: Union[str, bytes], mimeType: str = None) -> str:
|
||||
"""
|
||||
Advanced image processing - now uses centralized AI call method.
|
||||
|
||||
Args:
|
||||
prompt: The prompt for image analysis
|
||||
imageData: The image data (file path or bytes)
|
||||
mimeType: Optional MIME type of the image
|
||||
|
||||
Returns:
|
||||
The AI response as text
|
||||
"""
|
||||
try:
|
||||
# For image processing, use the original connector directly
|
||||
# as the centralized method doesn't handle images yet
|
||||
return await self.anthropicService.callAiImage(prompt, imageData, mimeType)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Anthropic image call: {str(e)}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
# Convenience methods for common use cases
|
||||
|
||||
async def callAiForDocumentAnalysis(
|
||||
self,
|
||||
prompt: str,
|
||||
documents: List[ChatDocument],
|
||||
priority: str = "balanced"
|
||||
) -> str:
|
||||
"""Convenience method for document analysis"""
|
||||
return await self.callAi(
|
||||
prompt=prompt,
|
||||
documents=documents,
|
||||
operation_type="document_analysis",
|
||||
priority=priority,
|
||||
compress_documents=True,
|
||||
process_documents_individually=False
|
||||
)
|
||||
|
||||
async def callAiForReportGeneration(
|
||||
self,
|
||||
prompt: str,
|
||||
documents: List[ChatDocument],
|
||||
priority: str = "quality"
|
||||
) -> str:
|
||||
"""Convenience method for report generation"""
|
||||
return await self.callAi(
|
||||
prompt=prompt,
|
||||
documents=documents,
|
||||
operation_type="report_generation",
|
||||
priority=priority,
|
||||
compress_documents=True,
|
||||
process_documents_individually=True
|
||||
)
|
||||
|
||||
async def callAiForEmailComposition(
|
||||
self,
|
||||
prompt: str,
|
||||
documents: List[ChatDocument] = None,
|
||||
priority: str = "speed"
|
||||
) -> str:
|
||||
"""Convenience method for email composition"""
|
||||
return await self.callAi(
|
||||
prompt=prompt,
|
||||
documents=documents,
|
||||
operation_type="email_composition",
|
||||
priority=priority,
|
||||
compress_prompt=True,
|
||||
compress_documents=True
|
||||
)
|
||||
|
||||
async def callAiForTaskPlanning(
|
||||
self,
|
||||
prompt: str,
|
||||
documents: List[ChatDocument] = None,
|
||||
priority: str = "balanced"
|
||||
) -> str:
|
||||
"""Convenience method for task planning"""
|
||||
return await self.callAi(
|
||||
prompt=prompt,
|
||||
documents=documents,
|
||||
operation_type="task_planning",
|
||||
priority=priority,
|
||||
compress_prompt=False,
|
||||
compress_documents=True
|
||||
)
|
||||
|
||||
716
modules/interfaces/interfaceAiObjects.py
Normal file
716
modules/interfaces/interfaceAiObjects.py
Normal file
|
|
@ -0,0 +1,716 @@
|
|||
import logging
|
||||
import asyncio
|
||||
import uuid
|
||||
import base64
|
||||
from typing import Dict, Any, List, Union, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from modules.aicore.aicoreModelRegistry import modelRegistry
|
||||
from modules.aicore.aicoreModelSelector import modelSelector
|
||||
from modules.datamodels.datamodelAi import (
|
||||
AiModel,
|
||||
AiCallOptions,
|
||||
AiCallRequest,
|
||||
AiCallResponse,
|
||||
OperationTypeEnum,
|
||||
AiModelCall,
|
||||
AiModelResponse,
|
||||
)
|
||||
from modules.datamodels.datamodelExtraction import ContentPart, MergeStrategy
|
||||
|
||||
|
||||
# Dynamic model registry - models are now loaded from connectors via aicore system
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AiObjects:
|
||||
"""Centralized AI interface: dynamically discovers and uses AI models. Includes web functionality."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Auto-discover and register all available connectors
|
||||
self._discoverAndRegisterConnectors()
|
||||
|
||||
def _discoverAndRegisterConnectors(self):
|
||||
"""Auto-discover and register all available AI connectors."""
|
||||
logger.info("Auto-discovering AI connectors...")
|
||||
|
||||
# Use the model registry's built-in discovery mechanism
|
||||
discoveredConnectors = modelRegistry.discoverConnectors()
|
||||
|
||||
# Register each discovered connector
|
||||
for connector in discoveredConnectors:
|
||||
modelRegistry.registerConnector(connector)
|
||||
logger.info(f"Registered connector: {connector.getConnectorType()}")
|
||||
|
||||
logger.info(f"Total connectors registered: {len(discoveredConnectors)}")
|
||||
logger.info("All AI connectors registered with dynamic model registry")
|
||||
|
||||
@classmethod
|
||||
async def create(cls) -> "AiObjects":
|
||||
"""Create AiObjects instance with auto-discovered connectors."""
|
||||
# No need to manually create connectors - they're auto-discovered
|
||||
return cls()
|
||||
|
||||
def _selectModel(self, prompt: str, context: str, options: AiCallOptions) -> str:
|
||||
"""Select the best model using dynamic model selection system. Returns displayName (unique identifier)."""
|
||||
# Get available models from the dynamic registry
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
|
||||
if not availableModels:
|
||||
logger.error("No models available in the registry")
|
||||
raise ValueError("No AI models available")
|
||||
|
||||
# Use the dynamic model selector
|
||||
selectedModel = modelSelector.selectModel(prompt, context, options, availableModels)
|
||||
|
||||
if not selectedModel:
|
||||
logger.error("No suitable model found for the given criteria")
|
||||
raise ValueError("No suitable AI model found")
|
||||
|
||||
logger.info(f"Selected model: {selectedModel.name} ({selectedModel.displayName})")
|
||||
return selectedModel.displayName
|
||||
|
||||
|
||||
# AI for Extraction, Processing, Generation
|
||||
async def call(self, request: AiCallRequest, progressCallback=None) -> AiCallResponse:
|
||||
"""Call AI model for text generation with model-aware chunking."""
|
||||
# Handle content parts (unified path)
|
||||
if hasattr(request, 'contentParts') and request.contentParts:
|
||||
return await self._callWithContentParts(request, progressCallback)
|
||||
# Handle traditional text/context calls
|
||||
return await self._callWithTextContext(request)
|
||||
|
||||
async def _callWithTextContext(self, request: AiCallRequest) -> AiCallResponse:
|
||||
"""Call AI model for traditional text/context calls with fallback mechanism."""
|
||||
prompt = request.prompt
|
||||
context = request.context or ""
|
||||
options = request.options
|
||||
|
||||
# Input bytes will be calculated inside _callWithModel
|
||||
|
||||
# Generation parameters are handled inside _callWithModel
|
||||
|
||||
# Get failover models for this operation type
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
failoverModelList = modelSelector.getFailoverModelList(prompt, context, options, availableModels)
|
||||
|
||||
if not failoverModelList:
|
||||
errorMsg = f"No suitable models found for operation {options.operationType}"
|
||||
logger.error(errorMsg)
|
||||
return AiCallResponse(
|
||||
content=errorMsg,
|
||||
modelName="error",
|
||||
priceUsd=0.0,
|
||||
processingTime=0.0,
|
||||
bytesSent=0,
|
||||
bytesReceived=0,
|
||||
errorCount=1
|
||||
)
|
||||
|
||||
# Try each model in failover sequence
|
||||
lastError = None
|
||||
for attempt, model in enumerate(failoverModelList):
|
||||
try:
|
||||
logger.info(f"Attempting AI call with model: {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
|
||||
|
||||
# Call the model directly - no truncation or compression here
|
||||
response = await self._callWithModel(model, prompt, context, options)
|
||||
|
||||
logger.info(f"✅ AI call successful with model: {model.name}")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
lastError = e
|
||||
logger.warning(f"❌ AI call failed with model {model.name}: {str(e)}")
|
||||
|
||||
# If this is not the last model, try the next one
|
||||
if attempt < len(failoverModelList) - 1:
|
||||
logger.info(f"🔄 Trying next failover model...")
|
||||
continue
|
||||
else:
|
||||
# All models failed
|
||||
logger.error(f"💥 All {len(failoverModelList)} models failed for operation {options.operationType}")
|
||||
break
|
||||
|
||||
# All failover attempts failed - return error response
|
||||
errorMsg = f"All AI models failed for operation {options.operationType}. Last error: {str(lastError)}"
|
||||
logger.error(errorMsg)
|
||||
return AiCallResponse(
|
||||
content=errorMsg,
|
||||
modelName="error",
|
||||
priceUsd=0.0,
|
||||
processingTime=0.0,
|
||||
bytesSent=0,
|
||||
bytesReceived=0,
|
||||
errorCount=1
|
||||
)
|
||||
|
||||
async def _callWithContentParts(self, request: AiCallRequest, progressCallback=None) -> AiCallResponse:
|
||||
"""Process content parts with model-aware chunking (unified for single and multiple parts)."""
|
||||
prompt = request.prompt
|
||||
options = request.options
|
||||
contentParts = request.contentParts
|
||||
|
||||
# Get failover models
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
failoverModelList = modelSelector.getFailoverModelList(prompt, "", options, availableModels)
|
||||
|
||||
if not failoverModelList:
|
||||
return self._createErrorResponse("No suitable models found", 0, 0)
|
||||
|
||||
# Process each content part
|
||||
allResults = []
|
||||
for contentPart in contentParts:
|
||||
partResult = await self._processContentPartWithFallback(contentPart, prompt, options, failoverModelList, progressCallback)
|
||||
allResults.append(partResult)
|
||||
|
||||
# Merge all results
|
||||
mergedContent = self._mergePartResults(allResults)
|
||||
|
||||
return AiCallResponse(
|
||||
content=mergedContent,
|
||||
modelName="multiple",
|
||||
priceUsd=sum(r.priceUsd for r in allResults),
|
||||
processingTime=sum(r.processingTime for r in allResults),
|
||||
bytesSent=sum(r.bytesSent for r in allResults),
|
||||
bytesReceived=sum(r.bytesReceived for r in allResults),
|
||||
errorCount=sum(r.errorCount for r in allResults)
|
||||
)
|
||||
|
||||
async def _processContentPartWithFallback(self, contentPart, prompt: str, options, failoverModelList, progressCallback=None) -> AiCallResponse:
|
||||
"""Process a single content part with model-aware chunking and fallback."""
|
||||
lastError = None
|
||||
|
||||
# Check if this is an image - Vision models need special handling
|
||||
isImage = (contentPart.typeGroup == "image") or (contentPart.mimeType and contentPart.mimeType.startswith("image/"))
|
||||
|
||||
# Determine the correct operation type based on content type
|
||||
# Images should use IMAGE_ANALYSE, not the generic operation type
|
||||
actualOperationType = options.operationType
|
||||
if isImage:
|
||||
actualOperationType = OperationTypeEnum.IMAGE_ANALYSE
|
||||
# Get vision-capable models for images
|
||||
availableModels = modelRegistry.getAvailableModels()
|
||||
visionFailoverList = modelSelector.getFailoverModelList(prompt, "", AiCallOptions(operationType=actualOperationType), availableModels)
|
||||
if visionFailoverList:
|
||||
logger.debug(f"Using {len(visionFailoverList)} vision-capable models for image processing")
|
||||
failoverModelList = visionFailoverList
|
||||
|
||||
for attempt, model in enumerate(failoverModelList):
|
||||
try:
|
||||
logger.info(f"Processing content part with model: {model.name} (attempt {attempt + 1}/{len(failoverModelList)})")
|
||||
|
||||
# Special handling for images with Vision models
|
||||
if isImage and hasattr(model, 'functionCall'):
|
||||
# Call model's functionCall directly (for Vision models this is callAiImage)
|
||||
from modules.datamodels.datamodelAi import AiModelCall, AiCallOptions as AiCallOpts
|
||||
|
||||
try:
|
||||
# Validate and prepare image data
|
||||
if not contentPart.data:
|
||||
raise ValueError("Image content part has no data")
|
||||
|
||||
# Ensure mimeType is valid
|
||||
mimeType = contentPart.mimeType or "image/jpeg"
|
||||
if not mimeType.startswith("image/"):
|
||||
raise ValueError(f"Invalid mimeType for image: {mimeType}")
|
||||
|
||||
# Prepare base64 data
|
||||
if isinstance(contentPart.data, str):
|
||||
# Already base64 encoded - validate it
|
||||
try:
|
||||
base64.b64decode(contentPart.data, validate=True)
|
||||
base64Data = contentPart.data
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid base64 data in contentPart: {str(e)}")
|
||||
elif isinstance(contentPart.data, bytes):
|
||||
# Binary data - encode to base64
|
||||
base64Data = base64.b64encode(contentPart.data).decode('utf-8')
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type for image: {type(contentPart.data)}")
|
||||
|
||||
# Create data URL
|
||||
imageDataUrl = f"data:{mimeType};base64,{base64Data}"
|
||||
|
||||
modelCall = AiModelCall(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt or ""},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": imageDataUrl
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
options=AiCallOpts(operationType=actualOperationType)
|
||||
)
|
||||
|
||||
modelResponse = await model.functionCall(modelCall)
|
||||
|
||||
if not modelResponse.success:
|
||||
raise ValueError(f"Model call failed: {modelResponse.error}")
|
||||
|
||||
logger.info(f"✅ Image content part processed successfully with model: {model.name}")
|
||||
|
||||
# Convert to AiCallResponse format
|
||||
return AiCallResponse(
|
||||
content=modelResponse.content,
|
||||
modelName=model.name,
|
||||
priceUsd=modelResponse.priceUsd if hasattr(modelResponse, 'priceUsd') else 0.0,
|
||||
processingTime=modelResponse.processingTime if hasattr(modelResponse, 'processingTime') else 0.0,
|
||||
bytesSent=0, # Will be calculated elsewhere
|
||||
bytesReceived=0, # Will be calculated elsewhere
|
||||
errorCount=0
|
||||
)
|
||||
except Exception as e:
|
||||
# Image processing failed with this model
|
||||
lastError = e
|
||||
logger.warning(f"❌ Image processing failed with model {model.name}: {str(e)}")
|
||||
|
||||
# If this is not the last model, try the next one
|
||||
if attempt < len(failoverModelList) - 1:
|
||||
logger.info(f"🔄 Trying next fallback model for image processing...")
|
||||
continue
|
||||
else:
|
||||
# All models failed
|
||||
logger.error(f"💥 All {len(failoverModelList)} models failed for image processing")
|
||||
raise
|
||||
|
||||
# For non-image parts, check if part fits in model context
|
||||
# Calculate available space accounting for prompt, system message, and output reservation
|
||||
partSize = len(contentPart.data.encode('utf-8')) if contentPart.data else 0
|
||||
|
||||
# Use same calculation as _chunkContentPart to determine actual available space
|
||||
modelContextTokens = model.contextLength
|
||||
modelMaxOutputTokens = model.maxTokens
|
||||
|
||||
# Reserve tokens for prompt, system message, output, and message overhead
|
||||
promptTokens = len(prompt.encode('utf-8')) / 4 if prompt else 0
|
||||
systemMessageTokens = 10 # ~40 bytes = 10 tokens
|
||||
outputTokens = modelMaxOutputTokens
|
||||
messageOverheadTokens = 100
|
||||
totalReservedTokens = promptTokens + systemMessageTokens + messageOverheadTokens + outputTokens
|
||||
|
||||
# Available tokens for content (with 80% safety margin)
|
||||
availableContentTokens = int((modelContextTokens - totalReservedTokens) * 0.8)
|
||||
if availableContentTokens < 100:
|
||||
availableContentTokens = max(100, int(modelContextTokens * 0.1))
|
||||
|
||||
# Convert to bytes (1 token ≈ 4 bytes)
|
||||
availableContentBytes = availableContentTokens * 4
|
||||
|
||||
logger.debug(f"Size check for {model.name}: partSize={partSize} bytes, availableContentBytes={availableContentBytes} bytes (contextLength={modelContextTokens} tokens, reserved={totalReservedTokens:.0f} tokens)")
|
||||
|
||||
if partSize <= availableContentBytes:
|
||||
# Part fits - call AI directly
|
||||
response = await self._callWithModel(model, prompt, contentPart.data, options)
|
||||
logger.info(f"✅ Content part processed successfully with model: {model.name}")
|
||||
return response
|
||||
else:
|
||||
# Part too large - chunk it (pass prompt to account for it in chunk size calculation)
|
||||
chunks = await self._chunkContentPart(contentPart, model, options, prompt)
|
||||
if not chunks:
|
||||
raise ValueError(f"Failed to chunk content part for model {model.name}")
|
||||
|
||||
logger.info(f"Starting to process {len(chunks)} chunks with model {model.name}")
|
||||
|
||||
# Log progress if callback provided
|
||||
if progressCallback:
|
||||
progressCallback(0.0, f"Starting to process {len(chunks)} chunks")
|
||||
|
||||
# Process each chunk
|
||||
chunkResults = []
|
||||
for idx, chunk in enumerate(chunks):
|
||||
chunkNum = idx + 1
|
||||
chunkData = chunk.get('data', '')
|
||||
chunkSize = len(chunkData.encode('utf-8')) if chunkData else 0
|
||||
logger.info(f"Processing chunk {chunkNum}/{len(chunks)} with model {model.name}, chunk size: {chunkSize} bytes")
|
||||
|
||||
# Calculate and log progress
|
||||
if progressCallback:
|
||||
progress = chunkNum / len(chunks)
|
||||
progressCallback(progress, f"Processing chunk {chunkNum}/{len(chunks)}")
|
||||
|
||||
try:
|
||||
chunkResponse = await self._callWithModel(model, prompt, chunkData, options)
|
||||
chunkResults.append(chunkResponse)
|
||||
logger.info(f"✅ Chunk {chunkNum}/{len(chunks)} processed successfully")
|
||||
|
||||
# Log completion progress
|
||||
if progressCallback:
|
||||
progressCallback(chunkNum / len(chunks), f"Chunk {chunkNum}/{len(chunks)} processed")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error processing chunk {chunkNum}/{len(chunks)}: {str(e)}")
|
||||
raise
|
||||
|
||||
# Merge chunk results
|
||||
mergedContent = self._mergeChunkResults(chunkResults)
|
||||
totalPrice = sum(r.priceUsd for r in chunkResults)
|
||||
totalTime = sum(r.processingTime for r in chunkResults)
|
||||
totalBytesSent = sum(r.bytesSent for r in chunkResults)
|
||||
totalBytesReceived = sum(r.bytesReceived for r in chunkResults)
|
||||
totalErrors = sum(r.errorCount for r in chunkResults)
|
||||
|
||||
logger.info(f"✅ Content part chunked and processed with model: {model.name} ({len(chunks)} chunks)")
|
||||
return AiCallResponse(
|
||||
content=mergedContent,
|
||||
modelName=model.name,
|
||||
priceUsd=totalPrice,
|
||||
processingTime=totalTime,
|
||||
bytesSent=totalBytesSent,
|
||||
bytesReceived=totalBytesReceived,
|
||||
errorCount=totalErrors
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
lastError = e
|
||||
error_msg = str(e) if str(e) else f"{type(e).__name__}"
|
||||
error_detail = f"❌ Model {model.name} failed for content part: {error_msg}"
|
||||
if hasattr(e, 'detail') and e.detail:
|
||||
error_detail += f" | Detail: {e.detail}"
|
||||
if hasattr(e, 'status_code'):
|
||||
error_detail += f" | Status: {e.status_code}"
|
||||
logger.warning(error_detail, exc_info=True)
|
||||
|
||||
if attempt < len(failoverModelList) - 1:
|
||||
logger.info(f"🔄 Trying next failover model...")
|
||||
continue
|
||||
else:
|
||||
logger.error(f"💥 All {len(failoverModelList)} models failed for content part")
|
||||
break
|
||||
|
||||
# All models failed
|
||||
return self._createErrorResponse(f"All models failed: {str(lastError)}", 0, 0)
|
||||
|
||||
async def _chunkContentPart(self, contentPart, model, options, prompt: str = "") -> List[Dict[str, Any]]:
|
||||
"""Chunk a content part based on model capabilities, accounting for prompt, system message overhead, and maxTokens output."""
|
||||
# Calculate model-specific chunk sizes
|
||||
modelContextTokens = model.contextLength # Total context in tokens
|
||||
modelMaxOutputTokens = model.maxTokens # Maximum output tokens
|
||||
|
||||
# Reserve tokens for:
|
||||
# 1. Prompt (user message)
|
||||
promptTokens = len(prompt.encode('utf-8')) / 4 if prompt else 0
|
||||
|
||||
# 2. System message wrapper ("Context from documents:\n")
|
||||
systemMessageTokens = 10 # ~40 bytes = 10 tokens
|
||||
|
||||
# 3. Max output tokens (model will reserve space for completion)
|
||||
outputTokens = modelMaxOutputTokens
|
||||
|
||||
# 4. JSON structure and message overhead (~100 tokens)
|
||||
messageOverheadTokens = 100
|
||||
|
||||
# Total reserved tokens = input overhead + output reservation
|
||||
totalReservedTokens = promptTokens + systemMessageTokens + messageOverheadTokens + outputTokens
|
||||
|
||||
# Available tokens for content = context length - reserved tokens
|
||||
# Use 80% of available for safety margin
|
||||
availableContentTokens = int((modelContextTokens - totalReservedTokens) * 0.8)
|
||||
|
||||
# Ensure we have at least some space
|
||||
if availableContentTokens < 100:
|
||||
logger.warning(f"Very limited space for content: {availableContentTokens} tokens available. Model: {model.name}, contextLength: {modelContextTokens}, maxTokens: {modelMaxOutputTokens}, prompt: {promptTokens:.0f} tokens")
|
||||
availableContentTokens = max(100, int(modelContextTokens * 0.1)) # Fallback to 10% of context
|
||||
|
||||
# Convert tokens to bytes (1 token ≈ 4 bytes)
|
||||
availableContentBytes = availableContentTokens * 4
|
||||
|
||||
logger.debug(f"Chunking calculation for {model.name}: contextLength={modelContextTokens} tokens, maxTokens={modelMaxOutputTokens} tokens, prompt={promptTokens:.0f} tokens, reserved={totalReservedTokens:.0f} tokens, available={availableContentTokens} tokens ({availableContentBytes} bytes)")
|
||||
|
||||
# Use 70% of available content bytes for text chunks (conservative)
|
||||
textChunkSize = int(availableContentBytes * 0.7)
|
||||
imageChunkSize = int(availableContentBytes * 0.8) # 80% for image chunks
|
||||
|
||||
# Build chunking options
|
||||
chunkingOptions = {
|
||||
"textChunkSize": textChunkSize,
|
||||
"imageChunkSize": imageChunkSize,
|
||||
"maxSize": availableContentBytes,
|
||||
"chunkAllowed": True
|
||||
}
|
||||
|
||||
# Get appropriate chunker
|
||||
from modules.services.serviceExtraction.subRegistry import ChunkerRegistry
|
||||
chunkerRegistry = ChunkerRegistry()
|
||||
chunker = chunkerRegistry.resolve(contentPart.typeGroup)
|
||||
|
||||
if not chunker:
|
||||
logger.warning(f"No chunker found for typeGroup: {contentPart.typeGroup}")
|
||||
return []
|
||||
|
||||
# Chunk the content part
|
||||
try:
|
||||
chunks = chunker.chunk(contentPart, chunkingOptions)
|
||||
logger.debug(f"Created {len(chunks)} chunks for {contentPart.typeGroup} part")
|
||||
return chunks
|
||||
except Exception as e:
|
||||
logger.error(f"Chunking failed for {contentPart.typeGroup}: {str(e)}")
|
||||
return []
|
||||
|
||||
def _mergePartResults(self, partResults: List[AiCallResponse]) -> str:
|
||||
"""Merge part results using the existing sophisticated merging system."""
|
||||
if not partResults:
|
||||
return ""
|
||||
|
||||
# Convert AiCallResponse results to ContentParts for merging
|
||||
from modules.datamodels.datamodelExtraction import ContentPart
|
||||
from modules.services.serviceExtraction.subUtils import makeId
|
||||
|
||||
content_parts = []
|
||||
for i, result in enumerate(partResults):
|
||||
if result.content:
|
||||
content_part = ContentPart(
|
||||
id=str(uuid.uuid4()),
|
||||
parentId=None,
|
||||
label=f"ai_result_{i}",
|
||||
typeGroup="text", # Default to text for AI results
|
||||
mimeType="text/plain",
|
||||
data=result.content,
|
||||
metadata={
|
||||
"aiResult": True,
|
||||
"modelName": result.modelName,
|
||||
"priceUsd": result.priceUsd,
|
||||
"processingTime": result.processingTime,
|
||||
"bytesSent": result.bytesSent,
|
||||
"bytesReceived": result.bytesReceived
|
||||
}
|
||||
)
|
||||
content_parts.append(content_part)
|
||||
|
||||
# Use existing merging system
|
||||
merge_strategy = MergeStrategy(
|
||||
useIntelligentMerging=True,
|
||||
groupBy="typeGroup",
|
||||
orderBy="id",
|
||||
mergeType="concatenate"
|
||||
)
|
||||
|
||||
merged_parts = applyMerging(content_parts, merge_strategy)
|
||||
|
||||
# Convert merged parts back to final string
|
||||
final_content = "\n\n".join([part.data for part in merged_parts])
|
||||
|
||||
logger.info(f"Merged {len(partResults)} AI results using existing merging system")
|
||||
return final_content.strip()
|
||||
|
||||
def _mergeChunkResults(self, chunkResults: List[AiCallResponse]) -> str:
|
||||
"""Merge chunk results using the existing sophisticated merging system."""
|
||||
if not chunkResults:
|
||||
return ""
|
||||
|
||||
# Convert AiCallResponse results to ContentParts for merging
|
||||
|
||||
content_parts = []
|
||||
for i, result in enumerate(chunkResults):
|
||||
if result.content:
|
||||
content_part = ContentPart(
|
||||
id=str(uuid.uuid4()),
|
||||
parentId=None,
|
||||
label=f"chunk_result_{i}",
|
||||
typeGroup="text", # Default to text for AI results
|
||||
mimeType="text/plain",
|
||||
data=result.content,
|
||||
metadata={
|
||||
"aiResult": True,
|
||||
"chunk": True,
|
||||
"modelName": result.modelName,
|
||||
"priceUsd": result.priceUsd,
|
||||
"processingTime": result.processingTime,
|
||||
"bytesSent": result.bytesSent,
|
||||
"bytesReceived": result.bytesReceived
|
||||
}
|
||||
)
|
||||
content_parts.append(content_part)
|
||||
|
||||
# Use existing merging system
|
||||
merge_strategy = MergeStrategy(
|
||||
useIntelligentMerging=True,
|
||||
groupBy="typeGroup",
|
||||
orderBy="id",
|
||||
mergeType="concatenate"
|
||||
)
|
||||
|
||||
merged_parts = applyMerging(content_parts, merge_strategy)
|
||||
|
||||
# Convert merged parts back to final string
|
||||
final_content = "\n\n".join([part.data for part in merged_parts])
|
||||
|
||||
logger.info(f"Merged {len(chunkResults)} chunk results using existing merging system")
|
||||
return final_content.strip()
|
||||
|
||||
def _createErrorResponse(self, errorMsg: str, inputBytes: int, outputBytes: int) -> AiCallResponse:
|
||||
"""Create an error response."""
|
||||
return AiCallResponse(
|
||||
content=errorMsg,
|
||||
modelName="error",
|
||||
priceUsd=0.0,
|
||||
processingTime=0.0,
|
||||
bytesSent=inputBytes,
|
||||
bytesReceived=outputBytes,
|
||||
errorCount=1
|
||||
)
|
||||
|
||||
async def _callWithModel(self, model: AiModel, prompt: str, context: str, options: AiCallOptions = None) -> AiCallResponse:
|
||||
"""Call a specific model and return the response."""
|
||||
# Calculate input bytes from prompt and context
|
||||
inputBytes = len((prompt + context).encode('utf-8'))
|
||||
|
||||
# Replace <TOKEN_LIMIT> placeholder with model's maxTokens value
|
||||
if "<TOKEN_LIMIT>" in prompt:
|
||||
if model.maxTokens > 0:
|
||||
tokenLimit = str(model.maxTokens)
|
||||
modelPrompt = prompt.replace("<TOKEN_LIMIT>", tokenLimit)
|
||||
logger.debug(f"Replaced <TOKEN_LIMIT> with {tokenLimit} for model {model.name}")
|
||||
else:
|
||||
raise ValueError(f"Model {model.name} has invalid maxTokens ({model.maxTokens}). Cannot set token limit.")
|
||||
else:
|
||||
modelPrompt = prompt
|
||||
|
||||
# Update messages array with replaced content
|
||||
messages = []
|
||||
if context:
|
||||
messages.append({"role": "system", "content": f"Context from documents:\n{context}"})
|
||||
messages.append({"role": "user", "content": modelPrompt})
|
||||
|
||||
# Start timing
|
||||
startTime = time.time()
|
||||
|
||||
# Call the model's function directly - completely generic
|
||||
if model.functionCall:
|
||||
# Create standardized call object
|
||||
modelCall = AiModelCall(
|
||||
messages=messages,
|
||||
model=model,
|
||||
options=options or {}
|
||||
)
|
||||
|
||||
# Log before calling model
|
||||
contextSize = len(context.encode('utf-8')) if context else 0
|
||||
promptSize = len(modelPrompt.encode('utf-8')) if modelPrompt else 0
|
||||
totalInputSize = contextSize + promptSize
|
||||
logger.debug(f"Calling model {model.name} with {len(messages)} messages, context size: {contextSize} bytes, prompt size: {promptSize} bytes, total input: {totalInputSize} bytes")
|
||||
|
||||
# Call the model with standardized interface
|
||||
modelResponse = await model.functionCall(modelCall)
|
||||
|
||||
# Log after successful call
|
||||
logger.debug(f"Model {model.name} returned successfully")
|
||||
|
||||
# Extract content from standardized response
|
||||
if not modelResponse.success:
|
||||
raise ValueError(f"Model call failed: {modelResponse.error}")
|
||||
content = modelResponse.content
|
||||
else:
|
||||
raise ValueError(f"Model {model.name} has no function call defined")
|
||||
|
||||
# Calculate timing and output bytes
|
||||
endTime = time.time()
|
||||
processingTime = endTime - startTime
|
||||
outputBytes = len(content.encode("utf-8"))
|
||||
|
||||
# Calculate price using model's own price calculation method
|
||||
priceUsd = model.calculatePriceUsd(processingTime, inputBytes, outputBytes)
|
||||
|
||||
return AiCallResponse(
|
||||
content=content,
|
||||
modelName=model.name,
|
||||
priceUsd=priceUsd,
|
||||
processingTime=processingTime,
|
||||
bytesSent=inputBytes,
|
||||
bytesReceived=outputBytes,
|
||||
errorCount=0
|
||||
)
|
||||
|
||||
|
||||
# Utility methods
|
||||
async def listAvailableModels(self, connectorType: str = None) -> List[Dict[str, Any]]:
|
||||
"""List available models, optionally filtered by connector type."""
|
||||
models = modelRegistry.getAvailableModels()
|
||||
if connectorType:
|
||||
return [model.model_dump() for model in models if model.connectorType == connectorType]
|
||||
return [model.model_dump() for model in models]
|
||||
|
||||
async def getModelInfo(self, displayName: str) -> Dict[str, Any]:
|
||||
"""Get information about a specific model by displayName."""
|
||||
model = modelRegistry.getModel(displayName)
|
||||
if not model:
|
||||
raise ValueError(f"Model with displayName '{displayName}' not found")
|
||||
return model.model_dump()
|
||||
|
||||
async def getModelsByTag(self, tag: str) -> List[str]:
|
||||
"""Get model displayNames that have a specific tag. Returns displayNames (unique identifiers)."""
|
||||
models = modelRegistry.getModelsByTag(tag)
|
||||
return [model.displayName for model in models]
|
||||
|
||||
|
||||
def applyMerging(parts: List[ContentPart], strategy: MergeStrategy) -> List[ContentPart]:
|
||||
"""Apply merging strategy to parts with intelligent token-aware merging."""
|
||||
logger.debug(f"applyMerging called with {len(parts)} parts")
|
||||
|
||||
# Import merging dependencies
|
||||
from modules.services.serviceExtraction.merging.mergerText import TextMerger
|
||||
from modules.services.serviceExtraction.merging.mergerTable import TableMerger
|
||||
from modules.services.serviceExtraction.merging.mergerDefault import DefaultMerger
|
||||
from modules.services.serviceExtraction.subMerger import IntelligentTokenAwareMerger
|
||||
|
||||
# Check if intelligent merging is enabled
|
||||
if strategy.useIntelligentMerging:
|
||||
modelCapabilities = strategy.capabilities or {}
|
||||
subMerger = IntelligentTokenAwareMerger(modelCapabilities)
|
||||
|
||||
# Use intelligent merging for all parts
|
||||
merged = subMerger.mergeChunksIntelligently(parts, strategy.prompt or "")
|
||||
|
||||
# Calculate and log optimization stats
|
||||
stats = subMerger.calculateOptimizationStats(parts, merged)
|
||||
logger.info(f"🧠 Intelligent merging stats: {stats}")
|
||||
logger.debug(f"Intelligent merging: {stats['original_ai_calls']} → {stats['optimized_ai_calls']} calls ({stats['reduction_percent']}% reduction)")
|
||||
|
||||
return merged
|
||||
|
||||
# Fallback to traditional merging
|
||||
textMerger = TextMerger()
|
||||
tableMerger = TableMerger()
|
||||
defaultMerger = DefaultMerger()
|
||||
|
||||
# Group by typeGroup
|
||||
textParts = [p for p in parts if p.typeGroup == "text"]
|
||||
tableParts = [p for p in parts if p.typeGroup == "table"]
|
||||
structureParts = [p for p in parts if p.typeGroup == "structure"]
|
||||
otherParts = [p for p in parts if p.typeGroup not in ("text", "table", "structure")]
|
||||
|
||||
logger.debug(f"Grouped - text: {len(textParts)}, table: {len(tableParts)}, structure: {len(structureParts)}, other: {len(otherParts)}")
|
||||
|
||||
merged: List[ContentPart] = []
|
||||
|
||||
if textParts:
|
||||
textMerged = textMerger.merge(textParts, strategy)
|
||||
logger.debug(f"TextMerger merged {len(textParts)} parts into {len(textMerged)} parts")
|
||||
merged.extend(textMerged)
|
||||
if tableParts:
|
||||
tableMerged = tableMerger.merge(tableParts, strategy)
|
||||
logger.debug(f"TableMerger merged {len(tableParts)} parts into {len(tableMerged)} parts")
|
||||
merged.extend(tableMerged)
|
||||
if structureParts:
|
||||
# For now, treat structure like text
|
||||
structureMerged = textMerger.merge(structureParts, strategy)
|
||||
logger.debug(f"StructureMerger merged {len(structureParts)} parts into {len(structureMerged)} parts")
|
||||
merged.extend(structureMerged)
|
||||
if otherParts:
|
||||
otherMerged = defaultMerger.merge(otherParts, strategy)
|
||||
logger.debug(f"DefaultMerger merged {len(otherParts)} parts into {len(otherMerged)} parts")
|
||||
merged.extend(otherMerged)
|
||||
|
||||
logger.debug(f"applyMerging returning {len(merged)} parts")
|
||||
return merged
|
||||
|
||||
|
|
@ -1,583 +0,0 @@
|
|||
"""
|
||||
Models for User Service
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from pydantic import BaseModel, Field, EmailStr
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from modules.shared.attributeUtils import register_model_labels, AttributeDefinition, ModelMixin
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
|
||||
class AuthAuthority(str, Enum):
|
||||
"""Authentication authority enum"""
|
||||
LOCAL = "local"
|
||||
GOOGLE = "google"
|
||||
MSFT = "msft"
|
||||
|
||||
class UserPrivilege(str, Enum):
|
||||
"""User privilege levels"""
|
||||
SYSADMIN = "sysadmin"
|
||||
ADMIN = "admin"
|
||||
USER = "user"
|
||||
|
||||
class ConnectionStatus(str, Enum):
|
||||
"""Connection status"""
|
||||
ACTIVE = "active"
|
||||
EXPIRED = "expired"
|
||||
REVOKED = "revoked"
|
||||
PENDING = "pending"
|
||||
|
||||
class TokenStatus(str, Enum):
|
||||
"""Status of an issued gateway JWT access token"""
|
||||
ACTIVE = "active"
|
||||
REVOKED = "revoked"
|
||||
|
||||
class Mandate(BaseModel, ModelMixin):
|
||||
"""Data model for a mandate"""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Unique ID of the mandate",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
name: str = Field(
|
||||
description="Name of the mandate",
|
||||
frontend_type="text",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True
|
||||
)
|
||||
language: str = Field(
|
||||
default="en",
|
||||
description="Default language of the mandate",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True,
|
||||
frontend_options=[
|
||||
{"value": "de", "label": {"en": "Deutsch", "fr": "Allemand"}},
|
||||
{"value": "en", "label": {"en": "English", "fr": "Anglais"}},
|
||||
{"value": "fr", "label": {"en": "Français", "fr": "Français"}},
|
||||
{"value": "it", "label": {"en": "Italiano", "fr": "Italien"}}
|
||||
]
|
||||
)
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Indicates whether the mandate is enabled",
|
||||
frontend_type="checkbox",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False
|
||||
)
|
||||
|
||||
# Register labels for Mandate
|
||||
register_model_labels(
|
||||
"Mandate",
|
||||
{"en": "Mandate", "fr": "Mandat"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"name": {"en": "Name", "fr": "Nom"},
|
||||
"language": {"en": "Language", "fr": "Langue"},
|
||||
"enabled": {"en": "Enabled", "fr": "Activé"}
|
||||
}
|
||||
)
|
||||
|
||||
class UserConnection(BaseModel, ModelMixin):
|
||||
"""Data model for a user's connection to an external service"""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Unique ID of the connection",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
userId: str = Field(
|
||||
description="ID of the user this connection belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
authority: AuthAuthority = Field(
|
||||
description="Authentication authority",
|
||||
frontend_type="select",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
frontend_options=[
|
||||
{"value": "local", "label": {"en": "Local", "fr": "Local"}},
|
||||
{"value": "google", "label": {"en": "Google", "fr": "Google"}},
|
||||
{"value": "msft", "label": {"en": "Microsoft", "fr": "Microsoft"}}
|
||||
]
|
||||
)
|
||||
externalId: str = Field(
|
||||
description="User ID in the external system",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
externalUsername: str = Field(
|
||||
description="Username in the external system",
|
||||
frontend_type="text",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False
|
||||
)
|
||||
externalEmail: Optional[EmailStr] = Field(
|
||||
None,
|
||||
description="Email in the external system",
|
||||
frontend_type="email",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False
|
||||
)
|
||||
status: ConnectionStatus = Field(
|
||||
default=ConnectionStatus.ACTIVE,
|
||||
description="Connection status",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False,
|
||||
frontend_options=[
|
||||
{"value": "active", "label": {"en": "Active", "fr": "Actif"}},
|
||||
{"value": "inactive", "label": {"en": "Inactive", "fr": "Inactif"}},
|
||||
{"value": "expired", "label": {"en": "Expired", "fr": "Expiré"}},
|
||||
{"value": "pending", "label": {"en": "Pending", "fr": "En attente"}}
|
||||
]
|
||||
)
|
||||
connectedAt: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="When the connection was established (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
lastChecked: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="When the connection was last verified (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
expiresAt: Optional[float] = Field(
|
||||
None,
|
||||
description="When the connection expires (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
tokenStatus: Optional[str] = Field(
|
||||
None,
|
||||
description="Current token status: active, expired, none",
|
||||
frontend_type="select",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
frontend_options=[
|
||||
{"value": "active", "label": {"en": "Active", "fr": "Actif"}},
|
||||
{"value": "expired", "label": {"en": "Expired", "fr": "Expiré"}},
|
||||
{"value": "none", "label": {"en": "None", "fr": "Aucun"}}
|
||||
]
|
||||
)
|
||||
tokenExpiresAt: Optional[float] = Field(
|
||||
None,
|
||||
description="When the current token expires (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
|
||||
# Register labels for UserConnection
|
||||
register_model_labels(
|
||||
"UserConnection",
|
||||
{"en": "User Connection", "fr": "Connexion utilisateur"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"authority": {"en": "Authority", "fr": "Autorité"},
|
||||
"externalId": {"en": "External ID", "fr": "ID externe"},
|
||||
"externalUsername": {"en": "External Username", "fr": "Nom d'utilisateur externe"},
|
||||
"externalEmail": {"en": "External Email", "fr": "Email externe"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"connectedAt": {"en": "Connected At", "fr": "Connecté le"},
|
||||
"lastChecked": {"en": "Last Checked", "fr": "Dernière vérification"},
|
||||
"expiresAt": {"en": "Expires At", "fr": "Expire le"},
|
||||
"tokenStatus": {"en": "Connection Status", "fr": "Statut de connexion"},
|
||||
"tokenExpiresAt": {"en": "Expires At", "fr": "Expire le"}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
class User(BaseModel, ModelMixin):
|
||||
"""Data model for a user"""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Unique ID of the user",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
username: str = Field(
|
||||
description="Username for login",
|
||||
frontend_type="text",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True
|
||||
)
|
||||
email: Optional[EmailStr] = Field(
|
||||
None,
|
||||
description="Email address of the user",
|
||||
frontend_type="email",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True
|
||||
)
|
||||
fullName: Optional[str] = Field(
|
||||
None,
|
||||
description="Full name of the user",
|
||||
frontend_type="text",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False
|
||||
)
|
||||
language: str = Field(
|
||||
default="en",
|
||||
description="Preferred language of the user",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True,
|
||||
frontend_options=[
|
||||
{"value": "de", "label": {"en": "Deutsch", "fr": "Allemand"}},
|
||||
{"value": "en", "label": {"en": "English", "fr": "Anglais"}},
|
||||
{"value": "fr", "label": {"en": "Français", "fr": "Français"}},
|
||||
{"value": "it", "label": {"en": "Italiano", "fr": "Italien"}}
|
||||
]
|
||||
)
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Indicates whether the user is enabled",
|
||||
frontend_type="checkbox",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False
|
||||
)
|
||||
privilege: UserPrivilege = Field(
|
||||
default=UserPrivilege.USER,
|
||||
description="Permission level",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True,
|
||||
frontend_options=[
|
||||
{"value": "user", "label": {"en": "User", "fr": "Utilisateur"}},
|
||||
{"value": "admin", "label": {"en": "Admin", "fr": "Administrateur"}},
|
||||
{"value": "sysadmin", "label": {"en": "SysAdmin", "fr": "Administrateur système"}}
|
||||
]
|
||||
)
|
||||
authenticationAuthority: AuthAuthority = Field(
|
||||
default=AuthAuthority.LOCAL,
|
||||
description="Primary authentication authority",
|
||||
frontend_type="select",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
frontend_options=[
|
||||
{"value": "local", "label": {"en": "Local", "fr": "Local"}},
|
||||
{"value": "google", "label": {"en": "Google", "fr": "Google"}},
|
||||
{"value": "msft", "label": {"en": "Microsoft", "fr": "Microsoft"}}
|
||||
]
|
||||
)
|
||||
mandateId: Optional[str] = Field(
|
||||
None,
|
||||
description="ID of the mandate this user belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
|
||||
# Register labels for User
|
||||
register_model_labels(
|
||||
"User",
|
||||
{"en": "User", "fr": "Utilisateur"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"username": {"en": "Username", "fr": "Nom d'utilisateur"},
|
||||
"email": {"en": "Email", "fr": "Email"},
|
||||
"fullName": {"en": "Full Name", "fr": "Nom complet"},
|
||||
"language": {"en": "Language", "fr": "Langue"},
|
||||
"enabled": {"en": "Enabled", "fr": "Activé"},
|
||||
"privilege": {"en": "Privilege", "fr": "Privilège"},
|
||||
"authenticationAuthority": {"en": "Auth Authority", "fr": "Autorité d'authentification"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID de mandat"}
|
||||
}
|
||||
)
|
||||
|
||||
class UserInDB(User):
|
||||
"""Extended user class with password hash"""
|
||||
hashedPassword: Optional[str] = Field(None, description="Hash of the user password")
|
||||
|
||||
# Register labels for UserInDB
|
||||
register_model_labels(
|
||||
"UserInDB",
|
||||
{"en": "User Access", "fr": "Accès de l'utilisateur"},
|
||||
{
|
||||
"hashedPassword": {"en": "Password hash", "fr": "Hachage de mot de passe"}
|
||||
}
|
||||
)
|
||||
|
||||
# Token Models
|
||||
class Token(BaseModel, ModelMixin):
|
||||
"""Token model for all authentication types"""
|
||||
id: Optional[str] = None
|
||||
userId: str
|
||||
authority: AuthAuthority
|
||||
connectionId: Optional[str] = Field(None, description="ID of the connection this token belongs to")
|
||||
tokenAccess: str
|
||||
tokenType: str = "bearer"
|
||||
expiresAt: float = Field(description="When the token expires (UTC timestamp in seconds)")
|
||||
tokenRefresh: Optional[str] = None
|
||||
createdAt: Optional[float] = Field(None, description="When the token was created (UTC timestamp in seconds)")
|
||||
# Revocation and session tracking (for LOCAL gateway JWTs)
|
||||
status: TokenStatus = Field(default=TokenStatus.ACTIVE, description="Token status: active/revoked")
|
||||
revokedAt: Optional[float] = Field(None, description="When the token was revoked (UTC timestamp in seconds)")
|
||||
revokedBy: Optional[str] = Field(None, description="User ID who revoked the token (admin/self)")
|
||||
reason: Optional[str] = Field(None, description="Optional revocation reason")
|
||||
sessionId: Optional[str] = Field(None, description="Logical session grouping for logout revocation")
|
||||
mandateId: Optional[str] = Field(None, description="Mandate ID for tenant scoping of the token")
|
||||
|
||||
class Config:
|
||||
useEnumValues = True
|
||||
|
||||
# Register labels for Token
|
||||
register_model_labels(
|
||||
"Token",
|
||||
{"en": "Token", "fr": "Jeton"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"authority": {"en": "Authority", "fr": "Autorité"},
|
||||
"connectionId": {"en": "Connection ID", "fr": "ID de connexion"},
|
||||
"tokenAccess": {"en": "Access Token", "fr": "Jeton d'accès"},
|
||||
"tokenType": {"en": "Token Type", "fr": "Type de jeton"},
|
||||
"expiresAt": {"en": "Expires At", "fr": "Expire le"},
|
||||
"tokenRefresh": {"en": "Refresh Token", "fr": "Jeton de rafraîchissement"},
|
||||
"createdAt": {"en": "Created At", "fr": "Créé le"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"revokedAt": {"en": "Revoked At", "fr": "Révoqué le"},
|
||||
"revokedBy": {"en": "Revoked By", "fr": "Révoqué par"},
|
||||
"reason": {"en": "Reason", "fr": "Raison"},
|
||||
"sessionId": {"en": "Session ID", "fr": "ID de session"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID de mandat"}
|
||||
}
|
||||
)
|
||||
|
||||
class LocalToken(Token):
|
||||
"""Local authentication token model"""
|
||||
pass
|
||||
|
||||
class GoogleToken(Token):
|
||||
"""Google OAuth token model"""
|
||||
pass
|
||||
|
||||
class MsftToken(Token):
|
||||
"""Microsoft OAuth token model"""
|
||||
pass
|
||||
|
||||
class AuthEvent(BaseModel, ModelMixin):
|
||||
"""Data model for authentication events"""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Unique ID of the auth event",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
userId: str = Field(
|
||||
description="ID of the user this event belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
eventType: str = Field(
|
||||
description="Type of authentication event (e.g., 'login', 'logout', 'token_refresh')",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
timestamp: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="Unix timestamp when the event occurred",
|
||||
frontend_type="datetime",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
ipAddress: Optional[str] = Field(
|
||||
default=None,
|
||||
description="IP address from which the event originated",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
userAgent: Optional[str] = Field(
|
||||
default=None,
|
||||
description="User agent string from the request",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
success: bool = Field(
|
||||
default=True,
|
||||
description="Whether the authentication event was successful",
|
||||
frontend_type="boolean",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
details: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Additional details about the event",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
|
||||
# Register labels for AuthEvent
|
||||
register_model_labels(
|
||||
"AuthEvent",
|
||||
{"en": "Authentication Event", "fr": "Événement d'authentification"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"eventType": {"en": "Event Type", "fr": "Type d'événement"},
|
||||
"timestamp": {"en": "Timestamp", "fr": "Horodatage"},
|
||||
"ipAddress": {"en": "IP Address", "fr": "Adresse IP"},
|
||||
"userAgent": {"en": "User Agent", "fr": "Agent utilisateur"},
|
||||
"success": {"en": "Success", "fr": "Succès"},
|
||||
"details": {"en": "Details", "fr": "Détails"}
|
||||
}
|
||||
)
|
||||
|
||||
class DataNeutraliserConfig(BaseModel, ModelMixin):
|
||||
"""Data model for data neutralization configuration"""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Unique ID of the configuration",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
mandateId: str = Field(
|
||||
description="ID of the mandate this configuration belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
userId: str = Field(
|
||||
description="ID of the user who created this configuration",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
enabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether data neutralization is enabled",
|
||||
frontend_type="checkbox",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False
|
||||
)
|
||||
namesToParse: str = Field(
|
||||
default="",
|
||||
description="Multiline list of names to parse for neutralization",
|
||||
frontend_type="textarea",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False
|
||||
)
|
||||
sharepointSourcePath: str = Field(
|
||||
default="",
|
||||
description="SharePoint path to read files for neutralization",
|
||||
frontend_type="text",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False
|
||||
)
|
||||
sharepointTargetPath: str = Field(
|
||||
default="",
|
||||
description="SharePoint path to store neutralized files",
|
||||
frontend_type="text",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False
|
||||
)
|
||||
|
||||
# Register labels for DataNeutraliserConfig
|
||||
register_model_labels(
|
||||
"DataNeutraliserConfig",
|
||||
{"en": "Data Neutralization Config", "fr": "Configuration de neutralisation des données"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID de mandat"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"enabled": {"en": "Enabled", "fr": "Activé"},
|
||||
"namesToParse": {"en": "Names to Parse", "fr": "Noms à analyser"},
|
||||
"sharepointSourcePath": {"en": "Source Path", "fr": "Chemin source"},
|
||||
"sharepointTargetPath": {"en": "Target Path", "fr": "Chemin cible"}
|
||||
}
|
||||
)
|
||||
|
||||
class DataNeutralizerAttributes(BaseModel, ModelMixin):
|
||||
"""Data model for neutralized data attributes mapping"""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Unique ID of the attribute mapping (used as UID in neutralized files)",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
mandateId: str = Field(
|
||||
description="ID of the mandate this attribute belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
userId: str = Field(
|
||||
description="ID of the user who created this attribute",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
originalText: str = Field(
|
||||
description="Original text that was neutralized",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
fileId: Optional[str] = Field(
|
||||
default=None,
|
||||
description="ID of the file this attribute belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
patternType: str = Field(
|
||||
description="Type of pattern that matched (email, phone, name, etc.)",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
|
||||
# Register labels for DataNeutralizerAttributes
|
||||
register_model_labels(
|
||||
"DataNeutralizerAttributes",
|
||||
{"en": "Neutralized Data Attribute", "fr": "Attribut de données neutralisées"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID de mandat"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"originalText": {"en": "Original Text", "fr": "Texte original"},
|
||||
"fileId": {"en": "File ID", "fr": "ID de fichier"},
|
||||
"patternType": {"en": "Pattern Type", "fr": "Type de modèle"}
|
||||
}
|
||||
)
|
||||
|
||||
class SystemTable(BaseModel, ModelMixin):
|
||||
"""Data model for system table entries"""
|
||||
table_name: str = Field(
|
||||
description="Name of the table",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
initial_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Initial ID for the table",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
|
||||
|
|
@ -1,893 +0,0 @@
|
|||
"""
|
||||
Chat model classes for the chat system.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, UTC
|
||||
import uuid
|
||||
from enum import Enum
|
||||
|
||||
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
|
||||
# ===== Method Models =====
|
||||
|
||||
class ActionDocument(BaseModel, ModelMixin):
|
||||
"""Clear document structure for action results"""
|
||||
documentName: str = Field(description="Name of the document")
|
||||
documentData: Any = Field(description="Content/data of the document")
|
||||
mimeType: str = Field(description="MIME type of the document")
|
||||
|
||||
# Register labels for ActionDocument
|
||||
register_model_labels(
|
||||
"ActionDocument",
|
||||
{"en": "Action Document", "fr": "Document d'action"},
|
||||
{
|
||||
"documentName": {"en": "Document Name", "fr": "Nom du document"},
|
||||
"documentData": {"en": "Document Data", "fr": "Données du document"},
|
||||
"mimeType": {"en": "MIME Type", "fr": "Type MIME"}
|
||||
}
|
||||
)
|
||||
|
||||
class ActionResult(BaseModel, ModelMixin):
|
||||
"""Clean action result with documents as primary output
|
||||
|
||||
IMPORTANT: Action methods should NOT set resultLabel in their return value.
|
||||
The resultLabel is managed by the action handler using the action's execResultLabel
|
||||
from the action plan. This ensures consistent document routing throughout the workflow.
|
||||
"""
|
||||
# Core result
|
||||
success: bool = Field(description="Whether execution succeeded")
|
||||
error: Optional[str] = Field(None, description="Error message if failed")
|
||||
|
||||
# Primary output - documents
|
||||
documents: List[ActionDocument] = Field(default_factory=list, description="Document outputs")
|
||||
resultLabel: Optional[str] = Field(None, description="Label for document routing (set by action handler, not by action methods)")
|
||||
|
||||
@classmethod
|
||||
def isSuccess(cls, documents: List[ActionDocument] = None) -> 'ActionResult':
|
||||
"""Create a successful action result
|
||||
|
||||
Note: Do not set resultLabel - this is managed by the action handler
|
||||
"""
|
||||
return cls(
|
||||
success=True,
|
||||
documents=documents or []
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def isFailure(cls, error: str, documents: List[ActionDocument] = None) -> 'ActionResult':
|
||||
"""Create a failed action result
|
||||
|
||||
Note: Do not set resultLabel - this is managed by the action handler
|
||||
"""
|
||||
return cls(
|
||||
success=False,
|
||||
documents=documents or [],
|
||||
error=error
|
||||
)
|
||||
|
||||
# Register labels for ActionResult
|
||||
register_model_labels(
|
||||
"ActionResult",
|
||||
{"en": "Action Result", "fr": "Résultat de l'action"},
|
||||
{
|
||||
"success": {"en": "Success", "fr": "Succès"},
|
||||
"error": {"en": "Error", "fr": "Erreur"},
|
||||
"documents": {"en": "Documents", "fr": "Documents"},
|
||||
"resultLabel": {"en": "Result Label", "fr": "Étiquette du résultat"}
|
||||
}
|
||||
)
|
||||
|
||||
# ===== Base Enums and Simple Models =====
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Task status enumeration"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
# Register labels for TaskStatus
|
||||
register_model_labels(
|
||||
"TaskStatus",
|
||||
{"en": "Task Status", "fr": "Statut de la tâche"},
|
||||
{
|
||||
"PENDING": {"en": "Pending", "fr": "En attente"},
|
||||
"RUNNING": {"en": "Running", "fr": "En cours"},
|
||||
"COMPLETED": {"en": "Completed", "fr": "Terminé"},
|
||||
"FAILED": {"en": "Failed", "fr": "Échec"},
|
||||
"CANCELLED": {"en": "Cancelled", "fr": "Annulé"},
|
||||
"ROLLED_BACK": {"en": "Rolled Back", "fr": "Annulé"}
|
||||
}
|
||||
)
|
||||
|
||||
class UserInputRequest(BaseModel, ModelMixin):
|
||||
"""Data model for a user input request"""
|
||||
prompt: str = Field(description="Prompt for the user")
|
||||
listFileId: List[str] = Field(default_factory=list, description="List of file IDs")
|
||||
userLanguage: str = Field(default="en", description="User's preferred language")
|
||||
|
||||
# Register labels for UserInputRequest
|
||||
register_model_labels(
|
||||
"UserInputRequest",
|
||||
{"en": "User Input Request", "fr": "Demande de saisie utilisateur"},
|
||||
{
|
||||
"prompt": {"en": "Prompt", "fr": "Invite"},
|
||||
"listFileId": {"en": "File IDs", "fr": "IDs des fichiers"},
|
||||
"userLanguage": {"en": "User Language", "fr": "Langue de l'utilisateur"}
|
||||
}
|
||||
)
|
||||
|
||||
# ===== Content Models =====
|
||||
|
||||
class ContentMetadata(BaseModel, ModelMixin):
|
||||
"""Metadata for content items"""
|
||||
size: int = Field(description="Content size in bytes")
|
||||
pages: Optional[int] = Field(None, description="Number of pages for multi-page content")
|
||||
error: Optional[str] = Field(None, description="Processing error if any")
|
||||
width: Optional[int] = Field(None, description="Width in pixels for images/videos")
|
||||
height: Optional[int] = Field(None, description="Height in pixels for images/videos")
|
||||
colorMode: Optional[str] = Field(None, description="Color mode (e.g., RGB, CMYK, grayscale)")
|
||||
fps: Optional[float] = Field(None, description="Frames per second for videos")
|
||||
durationSec: Optional[float] = Field(None, description="Duration in seconds for videos/audio")
|
||||
mimeType: str = Field(description="MIME type of the content")
|
||||
base64Encoded: bool = Field(description="Whether the data is base64 encoded")
|
||||
|
||||
# Register labels for ContentMetadata
|
||||
register_model_labels(
|
||||
"ContentMetadata",
|
||||
{"en": "Content Metadata", "fr": "Métadonnées du contenu"},
|
||||
{
|
||||
"size": {"en": "Size", "fr": "Taille"},
|
||||
"pages": {"en": "Pages", "fr": "Pages"},
|
||||
"error": {"en": "Error", "fr": "Erreur"},
|
||||
"width": {"en": "Width", "fr": "Largeur"},
|
||||
"height": {"en": "Height", "fr": "Hauteur"},
|
||||
"colorMode": {"en": "Color Mode", "fr": "Mode de couleur"},
|
||||
"fps": {"en": "FPS", "fr": "IPS"},
|
||||
"durationSec": {"en": "Duration", "fr": "Durée"},
|
||||
"mimeType": {"en": "MIME Type", "fr": "Type MIME"},
|
||||
"base64Encoded": {"en": "Base64 Encoded", "fr": "Encodé en Base64"}
|
||||
}
|
||||
)
|
||||
|
||||
class ContentItem(BaseModel, ModelMixin):
|
||||
"""Individual content item from a document"""
|
||||
label: str = Field(description="Content label (e.g., tab name, tag name)")
|
||||
data: str = Field(description="Extracted text content")
|
||||
metadata: ContentMetadata = Field(description="Content metadata")
|
||||
|
||||
# Register labels for ContentItem
|
||||
register_model_labels(
|
||||
"ContentItem",
|
||||
{"en": "Content Item", "fr": "Élément de contenu"},
|
||||
{
|
||||
"label": {"en": "Label", "fr": "Étiquette"},
|
||||
"data": {"en": "Data", "fr": "Données"},
|
||||
"metadata": {"en": "Metadata", "fr": "Métadonnées"}
|
||||
}
|
||||
)
|
||||
|
||||
class ChatDocument(BaseModel, ModelMixin):
|
||||
"""Data model for a chat document"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
messageId: str = Field(description="Foreign key to message")
|
||||
fileId: str = Field(description="Foreign key to file")
|
||||
|
||||
# Direct file attributes (copied from file object)
|
||||
fileName: str = Field(description="Name of the file")
|
||||
fileSize: int = Field(description="Size of the file")
|
||||
mimeType: str = Field(description="MIME type of the file")
|
||||
|
||||
# Workflow context fields
|
||||
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
|
||||
taskNumber: Optional[int] = Field(None, description="Task number within round")
|
||||
actionNumber: Optional[int] = Field(None, description="Action number within task")
|
||||
|
||||
# Reference to action that created this document
|
||||
actionId: Optional[str] = Field(None, description="ID of the action that created this document")
|
||||
|
||||
|
||||
|
||||
# Register labels for ChatDocument
|
||||
register_model_labels(
|
||||
"ChatDocument",
|
||||
{"en": "Chat Document", "fr": "Document de chat"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"messageId": {"en": "Message ID", "fr": "ID du message"},
|
||||
"fileId": {"en": "File ID", "fr": "ID du fichier"},
|
||||
"fileName": {"en": "File Name", "fr": "Nom du fichier"},
|
||||
"fileSize": {"en": "File Size", "fr": "Taille du fichier"},
|
||||
"mimeType": {"en": "MIME Type", "fr": "Type MIME"},
|
||||
"roundNumber": {"en": "Round Number", "fr": "Numéro de tour"},
|
||||
"taskNumber": {"en": "Task Number", "fr": "Numéro de tâche"},
|
||||
"actionNumber": {"en": "Action Number", "fr": "Numéro d'action"},
|
||||
"actionId": {"en": "Action ID", "fr": "ID de l'action"}
|
||||
}
|
||||
)
|
||||
|
||||
class DocumentExchange(BaseModel, ModelMixin):
|
||||
"""Data model for document exchange between AI actions"""
|
||||
documentsLabel: str = Field(description="Label for the set of documents")
|
||||
documents: List[str] = Field(default_factory=list, description="List of document references")
|
||||
|
||||
# Register labels for DocumentExchange
|
||||
register_model_labels(
|
||||
"DocumentExchange",
|
||||
{"en": "Document Exchange", "fr": "Échange de documents"},
|
||||
{
|
||||
"documentsLabel": {"en": "Documents Label", "fr": "Label des documents"},
|
||||
"documents": {"en": "Documents", "fr": "Documents"}
|
||||
}
|
||||
)
|
||||
|
||||
class ExtractedContent(BaseModel, ModelMixin):
|
||||
"""Data model for extracted content"""
|
||||
id: str = Field(description="Reference to source ChatDocument")
|
||||
contents: List[ContentItem] = Field(default_factory=list, description="List of content items")
|
||||
|
||||
# Register labels for ExtractedContent
|
||||
register_model_labels(
|
||||
"ExtractedContent",
|
||||
{"en": "Extracted Content", "fr": "Contenu extrait"},
|
||||
{
|
||||
"objectId": {"en": "Object ID", "fr": "ID de l'objet"},
|
||||
"objectType": {"en": "Object Type", "fr": "Type d'objet"},
|
||||
"contents": {"en": "Contents", "fr": "Contenus"}
|
||||
}
|
||||
)
|
||||
|
||||
# ===== Task Models =====
|
||||
|
||||
class TaskAction(BaseModel, ModelMixin):
|
||||
"""Model for task actions"""
|
||||
id: str = Field(..., description="Action ID")
|
||||
execMethod: str = Field(..., description="Method to execute")
|
||||
execAction: str = Field(..., description="Action to perform")
|
||||
execParameters: Dict[str, Any] = Field(default_factory=dict, description="Action parameters")
|
||||
execResultLabel: Optional[str] = Field(None, description="Label for the set of result documents")
|
||||
# NEW: Optional document format specification
|
||||
expectedDocumentFormats: Optional[List[Dict[str, str]]] = Field(None, description="Expected document formats (optional)")
|
||||
|
||||
# User message in user's language
|
||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
||||
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Action status")
|
||||
error: Optional[str] = Field(None, description="Error message if action failed")
|
||||
retryCount: int = Field(default=0, description="Number of retries attempted")
|
||||
retryMax: int = Field(default=3, description="Maximum number of retries")
|
||||
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
|
||||
timestamp: float = Field(default_factory=get_utc_timestamp, description="When the action was executed (UTC timestamp in seconds)")
|
||||
result: Optional[str] = Field(None, description="Result of the action")
|
||||
resultDocuments: Optional[List[ChatDocument]] = Field(None, description="Result documents from the action")
|
||||
|
||||
def isSuccessful(self) -> bool:
|
||||
"""Check if action was successful"""
|
||||
return self.status == TaskStatus.COMPLETED
|
||||
|
||||
def hasError(self) -> bool:
|
||||
"""Check if action has an error"""
|
||||
return self.status == TaskStatus.FAILED
|
||||
|
||||
def getErrorMessage(self) -> Optional[str]:
|
||||
"""Get error message if any"""
|
||||
return self.error if self.hasError() else None
|
||||
|
||||
def setError(self, error: str) -> None:
|
||||
"""Set action error"""
|
||||
self.error = error
|
||||
self.status = TaskStatus.FAILED
|
||||
|
||||
def setSuccess(self) -> None:
|
||||
"""Set action as successful"""
|
||||
self.status = TaskStatus.COMPLETED
|
||||
self.error = None
|
||||
|
||||
# Register labels for TaskAction
|
||||
register_model_labels(
|
||||
"TaskAction",
|
||||
{"en": "Task Action", "fr": "Action de tâche"},
|
||||
{
|
||||
"id": {"en": "Action ID", "fr": "ID de l'action"},
|
||||
"execMethod": {"en": "Method", "fr": "Méthode"},
|
||||
"execAction": {"en": "Action", "fr": "Action"},
|
||||
"execParameters": {"en": "Parameters", "fr": "Paramètres"},
|
||||
"execResultLabel": {"en": "Result Label", "fr": "Label du résultat"},
|
||||
"expectedDocumentFormats": {"en": "Expected Document Formats", "fr": "Formats de documents attendus"},
|
||||
"userMessage": {"en": "User Message", "fr": "Message utilisateur"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"error": {"en": "Error", "fr": "Erreur"},
|
||||
"retryCount": {"en": "Retry Count", "fr": "Nombre de tentatives"},
|
||||
"retryMax": {"en": "Max Retries", "fr": "Tentatives max"},
|
||||
"processingTime": {"en": "Processing Time", "fr": "Temps de traitement"},
|
||||
"timestamp": {"en": "Timestamp", "fr": "Horodatage"},
|
||||
"result": {"en": "Result", "fr": "Résultat"},
|
||||
"resultDocuments": {"en": "Result Documents", "fr": "Documents de résultat"}
|
||||
}
|
||||
)
|
||||
|
||||
class TaskResult(BaseModel, ModelMixin):
|
||||
"""Model for task results"""
|
||||
taskId: str = Field(..., description="Task ID")
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Task status")
|
||||
success: bool = Field(..., description="Whether the task was successful")
|
||||
feedback: Optional[str] = Field(None, description="Task feedback message")
|
||||
error: Optional[str] = Field(None, description="Error message if task failed")
|
||||
|
||||
# Register labels for TaskResult
|
||||
register_model_labels(
|
||||
"TaskResult",
|
||||
{"en": "Task Result", "fr": "Résultat de tâche"},
|
||||
{
|
||||
"taskId": {"en": "Task ID", "fr": "ID de la tâche"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"success": {"en": "Success", "fr": "Succès"},
|
||||
"feedback": {"en": "Feedback", "fr": "Retour"},
|
||||
"error": {"en": "Error", "fr": "Erreur"}
|
||||
}
|
||||
)
|
||||
|
||||
class TaskItem(BaseModel, ModelMixin):
|
||||
"""Model for workflow tasks"""
|
||||
id: str = Field(..., description="Task ID")
|
||||
workflowId: str = Field(..., description="Workflow ID")
|
||||
userInput: str = Field(..., description="User input that triggered the task")
|
||||
status: TaskStatus = Field(default=TaskStatus.PENDING, description="Task status")
|
||||
error: Optional[str] = Field(None, description="Error message if task failed")
|
||||
startedAt: Optional[float] = Field(None, description="When the task started (UTC timestamp in seconds)")
|
||||
finishedAt: Optional[float] = Field(None, description="When the task finished (UTC timestamp in seconds)")
|
||||
actionList: List[TaskAction] = Field(default_factory=list, description="List of actions to execute")
|
||||
retryCount: int = Field(default=0, description="Number of retries attempted")
|
||||
retryMax: int = Field(default=3, description="Maximum number of retries")
|
||||
rollbackOnFailure: bool = Field(default=True, description="Whether to rollback on failure")
|
||||
dependencies: List[str] = Field(default_factory=list, description="List of task IDs this task depends on")
|
||||
feedback: Optional[str] = Field(None, description="Task feedback message")
|
||||
processingTime: Optional[float] = Field(None, description="Total processing time in seconds")
|
||||
resultLabels: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Map of result labels to their values")
|
||||
|
||||
def isSuccessful(self) -> bool:
|
||||
"""Check if task was successful"""
|
||||
return self.status == TaskStatus.COMPLETED
|
||||
|
||||
def hasError(self) -> bool:
|
||||
"""Check if task has an error"""
|
||||
return self.status == TaskStatus.FAILED
|
||||
|
||||
def getErrorMessage(self) -> Optional[str]:
|
||||
"""Get error message if any"""
|
||||
return self.error if self.hasError() else None
|
||||
|
||||
def getResultDocuments(self) -> List[ChatDocument]:
|
||||
"""Get all documents from all successful actions"""
|
||||
documents = []
|
||||
for action in self.actionList:
|
||||
if action.isSuccessful() and action.resultDocuments:
|
||||
documents.extend(action.resultDocuments)
|
||||
return documents
|
||||
|
||||
def getResultDocumentLabel(self) -> Optional[str]:
|
||||
"""Get the label for the result documents"""
|
||||
for action in self.actionList:
|
||||
if action.isSuccessful() and action.execResultLabel:
|
||||
return action.execResultLabel
|
||||
return None
|
||||
|
||||
def getResultLabel(self, label: str) -> Optional[Any]:
|
||||
"""Get value for a specific result label"""
|
||||
return self.resultLabels.get(label) if self.resultLabels else None
|
||||
|
||||
# Register labels for TaskItem
|
||||
register_model_labels(
|
||||
"TaskItem",
|
||||
{"en": "Task", "fr": "Tâche"},
|
||||
{
|
||||
"id": {"en": "Task ID", "fr": "ID de la tâche"},
|
||||
"workflowId": {"en": "Workflow ID", "fr": "ID du workflow"},
|
||||
"userInput": {"en": "User Input", "fr": "Entrée utilisateur"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"error": {"en": "Error", "fr": "Erreur"},
|
||||
"startedAt": {"en": "Started At", "fr": "Démarré à"},
|
||||
"finishedAt": {"en": "Finished At", "fr": "Terminé à"},
|
||||
"actionList": {"en": "Actions", "fr": "Actions"},
|
||||
"retryCount": {"en": "Retry Count", "fr": "Nombre de tentatives"},
|
||||
"retryMax": {"en": "Max Retries", "fr": "Tentatives max"},
|
||||
"rollbackOnFailure": {"en": "Rollback On Failure", "fr": "Annuler en cas d'échec"},
|
||||
"dependencies": {"en": "Dependencies", "fr": "Dépendances"},
|
||||
"feedback": {"en": "Feedback", "fr": "Retour"},
|
||||
"processingTime": {"en": "Processing Time", "fr": "Temps de traitement"}
|
||||
}
|
||||
)
|
||||
|
||||
class ChatStat(BaseModel, ModelMixin):
|
||||
"""Data model for chat statistics - ONLY statistics, not workflow progress"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
workflowId: Optional[str] = Field(None, description="Foreign key to workflow (for workflow stats)")
|
||||
messageId: Optional[str] = Field(None, description="Foreign key to message (for message stats)")
|
||||
processingTime: Optional[float] = Field(None, description="Processing time in seconds")
|
||||
tokenCount: Optional[int] = Field(None, description="Number of tokens processed")
|
||||
bytesSent: Optional[int] = Field(None, description="Number of bytes sent")
|
||||
bytesReceived: Optional[int] = Field(None, description="Number of bytes received")
|
||||
successRate: Optional[float] = Field(None, description="Success rate of operations")
|
||||
errorCount: Optional[int] = Field(None, description="Number of errors encountered")
|
||||
|
||||
# Register labels for ChatStat
|
||||
register_model_labels(
|
||||
"ChatStat",
|
||||
{"en": "Chat Statistics", "fr": "Statistiques de chat"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"workflowId": {"en": "Workflow ID", "fr": "ID du workflow"},
|
||||
"messageId": {"en": "Message ID", "fr": "ID du message"},
|
||||
"processingTime": {"en": "Processing Time", "fr": "Temps de traitement"},
|
||||
"tokenCount": {"en": "Token Count", "fr": "Nombre de tokens"},
|
||||
"bytesSent": {"en": "Bytes Sent", "fr": "Octets envoyés"},
|
||||
"bytesReceived": {"en": "Bytes Received", "fr": "Octets reçus"},
|
||||
"successRate": {"en": "Success Rate", "fr": "Taux de succès"},
|
||||
"errorCount": {"en": "Error Count", "fr": "Nombre d'erreurs"}
|
||||
}
|
||||
)
|
||||
|
||||
class ChatLog(BaseModel, ModelMixin):
|
||||
"""Data model for chat logs"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
workflowId: str = Field(description="Foreign key to workflow")
|
||||
message: str = Field(description="Log message")
|
||||
type: str = Field(description="Log type (info, warning, error, etc.)")
|
||||
timestamp: float = Field(default_factory=get_utc_timestamp, description="When the log entry was created (UTC timestamp in seconds)")
|
||||
status: Optional[str] = Field(None, description="Status of the log entry")
|
||||
progress: Optional[float] = Field(None, description="Progress indicator (0.0 to 1.0)")
|
||||
performance: Optional[Dict[str, Any]] = Field(None, description="Performance metrics")
|
||||
|
||||
# Register labels for ChatLog
|
||||
register_model_labels(
|
||||
"ChatLog",
|
||||
{"en": "Chat Log", "fr": "Journal de chat"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"workflowId": {"en": "Workflow ID", "fr": "ID du flux de travail"},
|
||||
"message": {"en": "Message", "fr": "Message"},
|
||||
"type": {"en": "Type", "fr": "Type"},
|
||||
"timestamp": {"en": "Timestamp", "fr": "Horodatage"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"progress": {"en": "Progress", "fr": "Progression"},
|
||||
"performance": {"en": "Performance", "fr": "Performance"}
|
||||
}
|
||||
)
|
||||
|
||||
class ChatMessage(BaseModel, ModelMixin):
|
||||
"""Data model for a chat message"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
workflowId: str = Field(description="Foreign key to workflow")
|
||||
parentMessageId: Optional[str] = Field(None, description="Parent message ID for threading")
|
||||
documents: List[ChatDocument] = Field(default_factory=list, description="Associated documents")
|
||||
documentsLabel: Optional[str] = Field(None, description="Label for the set of documents")
|
||||
message: Optional[str] = Field(None, description="Message content")
|
||||
role: str = Field(description="Role of the message sender")
|
||||
status: str = Field(description="Status of the message (first, step, last)")
|
||||
sequenceNr: int = Field(description="Sequence number of the message (set automatically)")
|
||||
publishedAt: float = Field(default_factory=get_utc_timestamp, description="When the message was published (UTC timestamp in seconds)")
|
||||
stats: Optional[ChatStat] = Field(None, description="Statistics for this message")
|
||||
success: Optional[bool] = Field(None, description="Whether the message processing was successful")
|
||||
actionId: Optional[str] = Field(None, description="ID of the action that produced this message")
|
||||
actionMethod: Optional[str] = Field(None, description="Method of the action that produced this message")
|
||||
actionName: Optional[str] = Field(None, description="Name of the action that produced this message")
|
||||
|
||||
# New workflow context fields:
|
||||
roundNumber: Optional[int] = Field(None, description="Round number in workflow")
|
||||
taskNumber: Optional[int] = Field(None, description="Task number within round")
|
||||
actionNumber: Optional[int] = Field(None, description="Action number within task")
|
||||
|
||||
# New workflow progress fields:
|
||||
taskProgress: Optional[str] = Field(
|
||||
None,
|
||||
description="Task progress status: pending, running, success, fail, retry"
|
||||
)
|
||||
|
||||
actionProgress: Optional[str] = Field(
|
||||
None,
|
||||
description="Action progress status: pending, running, success, fail"
|
||||
)
|
||||
|
||||
# Register labels for ChatMessage
|
||||
register_model_labels(
|
||||
"ChatMessage",
|
||||
{"en": "Chat Message", "fr": "Message de chat"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"workflowId": {"en": "Workflow ID", "fr": "ID du flux de travail"},
|
||||
"parentMessageId": {"en": "Parent Message ID", "fr": "ID du message parent"},
|
||||
"documents": {"en": "Documents", "fr": "Documents"},
|
||||
"documentsLabel": {"en": "Documents Label", "fr": "Label des documents"},
|
||||
"message": {"en": "Message", "fr": "Message"},
|
||||
"role": {"en": "Role", "fr": "Rôle"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"sequenceNr": {"en": "Sequence Number", "fr": "Numéro de séquence"},
|
||||
"publishedAt": {"en": "Published At", "fr": "Publié le"},
|
||||
"stats": {"en": "Statistics", "fr": "Statistiques"},
|
||||
"success": {"en": "Success", "fr": "Succès"},
|
||||
"actionId": {"en": "Action ID", "fr": "ID de l'action"},
|
||||
"actionMethod": {"en": "Action Method", "fr": "Méthode de l'action"},
|
||||
"actionName": {"en": "Action Name", "fr": "Nom de l'action"},
|
||||
"roundNumber": {"en": "Round Number", "fr": "Numéro de tour"},
|
||||
"taskNumber": {"en": "Task Number", "fr": "Numéro de tâche"},
|
||||
"actionNumber": {"en": "Action Number", "fr": "Numéro d'action"},
|
||||
"taskProgress": {"en": "Task Progress", "fr": "Progression de la tâche"},
|
||||
"actionProgress": {"en": "Action Progress", "fr": "Progression de l'action"}
|
||||
}
|
||||
)
|
||||
|
||||
class ChatWorkflow(BaseModel, ModelMixin):
|
||||
"""Data model for a chat workflow"""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Primary key",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
mandateId: str = Field(
|
||||
description="ID of the mandate this workflow belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
status: str = Field(
|
||||
description="Current status of the workflow",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False,
|
||||
frontend_options=[
|
||||
{"value": "running", "label": {"en": "Running", "fr": "En cours"}},
|
||||
{"value": "completed", "label": {"en": "Completed", "fr": "Terminé"}},
|
||||
{"value": "stopped", "label": {"en": "Stopped", "fr": "Arrêté"}},
|
||||
{"value": "error", "label": {"en": "Error", "fr": "Erreur"}}
|
||||
]
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of the workflow",
|
||||
frontend_type="text",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True
|
||||
)
|
||||
currentRound: int = Field(
|
||||
description="Current round number",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
currentTask: int = Field(
|
||||
default=0,
|
||||
description="Current task number",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
currentAction: int = Field(
|
||||
default=0,
|
||||
description="Current action number",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
totalTasks: int = Field(
|
||||
default=0,
|
||||
description="Total number of tasks in the workflow",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
totalActions: int = Field(
|
||||
default=0,
|
||||
description="Total number of actions in the workflow",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
lastActivity: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="Timestamp of last activity (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
startedAt: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="When the workflow started (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
logs: List[ChatLog] = Field(
|
||||
default_factory=list,
|
||||
description="Workflow logs",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
messages: List[ChatMessage] = Field(
|
||||
default_factory=list,
|
||||
description="Messages in the workflow",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
stats: Optional[ChatStat] = Field(
|
||||
None,
|
||||
description="Workflow statistics",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
tasks: List[TaskItem] = Field(
|
||||
default_factory=list,
|
||||
description="List of tasks in the workflow",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
|
||||
# Register labels for ChatWorkflow
|
||||
register_model_labels(
|
||||
"ChatWorkflow",
|
||||
{"en": "Chat Workflow", "fr": "Flux de travail de chat"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"name": {"en": "Name", "fr": "Nom"},
|
||||
"currentRound": {"en": "Current Round", "fr": "Tour actuel"},
|
||||
"currentTask": {"en": "Current Task", "fr": "Tâche actuelle"},
|
||||
"currentAction": {"en": "Current Action", "fr": "Action actuelle"},
|
||||
"totalTasks": {"en": "Total Tasks", "fr": "Total des tâches"},
|
||||
"totalActions": {"en": "Total Actions", "fr": "Total des actions"},
|
||||
"lastActivity": {"en": "Last Activity", "fr": "Dernière activité"},
|
||||
"startedAt": {"en": "Started At", "fr": "Démarré le"},
|
||||
"logs": {"en": "Logs", "fr": "Journaux"},
|
||||
"messages": {"en": "Messages", "fr": "Messages"},
|
||||
"stats": {"en": "Statistics", "fr": "Statistiques"},
|
||||
"tasks": {"en": "Tasks", "fr": "Tâches"}
|
||||
}
|
||||
)
|
||||
|
||||
# ====== WORKFLOW SUPPORT MODELS (for managerChat.py compatibility) ======
|
||||
|
||||
class TaskStep(BaseModel, ModelMixin):
|
||||
id: str
|
||||
objective: str
|
||||
dependencies: Optional[list[str]] = Field(default_factory=list)
|
||||
success_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
estimated_complexity: Optional[str] = None
|
||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
||||
|
||||
# Register labels for TaskStep
|
||||
register_model_labels(
|
||||
"TaskStep",
|
||||
{"en": "Task Step", "fr": "Étape de tâche"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"objective": {"en": "Objective", "fr": "Objectif"},
|
||||
"dependencies": {"en": "Dependencies", "fr": "Dépendances"},
|
||||
"success_criteria": {"en": "Success Criteria", "fr": "Critères de succès"},
|
||||
"estimated_complexity": {"en": "Estimated Complexity", "fr": "Complexité estimée"},
|
||||
"userMessage": {"en": "User Message", "fr": "Message utilisateur"}
|
||||
}
|
||||
)
|
||||
|
||||
class TaskHandover(BaseModel, ModelMixin):
|
||||
"""Structured handover between workflow phases and tasks"""
|
||||
taskId: str = Field(description="Target task ID")
|
||||
sourceTask: Optional[str] = Field(None, description="Source task ID")
|
||||
|
||||
# Document handovers
|
||||
inputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Available input documents")
|
||||
outputDocuments: List[DocumentExchange] = Field(default_factory=list, description="Produced output documents")
|
||||
|
||||
# Context and state
|
||||
context: Dict[str, Any] = Field(default_factory=dict, description="Task context")
|
||||
previousResults: List[str] = Field(default_factory=list, description="Previous result summaries")
|
||||
improvements: List[str] = Field(default_factory=list, description="Improvement suggestions")
|
||||
|
||||
# Workflow context
|
||||
workflowSummary: Optional[str] = Field(None, description="Summarized workflow context")
|
||||
messageHistory: List[str] = Field(default_factory=list, description="Key message summaries")
|
||||
|
||||
# Metadata
|
||||
timestamp: float = Field(default_factory=get_utc_timestamp, description="When the handover was created (UTC timestamp in seconds)")
|
||||
handoverType: str = Field(default="task", description="Type of handover: task, phase, or workflow")
|
||||
|
||||
def addInputDocument(self, documentExchange: DocumentExchange) -> None:
|
||||
"""Add an input document exchange"""
|
||||
self.inputDocuments.append(documentExchange)
|
||||
|
||||
def addOutputDocument(self, documentExchange: DocumentExchange) -> None:
|
||||
"""Add an output document exchange"""
|
||||
self.outputDocuments.append(documentExchange)
|
||||
|
||||
def getDocumentsForAction(self, actionId: str) -> List[DocumentExchange]:
|
||||
"""Get all document exchanges relevant for a specific action"""
|
||||
relevant = []
|
||||
for doc_exchange in self.inputDocuments + self.outputDocuments:
|
||||
if doc_exchange.isForAction(actionId):
|
||||
relevant.append(doc_exchange)
|
||||
return relevant
|
||||
|
||||
# Register labels for TaskHandover
|
||||
register_model_labels(
|
||||
"TaskHandover",
|
||||
{"en": "Task Handover", "fr": "Transfert de tâche"},
|
||||
{
|
||||
"taskId": {"en": "Task ID", "fr": "ID de la tâche"},
|
||||
"sourceTask": {"en": "Source Task", "fr": "Tâche source"},
|
||||
"inputDocuments": {"en": "Input Documents", "fr": "Documents d'entrée"},
|
||||
"outputDocuments": {"en": "Output Documents", "fr": "Documents de sortie"},
|
||||
"context": {"en": "Context", "fr": "Contexte"},
|
||||
"previousResults": {"en": "Previous Results", "fr": "Résultats précédents"},
|
||||
"improvements": {"en": "Improvements", "fr": "Améliorations"},
|
||||
"workflowSummary": {"en": "Workflow Summary", "fr": "Résumé du workflow"},
|
||||
"messageHistory": {"en": "Message History", "fr": "Historique des messages"},
|
||||
"timestamp": {"en": "Timestamp", "fr": "Horodatage"},
|
||||
"handoverType": {"en": "Handover Type", "fr": "Type de transfert"}
|
||||
}
|
||||
)
|
||||
|
||||
class TaskContext(BaseModel, ModelMixin):
|
||||
task_step: TaskStep
|
||||
workflow: Optional['ChatWorkflow'] = None
|
||||
workflow_id: Optional[str] = None
|
||||
|
||||
# Available resources
|
||||
available_documents: Optional[str] = "No documents available"
|
||||
available_connections: Optional[list[str]] = Field(default_factory=list)
|
||||
|
||||
# Previous execution state
|
||||
previous_results: Optional[list[str]] = Field(default_factory=list)
|
||||
previous_handover: Optional[TaskHandover] = None
|
||||
|
||||
# Current execution state
|
||||
improvements: Optional[list[str]] = Field(default_factory=list)
|
||||
retry_count: Optional[int] = 0
|
||||
previous_action_results: Optional[list] = Field(default_factory=list)
|
||||
previous_review_result: Optional[dict] = None
|
||||
is_regeneration: Optional[bool] = False
|
||||
|
||||
# Failure analysis
|
||||
failure_patterns: Optional[list[str]] = Field(default_factory=list)
|
||||
failed_actions: Optional[list] = Field(default_factory=list)
|
||||
successful_actions: Optional[list] = Field(default_factory=list)
|
||||
|
||||
# Criteria progress tracking for retries
|
||||
criteria_progress: Optional[dict] = None
|
||||
|
||||
def getDocumentReferences(self) -> List[str]:
|
||||
"""Get all available document references from previous handover"""
|
||||
docs = []
|
||||
if self.previous_handover:
|
||||
for doc_exchange in self.previous_handover.inputDocuments:
|
||||
docs.extend(doc_exchange.documents)
|
||||
return list(set(docs)) # Remove duplicates
|
||||
|
||||
def addImprovement(self, improvement: str) -> None:
|
||||
"""Add an improvement suggestion"""
|
||||
if improvement not in (self.improvements or []):
|
||||
if self.improvements is None:
|
||||
self.improvements = []
|
||||
self.improvements.append(improvement)
|
||||
|
||||
class ReviewContext(BaseModel, ModelMixin):
|
||||
task_step: TaskStep
|
||||
task_actions: Optional[list] = Field(default_factory=list)
|
||||
action_results: Optional[list] = Field(default_factory=list)
|
||||
step_result: Optional[dict] = Field(default_factory=dict)
|
||||
workflow_id: Optional[str] = None
|
||||
previous_results: Optional[list[str]] = Field(default_factory=list)
|
||||
|
||||
class ReviewResult(BaseModel, ModelMixin):
|
||||
status: str
|
||||
reason: Optional[str] = None
|
||||
improvements: Optional[list[str]] = Field(default_factory=list)
|
||||
quality_score: Optional[int] = 5
|
||||
missing_outputs: Optional[list[str]] = Field(default_factory=list)
|
||||
met_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
unmet_criteria: Optional[list[str]] = Field(default_factory=list)
|
||||
confidence: Optional[float] = 0.5
|
||||
userMessage: Optional[str] = Field(None, description="User-friendly message in user's language")
|
||||
|
||||
# Register labels for ReviewResult
|
||||
register_model_labels(
|
||||
"ReviewResult",
|
||||
{"en": "Review Result", "fr": "Résultat de l'évaluation"},
|
||||
{
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"reason": {"en": "Reason", "fr": "Raison"},
|
||||
"improvements": {"en": "Improvements", "fr": "Améliorations"},
|
||||
"quality_score": {"en": "Quality Score", "fr": "Score de qualité"},
|
||||
"missing_outputs": {"en": "Missing Outputs", "fr": "Sorties manquantes"},
|
||||
"met_criteria": {"en": "Met Criteria", "fr": "Critères respectés"},
|
||||
"unmet_criteria": {"en": "Unmet Criteria", "fr": "Critères non respectés"},
|
||||
"confidence": {"en": "Confidence", "fr": "Confiance"},
|
||||
"userMessage": {"en": "User Message", "fr": "Message utilisateur"}
|
||||
}
|
||||
)
|
||||
|
||||
class TaskPlan(BaseModel, ModelMixin):
|
||||
overview: str
|
||||
tasks: list[TaskStep]
|
||||
userMessage: Optional[str] = Field(None, description="Overall user-friendly message for the task plan")
|
||||
|
||||
# Register labels for TaskPlan
|
||||
register_model_labels(
|
||||
"TaskPlan",
|
||||
{"en": "Task Plan", "fr": "Plan de tâches"},
|
||||
{
|
||||
"overview": {"en": "Overview", "fr": "Aperçu"},
|
||||
"tasks": {"en": "Tasks", "fr": "Tâches"},
|
||||
"userMessage": {"en": "User Message", "fr": "Message utilisateur"}
|
||||
}
|
||||
)
|
||||
|
||||
class WorkflowResult(BaseModel, ModelMixin):
|
||||
status: str
|
||||
completed_tasks: int
|
||||
total_tasks: int
|
||||
execution_time: float
|
||||
final_results_count: int
|
||||
error: Optional[str] = None
|
||||
phase: Optional[str] = None
|
||||
|
||||
# Register labels for WorkflowResult
|
||||
register_model_labels(
|
||||
"WorkflowResult",
|
||||
{"en": "Workflow Result", "fr": "Résultat du workflow"},
|
||||
{
|
||||
"status": {"en": "Status", "fr": "Statut"},
|
||||
"completed_tasks": {"en": "Completed Tasks", "fr": "Tâches terminées"},
|
||||
"total_tasks": {"en": "Total Tasks", "fr": "Total des tâches"},
|
||||
"execution_time": {"en": "Execution Time", "fr": "Temps d'exécution"},
|
||||
"final_results_count": {"en": "Final Results Count", "fr": "Nombre de résultats finaux"},
|
||||
"error": {"en": "Error", "fr": "Erreur"},
|
||||
"phase": {"en": "Phase", "fr": "Phase"}
|
||||
}
|
||||
)
|
||||
|
||||
# ===== Centralized AI Call Response Models =====
|
||||
|
||||
class AiResult(BaseModel, ModelMixin):
|
||||
"""Document result from centralized AI call"""
|
||||
filename: str = Field(description="Name of the result document")
|
||||
mimetype: str = Field(description="MIME type of the result document")
|
||||
content: str = Field(description="Content of the result document")
|
||||
|
||||
# Register labels for AiResult
|
||||
register_model_labels(
|
||||
"AiResult",
|
||||
{"en": "Result Document", "fr": "Document de résultat"},
|
||||
{
|
||||
"filename": {"en": "Filename", "fr": "Nom de fichier"},
|
||||
"mimetype": {"en": "MIME Type", "fr": "Type MIME"},
|
||||
"content": {"en": "Content", "fr": "Contenu"}
|
||||
}
|
||||
)
|
||||
|
||||
class CentralizedAiResponse(BaseModel, ModelMixin):
|
||||
"""Standardized response format from centralized AI calls"""
|
||||
aiResults: List[AiResult] = Field(default_factory=list, description="List of result documents")
|
||||
success: bool = Field(description="Whether the AI call was successful")
|
||||
error: Optional[str] = Field(None, description="Error message if the call failed")
|
||||
|
||||
# Register labels for CentralizedAiResponse
|
||||
register_model_labels(
|
||||
"CentralizedAiResponse",
|
||||
{"en": "Centralized AI Response", "fr": "Réponse IA centralisée"},
|
||||
{
|
||||
"aiResults": {"en": "Result Documents", "fr": "Documents de résultat"},
|
||||
"success": {"en": "Success", "fr": "Succès"},
|
||||
"error": {"en": "Error", "fr": "Erreur"}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,264 +0,0 @@
|
|||
"""
|
||||
Service Management model classes for the service management system.
|
||||
Updated to match the Entity Relation Diagram structure.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
# Import for label registration
|
||||
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
|
||||
# CORE MODELS
|
||||
|
||||
class FileItem(BaseModel, ModelMixin):
|
||||
"""Data model for a file item"""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Primary key",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
mandateId: str = Field(
|
||||
description="ID of the mandate this file belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
fileName: str = Field(
|
||||
description="Name of the file",
|
||||
frontend_type="text",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True
|
||||
)
|
||||
mimeType: str = Field(
|
||||
description="MIME type of the file",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
fileHash: str = Field(
|
||||
description="Hash of the file",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
fileSize: int = Field(
|
||||
description="Size of the file in bytes",
|
||||
frontend_type="integer",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
creationDate: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="Date when the file was created (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert model to dictionary"""
|
||||
return super().to_dict()
|
||||
|
||||
# Register labels for FileItem
|
||||
register_model_labels(
|
||||
"FileItem",
|
||||
{"en": "File Item", "fr": "Élément de fichier"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||
"fileName": {"en": "fileName", "fr": "Nom de fichier"},
|
||||
"mimeType": {"en": "MIME Type", "fr": "Type MIME"},
|
||||
"fileHash": {"en": "File Hash", "fr": "Hash du fichier"},
|
||||
"fileSize": {"en": "File Size", "fr": "Taille du fichier"},
|
||||
"creationDate": {"en": "Creation Date", "fr": "Date de création"}
|
||||
}
|
||||
)
|
||||
|
||||
class FilePreview(BaseModel, ModelMixin):
|
||||
"""Data model for file preview"""
|
||||
content: Union[str, bytes] = Field(description="File content (text or binary)")
|
||||
mimeType: str = Field(description="MIME type of the file")
|
||||
fileName: str = Field(description="Original fileName")
|
||||
isText: bool = Field(description="Whether the content is text (True) or binary (False)")
|
||||
encoding: Optional[str] = Field(None, description="Text encoding if content is text")
|
||||
size: int = Field(description="Size of the content in bytes")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert model to dictionary with proper content handling"""
|
||||
data = super().to_dict()
|
||||
# Convert bytes to base64 string if content is binary
|
||||
if isinstance(data.get("content"), bytes):
|
||||
import base64
|
||||
data["content"] = base64.b64encode(data["content"]).decode('utf-8')
|
||||
return data
|
||||
|
||||
# Register labels for FilePreview
|
||||
register_model_labels(
|
||||
"FilePreview",
|
||||
{"en": "File Preview", "fr": "Aperçu du fichier"},
|
||||
{
|
||||
"content": {"en": "Content", "fr": "Contenu"},
|
||||
"mimeType": {"en": "MIME Type", "fr": "Type MIME"},
|
||||
"fileName": {"en": "fileName", "fr": "Nom de fichier"},
|
||||
"isText": {"en": "Is Text", "fr": "Est du texte"},
|
||||
"encoding": {"en": "Encoding", "fr": "Encodage"},
|
||||
"size": {"en": "Size", "fr": "Taille"}
|
||||
}
|
||||
)
|
||||
|
||||
class FileData(BaseModel, ModelMixin):
|
||||
"""Data model for file data"""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
data: str = Field(description="File data content")
|
||||
base64Encoded: bool = Field(description="Whether the data is base64 encoded")
|
||||
|
||||
# Register labels for FileData
|
||||
register_model_labels(
|
||||
"FileData",
|
||||
{"en": "File Data", "fr": "Données de fichier"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"data": {"en": "Data", "fr": "Données"},
|
||||
"base64Encoded": {"en": "Base64 Encoded", "fr": "Encodé en Base64"}
|
||||
}
|
||||
)
|
||||
|
||||
class Prompt(BaseModel, ModelMixin):
|
||||
"""Data model for a prompt"""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Primary key",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
mandateId: str = Field(
|
||||
description="ID of the mandate this prompt belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
content: str = Field(
|
||||
description="Content of the prompt",
|
||||
frontend_type="textarea",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True
|
||||
)
|
||||
name: str = Field(
|
||||
description="Name of the prompt",
|
||||
frontend_type="text",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True
|
||||
)
|
||||
|
||||
# Register labels for Prompt
|
||||
register_model_labels(
|
||||
"Prompt",
|
||||
{"en": "Prompt", "fr": "Invite"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||
"content": {"en": "Content", "fr": "Contenu"},
|
||||
"name": {"en": "Name", "fr": "Nom"}
|
||||
}
|
||||
)
|
||||
|
||||
class VoiceSettings(BaseModel, ModelMixin):
|
||||
"""Data model for voice service settings per user"""
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Primary key",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
userId: str = Field(
|
||||
description="ID of the user these settings belong to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True
|
||||
)
|
||||
mandateId: str = Field(
|
||||
description="ID of the mandate these settings belong to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
sttLanguage: str = Field(
|
||||
default="de-DE",
|
||||
description="Speech-to-Text language",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True
|
||||
)
|
||||
ttsLanguage: str = Field(
|
||||
default="de-DE",
|
||||
description="Text-to-Speech language",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True
|
||||
)
|
||||
ttsVoice: str = Field(
|
||||
default="de-DE-KatjaNeural",
|
||||
description="Text-to-Speech voice",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=True
|
||||
)
|
||||
translationEnabled: bool = Field(
|
||||
default=True,
|
||||
description="Whether translation is enabled",
|
||||
frontend_type="checkbox",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False
|
||||
)
|
||||
targetLanguage: str = Field(
|
||||
default="en-US",
|
||||
description="Target language for translation",
|
||||
frontend_type="select",
|
||||
frontend_readonly=False,
|
||||
frontend_required=False
|
||||
)
|
||||
creationDate: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="Date when the settings were created (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
lastModified: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="Date when the settings were last modified (UTC timestamp in seconds)",
|
||||
frontend_type="timestamp",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert model to dictionary"""
|
||||
return super().to_dict()
|
||||
|
||||
# Register labels for VoiceSettings
|
||||
register_model_labels(
|
||||
"VoiceSettings",
|
||||
{"en": "Voice Settings", "fr": "Paramètres vocaux"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||
"sttLanguage": {"en": "STT Language", "fr": "Langue STT"},
|
||||
"ttsLanguage": {"en": "TTS Language", "fr": "Langue TTS"},
|
||||
"ttsVoice": {"en": "TTS Voice", "fr": "Voix TTS"},
|
||||
"translationEnabled": {"en": "Translation Enabled", "fr": "Traduction activée"},
|
||||
"targetLanguage": {"en": "Target Language", "fr": "Langue cible"},
|
||||
"creationDate": {"en": "Creation Date", "fr": "Date de création"},
|
||||
"lastModified": {"en": "Last Modified", "fr": "Dernière modification"}
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -4,9 +4,8 @@ Access control for the Application.
|
|||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from modules.interfaces.interfaceAppModel import UserPrivilege, User, UserInDB, AuthEvent, Mandate
|
||||
from modules.shared.timezoneUtils import get_utc_now
|
||||
from modules.datamodels.datamodelUam import UserPrivilege, User, UserInDB, Mandate
|
||||
from modules.datamodels.datamodelSecurity import AuthEvent
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -4,8 +4,8 @@ Handles user access management and permission checks.
|
|||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from modules.interfaces.interfaceAppModel import User, UserPrivilege
|
||||
from modules.interfaces.interfaceChatModel import ChatWorkflow, ChatMessage, ChatLog, ChatStat, ChatDocument
|
||||
from modules.datamodels.datamodelUam import User, UserPrivilege
|
||||
from modules.datamodels.datamodelChat import ChatWorkflow, AutomationDefinition
|
||||
|
||||
class ChatAccess:
|
||||
"""
|
||||
|
|
@ -41,7 +41,13 @@ class ChatAccess:
|
|||
filtered_records = []
|
||||
|
||||
# Apply filtering based on privilege
|
||||
if userPrivilege == UserPrivilege.SYSADMIN:
|
||||
if table_name == "AutomationDefinition":
|
||||
# Users see only their own automation definitions
|
||||
filtered_records = [
|
||||
r for r in recordset
|
||||
if r.get("mandateId","-") == self.mandateId and r.get("_createdBy") == self.userId
|
||||
]
|
||||
elif userPrivilege == UserPrivilege.SYSADMIN:
|
||||
filtered_records = recordset # System admins see all records
|
||||
elif userPrivilege == UserPrivilege.ADMIN:
|
||||
# Admins see records in their mandate
|
||||
|
|
@ -68,6 +74,10 @@ class ChatAccess:
|
|||
record["_hideView"] = False # Everyone can view
|
||||
record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
|
||||
record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
|
||||
elif table_name == "AutomationDefinition":
|
||||
record["_hideView"] = False # Everyone can view
|
||||
record["_hideEdit"] = not self.canModify(AutomationDefinition, record_id)
|
||||
record["_hideDelete"] = not self.canModify(AutomationDefinition, record_id)
|
||||
else:
|
||||
# Default access control for other tables
|
||||
record["_hideView"] = False
|
||||
1765
modules/interfaces/interfaceDbChatObjects.py
Normal file
1765
modules/interfaces/interfaceDbChatObjects.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -5,9 +5,10 @@ Handles user access management and permission checks.
|
|||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from modules.interfaces.interfaceAppModel import User, UserInDB
|
||||
from modules.interfaces.interfaceComponentModel import Prompt, FileItem, FileData, VoiceSettings
|
||||
from modules.interfaces.interfaceChatModel import ChatWorkflow, ChatMessage, ChatLog
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelUtils import Prompt
|
||||
from modules.datamodels.datamodelFiles import FileItem
|
||||
from modules.datamodels.datamodelChat import ChatWorkflow
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -5,23 +5,21 @@ Uses the JSON connector for data access with added language support.
|
|||
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime, UTC
|
||||
import base64
|
||||
import hashlib
|
||||
import math
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
import hashlib
|
||||
|
||||
from modules.interfaces.interfaceComponentAccess import ComponentAccess
|
||||
from modules.interfaces.interfaceComponentModel import (
|
||||
FilePreview, Prompt, FileItem, FileData, VoiceSettings
|
||||
)
|
||||
from modules.interfaces.interfaceAppModel import User, Mandate
|
||||
|
||||
# DYNAMIC PART: Connectors to the Interface
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||||
|
||||
# Basic Configurations
|
||||
from modules.interfaces.interfaceDbComponentAccess import ComponentAccess
|
||||
from modules.datamodels.datamodelFiles import FilePreview, FileItem, FileData
|
||||
from modules.datamodels.datamodelUtils import Prompt
|
||||
from modules.datamodels.datamodelVoice import VoiceSettings
|
||||
from modules.datamodels.datamodelUam import User, Mandate
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
from modules.shared.timezoneUtils import getUtcTimestamp
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Singleton factory for Management instances with AI service per context
|
||||
|
|
@ -150,7 +148,7 @@ class ComponentObjects:
|
|||
return
|
||||
|
||||
# Get the root interface to access the initial mandate ID
|
||||
from modules.interfaces.interfaceAppObjects import getRootInterface
|
||||
from modules.interfaces.interfaceDbAppObjects import getRootInterface
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# Get initial mandate ID through the root interface
|
||||
|
|
@ -248,6 +246,56 @@ class ComponentObjects:
|
|||
"""Delegate to access control module."""
|
||||
return self.access.canModify(model_class, recordId)
|
||||
|
||||
def _applyFilters(self, records: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Apply filter criteria to records (implementation for future filtering)."""
|
||||
# TODO: Implement filtering logic when needed
|
||||
return records
|
||||
|
||||
def _applySorting(self, records: List[Dict[str, Any]], sortFields: List[Any]) -> List[Dict[str, Any]]:
|
||||
"""Apply multi-level sorting to records using stable sort (sorts from least to most significant field)."""
|
||||
if not sortFields:
|
||||
return records
|
||||
|
||||
# Start with a copy to avoid modifying original
|
||||
sortedRecords = list(records)
|
||||
|
||||
# Sort from least significant to most significant field (reverse order)
|
||||
# Python's sort is stable, so this creates proper multi-level sorting
|
||||
for sortField in reversed(sortFields):
|
||||
# Handle both dict and object formats
|
||||
if isinstance(sortField, dict):
|
||||
fieldName = sortField.get("field")
|
||||
direction = sortField.get("direction", "asc")
|
||||
else:
|
||||
fieldName = getattr(sortField, "field", None)
|
||||
direction = getattr(sortField, "direction", "asc")
|
||||
|
||||
if not fieldName:
|
||||
continue
|
||||
|
||||
isDesc = (direction == "desc")
|
||||
|
||||
def sortKey(record):
|
||||
value = record.get(fieldName)
|
||||
# Handle None values - place them at the end for both directions
|
||||
if value is None:
|
||||
# Use a special value that sorts last
|
||||
return (1, "") # (is_none_flag, empty_value) - sorts after (0, ...)
|
||||
else:
|
||||
# Return tuple with type indicator for proper comparison
|
||||
if isinstance(value, (int, float)):
|
||||
return (0, value)
|
||||
elif isinstance(value, str):
|
||||
return (0, value)
|
||||
elif isinstance(value, bool):
|
||||
return (0, value)
|
||||
else:
|
||||
return (0, str(value))
|
||||
|
||||
# Sort with reverse parameter for descending
|
||||
sortedRecords.sort(key=sortKey, reverse=isDesc)
|
||||
|
||||
return sortedRecords
|
||||
|
||||
# Utilities
|
||||
|
||||
|
|
@ -259,18 +307,57 @@ class ComponentObjects:
|
|||
|
||||
# Prompt methods
|
||||
|
||||
def getAllPrompts(self) -> List[Prompt]:
|
||||
"""Returns prompts based on user access level."""
|
||||
def getAllPrompts(self, pagination: Optional[PaginationParams] = None) -> Union[List[Prompt], PaginatedResult]:
|
||||
"""
|
||||
Returns prompts based on user access level.
|
||||
Supports optional pagination, sorting, and filtering.
|
||||
|
||||
Args:
|
||||
pagination: Optional pagination parameters. If None, returns all items.
|
||||
|
||||
Returns:
|
||||
If pagination is None: List[Prompt]
|
||||
If pagination is provided: PaginatedResult with items and metadata
|
||||
"""
|
||||
try:
|
||||
allPrompts = self.db.getRecordset(Prompt)
|
||||
filteredPrompts = self._uam(Prompt, allPrompts)
|
||||
|
||||
# Convert to Prompt objects
|
||||
return [Prompt.from_dict(prompt) for prompt in filteredPrompts]
|
||||
# If no pagination requested, return all items
|
||||
if pagination is None:
|
||||
return [Prompt(**prompt) for prompt in filteredPrompts]
|
||||
|
||||
# Apply filtering (if filters provided)
|
||||
if pagination.filters:
|
||||
filteredPrompts = self._applyFilters(filteredPrompts, pagination.filters)
|
||||
|
||||
# Apply sorting (in order of sortFields)
|
||||
if pagination.sort:
|
||||
filteredPrompts = self._applySorting(filteredPrompts, pagination.sort)
|
||||
|
||||
# Count total items after filters
|
||||
totalItems = len(filteredPrompts)
|
||||
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
||||
|
||||
# Apply pagination (skip/limit)
|
||||
startIdx = (pagination.page - 1) * pagination.pageSize
|
||||
endIdx = startIdx + pagination.pageSize
|
||||
pagedPrompts = filteredPrompts[startIdx:endIdx]
|
||||
|
||||
# Convert to model objects
|
||||
items = [Prompt(**prompt) for prompt in pagedPrompts]
|
||||
|
||||
return PaginatedResult(
|
||||
items=items,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting prompts: {str(e)}")
|
||||
return []
|
||||
if pagination is None:
|
||||
return []
|
||||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||||
|
||||
def getPrompt(self, promptId: str) -> Optional[Prompt]:
|
||||
"""Returns a prompt by ID if user has access."""
|
||||
|
|
@ -279,7 +366,7 @@ class ComponentObjects:
|
|||
return None
|
||||
|
||||
filteredPrompts = self._uam(Prompt, prompts)
|
||||
return Prompt.from_dict(filteredPrompts[0]) if filteredPrompts else None
|
||||
return Prompt(**filteredPrompts[0]) if filteredPrompts else None
|
||||
|
||||
def createPrompt(self, promptData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Creates a new prompt if user has permission."""
|
||||
|
|
@ -311,7 +398,7 @@ class ComponentObjects:
|
|||
if not updatedPrompt:
|
||||
raise ValueError("Failed to retrieve updated prompt")
|
||||
|
||||
return updatedPrompt.to_dict()
|
||||
return updatedPrompt.model_dump()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating prompt: {str(e)}")
|
||||
|
|
@ -388,7 +475,6 @@ class ComponentObjects:
|
|||
|
||||
def getMimeType(self, fileName: str) -> str:
|
||||
"""Determines the MIME type based on the file extension."""
|
||||
import os
|
||||
ext = os.path.splitext(fileName)[1].lower()[1:]
|
||||
extensionToMime = {
|
||||
"pdf": "application/pdf",
|
||||
|
|
@ -459,39 +545,79 @@ class ComponentObjects:
|
|||
|
||||
# File methods - metadata-based operations
|
||||
|
||||
def getAllFiles(self) -> List[FileItem]:
|
||||
"""Returns files based on user access level."""
|
||||
def getAllFiles(self, pagination: Optional[PaginationParams] = None) -> Union[List[FileItem], PaginatedResult]:
|
||||
"""
|
||||
Returns files based on user access level.
|
||||
Supports optional pagination, sorting, and filtering.
|
||||
|
||||
Args:
|
||||
pagination: Optional pagination parameters. If None, returns all items.
|
||||
|
||||
Returns:
|
||||
If pagination is None: List[FileItem]
|
||||
If pagination is provided: PaginatedResult with items and metadata
|
||||
"""
|
||||
allFiles = self.db.getRecordset(FileItem)
|
||||
filteredFiles = self._uam(FileItem, allFiles)
|
||||
|
||||
# Convert database records to FileItem instances
|
||||
fileItems = []
|
||||
for file in filteredFiles:
|
||||
try:
|
||||
# Ensure proper values, use defaults for invalid data
|
||||
creationDate = file.get("creationDate")
|
||||
if creationDate is None or not isinstance(creationDate, (int, float)) or creationDate <= 0:
|
||||
creationDate = get_utc_timestamp()
|
||||
|
||||
fileName = file.get("fileName")
|
||||
if not fileName or fileName == "None":
|
||||
continue # Skip records with invalid fileName
|
||||
|
||||
fileItem = FileItem(
|
||||
id=file.get("id"),
|
||||
mandateId=file.get("mandateId"),
|
||||
fileName=fileName,
|
||||
mimeType=file.get("mimeType"),
|
||||
fileHash=file.get("fileHash"),
|
||||
fileSize=file.get("fileSize"),
|
||||
creationDate=creationDate
|
||||
)
|
||||
fileItems.append(fileItem)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping invalid file record: {str(e)}")
|
||||
continue
|
||||
|
||||
return fileItems
|
||||
# Convert database records to FileItem instances (for both paginated and non-paginated)
|
||||
def convertFileItems(files):
|
||||
fileItems = []
|
||||
for file in files:
|
||||
try:
|
||||
# Ensure proper values, use defaults for invalid data
|
||||
creationDate = file.get("creationDate")
|
||||
if creationDate is None or not isinstance(creationDate, (int, float)) or creationDate <= 0:
|
||||
creationDate = getUtcTimestamp()
|
||||
|
||||
fileName = file.get("fileName")
|
||||
if not fileName or fileName == "None":
|
||||
continue # Skip records with invalid fileName
|
||||
|
||||
fileItem = FileItem(
|
||||
id=file.get("id"),
|
||||
mandateId=file.get("mandateId"),
|
||||
fileName=fileName,
|
||||
mimeType=file.get("mimeType"),
|
||||
fileHash=file.get("fileHash"),
|
||||
fileSize=file.get("fileSize"),
|
||||
creationDate=creationDate
|
||||
)
|
||||
fileItems.append(fileItem)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipping invalid file record: {str(e)}")
|
||||
continue
|
||||
return fileItems
|
||||
|
||||
# If no pagination requested, return all items
|
||||
if pagination is None:
|
||||
return convertFileItems(filteredFiles)
|
||||
|
||||
# Apply filtering (if filters provided)
|
||||
if pagination.filters:
|
||||
filteredFiles = self._applyFilters(filteredFiles, pagination.filters)
|
||||
|
||||
# Apply sorting (in order of sortFields)
|
||||
if pagination.sort:
|
||||
filteredFiles = self._applySorting(filteredFiles, pagination.sort)
|
||||
|
||||
# Count total items after filters
|
||||
totalItems = len(filteredFiles)
|
||||
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
||||
|
||||
# Apply pagination (skip/limit)
|
||||
startIdx = (pagination.page - 1) * pagination.pageSize
|
||||
endIdx = startIdx + pagination.pageSize
|
||||
pagedFiles = filteredFiles[startIdx:endIdx]
|
||||
|
||||
# Convert to model objects
|
||||
items = convertFileItems(pagedFiles)
|
||||
|
||||
return PaginatedResult(
|
||||
items=items,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages
|
||||
)
|
||||
|
||||
def getFile(self, fileId: str) -> Optional[FileItem]:
|
||||
"""Returns a file by ID if user has access."""
|
||||
|
|
@ -508,7 +634,7 @@ class ComponentObjects:
|
|||
# Get creation date from record or use current time
|
||||
creationDate = file.get("creationDate")
|
||||
if not creationDate:
|
||||
creationDate = get_utc_timestamp()
|
||||
creationDate = getUtcTimestamp()
|
||||
|
||||
return FileItem(
|
||||
id=file.get("id"),
|
||||
|
|
@ -558,7 +684,6 @@ class ComponentObjects:
|
|||
|
||||
def createFile(self, name: str, mimeType: str, content: bytes) -> FileItem:
|
||||
"""Creates a new file entry if user has permission. Computes fileHash and fileSize from content."""
|
||||
import hashlib
|
||||
if not self._canModify(FileItem):
|
||||
raise PermissionError("No permission to create files")
|
||||
|
||||
|
|
@ -655,8 +780,6 @@ class ComponentObjects:
|
|||
def createFileData(self, fileId: str, data: bytes) -> bool:
|
||||
"""Stores the binary data of a file in the database."""
|
||||
try:
|
||||
import base64
|
||||
|
||||
# Check file access
|
||||
file = self.getFile(fileId)
|
||||
if not file:
|
||||
|
|
@ -715,8 +838,6 @@ class ComponentObjects:
|
|||
logger.warning(f"No access to file ID {fileId}")
|
||||
return None
|
||||
|
||||
import base64
|
||||
|
||||
fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
||||
if not fileDataEntries:
|
||||
logger.warning(f"No data found for file ID {fileId}")
|
||||
|
|
@ -791,12 +912,10 @@ class ComponentObjects:
|
|||
encoding = 'latin-1'
|
||||
elif file.mimeType.startswith("image/"):
|
||||
# For images, return base64
|
||||
import base64
|
||||
content = base64.b64encode(fileContent).decode('utf-8')
|
||||
isText = False
|
||||
else:
|
||||
# For other files, return as base64
|
||||
import base64
|
||||
content = base64.b64encode(fileContent).decode('utf-8')
|
||||
isText = False
|
||||
|
||||
|
|
@ -827,7 +946,6 @@ class ComponentObjects:
|
|||
raise ValueError(f"fileContent must be bytes, got {type(fileContent)}")
|
||||
|
||||
# Compute file hash first to check for duplicates
|
||||
import hashlib
|
||||
fileHash = hashlib.sha256(fileContent).hexdigest()
|
||||
|
||||
# Check for exact name+hash match first (same name + same content)
|
||||
|
|
@ -894,13 +1012,11 @@ class ComponentObjects:
|
|||
# Ensure timestamps are set for validation
|
||||
settings_data = filteredSettings[0]
|
||||
if not settings_data.get("creationDate"):
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
settings_data["creationDate"] = get_utc_timestamp()
|
||||
settings_data["creationDate"] = getUtcTimestamp()
|
||||
if not settings_data.get("lastModified"):
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
settings_data["lastModified"] = get_utc_timestamp()
|
||||
settings_data["lastModified"] = getUtcTimestamp()
|
||||
|
||||
return VoiceSettings.from_dict(settings_data)
|
||||
return VoiceSettings(**settings_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting voice settings: {str(e)}")
|
||||
|
|
@ -946,8 +1062,7 @@ class ComponentObjects:
|
|||
raise ValueError(f"Voice settings not found for user {userId}")
|
||||
|
||||
# Update lastModified timestamp
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
updateData["lastModified"] = get_utc_timestamp()
|
||||
updateData["lastModified"] = getUtcTimestamp()
|
||||
|
||||
# Update voice settings record
|
||||
success = self.db.recordModify(VoiceSettings, existingSettings.id, updateData)
|
||||
|
|
@ -960,7 +1075,7 @@ class ComponentObjects:
|
|||
raise ValueError("Failed to retrieve updated voice settings")
|
||||
|
||||
logger.info(f"Updated voice settings for user {userId}")
|
||||
return updatedSettings.to_dict()
|
||||
return updatedSettings.model_dump()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating voice settings: {str(e)}")
|
||||
|
|
@ -1012,7 +1127,7 @@ class ComponentObjects:
|
|||
}
|
||||
|
||||
createdRecord = self.createVoiceSettings(defaultSettings)
|
||||
return VoiceSettings.from_dict(createdRecord)
|
||||
return VoiceSettings(**createdRecord)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting or creating voice settings: {str(e)}")
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
"""Base class for ticket classes."""
|
||||
|
||||
from typing import Any, Dict
|
||||
from pydantic import BaseModel, Field
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class TicketFieldAttribute(BaseModel):
|
||||
field_name: str = Field(description="Human-readable field name")
|
||||
field: str = Field(description="JIRA field ID/key")
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
# A very flexible approach for now. Might want to be more strict in the future.
|
||||
data: Dict[str, Any] = Field(default_factory=dict, description="Task data")
|
||||
|
||||
|
||||
class TicketBase(ABC):
|
||||
@abstractmethod
|
||||
async def read_attributes(self) -> list[TicketFieldAttribute]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def read_tasks(self, limit: int = 0) -> list[Task]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def write_tasks(self, tasklist: list[Task]) -> None: ...
|
||||
File diff suppressed because it is too large
Load diff
504
modules/interfaces/interfaceVoiceObjects.py
Normal file
504
modules/interfaces/interfaceVoiceObjects.py
Normal file
|
|
@ -0,0 +1,504 @@
|
|||
"""
|
||||
Interface for Voice Services
|
||||
Provides a generic interface layer between routes and voice connectors.
|
||||
Handles voice operations including speech-to-text, text-to-speech, and translation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
from modules.connectors.connectorVoiceGoogle import ConnectorGoogleSpeech
|
||||
from modules.datamodels.datamodelVoice import VoiceSettings
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.shared.timezoneUtils import getUtcTimestamp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Singleton factory for Voice instances
|
||||
_instancesVoice = {}
|
||||
|
||||
class VoiceObjects:
|
||||
"""
|
||||
Interface for Voice Services.
|
||||
Provides a generic interface layer between routes and voice connectors.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Voice Interface."""
|
||||
self.currentUser: Optional[User] = None
|
||||
self.userId: Optional[str] = None
|
||||
self._google_speech_connector: Optional[ConnectorGoogleSpeech] = None
|
||||
|
||||
def setUserContext(self, currentUser: User):
|
||||
"""Set the user context for the interface."""
|
||||
if not currentUser:
|
||||
logger.info("Initializing voice interface without user context")
|
||||
return
|
||||
|
||||
self.currentUser = currentUser
|
||||
self.userId = currentUser.id
|
||||
|
||||
if not self.userId:
|
||||
raise ValueError("Invalid user context: id is required")
|
||||
|
||||
logger.debug(f"Voice interface user context set: userId={self.userId}")
|
||||
|
||||
def _getGoogleSpeechConnector(self) -> ConnectorGoogleSpeech:
|
||||
"""Get or create Google Cloud Speech connector instance."""
|
||||
if self._google_speech_connector is None:
|
||||
try:
|
||||
self._google_speech_connector = ConnectorGoogleSpeech()
|
||||
logger.info("✅ Google Cloud Speech connector initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to initialize Google Cloud Speech connector: {e}")
|
||||
raise
|
||||
return self._google_speech_connector
|
||||
|
||||
# Speech-to-Text Operations
|
||||
|
||||
async def speechToText(self, audioContent: bytes, language: str = "de-DE",
|
||||
sampleRate: int = None, channels: int = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert speech to text using Google Cloud Speech-to-Text API.
|
||||
|
||||
Args:
|
||||
audioContent: Raw audio data
|
||||
language: Language code (e.g., 'de-DE', 'en-US')
|
||||
sampleRate: Audio sample rate (auto-detected if None)
|
||||
channels: Number of audio channels (auto-detected if None)
|
||||
|
||||
Returns:
|
||||
Dict containing transcribed text, confidence, and metadata
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🎤 Speech-to-text request: {len(audioContent)} bytes, language: {language}")
|
||||
|
||||
connector = self._getGoogleSpeechConnector()
|
||||
result = await connector.speechToText(
|
||||
audioContent=audioContent,
|
||||
language=language,
|
||||
sampleRate=sampleRate,
|
||||
channels=channels
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
logger.info(f"✅ Speech-to-text successful: '{result['text']}' (confidence: {result['confidence']:.2f})")
|
||||
else:
|
||||
logger.warning(f"⚠️ Speech-to-text failed: {result.get('error', 'Unknown error')}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Speech-to-text error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"text": "",
|
||||
"confidence": 0.0,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Translation Operations
|
||||
|
||||
async def translateText(self, text: str, sourceLanguage: str = "de",
|
||||
targetLanguage: str = "en") -> Dict[str, Any]:
|
||||
"""
|
||||
Translate text using Google Cloud Translation API.
|
||||
|
||||
Args:
|
||||
text: Text to translate
|
||||
sourceLanguage: Source language code (e.g., 'de', 'en')
|
||||
targetLanguage: Target language code (e.g., 'en', 'de')
|
||||
|
||||
Returns:
|
||||
Dict containing translated text and metadata
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🌐 Translation request: '{text}' ({sourceLanguage} -> {targetLanguage})")
|
||||
|
||||
if not text.strip():
|
||||
return {
|
||||
"success": False,
|
||||
"translated_text": "",
|
||||
"error": "Empty text provided"
|
||||
}
|
||||
|
||||
connector = self._getGoogleSpeechConnector()
|
||||
result = await connector.translateText(
|
||||
text=text,
|
||||
sourceLanguage=sourceLanguage,
|
||||
targetLanguage=targetLanguage
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
logger.info(f"✅ Translation successful: '{result['translated_text']}'")
|
||||
else:
|
||||
logger.warning(f"⚠️ Translation failed: {result.get('error', 'Unknown error')}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Translation error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"translated_text": "",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Combined Operations
|
||||
|
||||
async def speechToTranslatedText(self, audioContent: bytes,
|
||||
fromLanguage: str = "de-DE",
|
||||
toLanguage: str = "en") -> Dict[str, Any]:
|
||||
"""
|
||||
Complete pipeline: Speech-to-Text + Translation.
|
||||
|
||||
Args:
|
||||
audioContent: Raw audio data
|
||||
fromLanguage: Source language for speech recognition
|
||||
toLanguage: Target language for translation
|
||||
|
||||
Returns:
|
||||
Dict containing original text, translated text, and metadata
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🔄 Speech-to-translation pipeline: {fromLanguage} -> {toLanguage}")
|
||||
|
||||
connector = self._getGoogleSpeechConnector()
|
||||
result = await connector.speechToTranslatedText(
|
||||
audioContent=audioContent,
|
||||
fromLanguage=fromLanguage,
|
||||
toLanguage=toLanguage
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
logger.info(f"✅ Complete pipeline successful:")
|
||||
logger.info(f" Original: '{result['original_text']}'")
|
||||
logger.info(f" Translated: '{result['translated_text']}'")
|
||||
else:
|
||||
logger.warning(f"⚠️ Speech-to-translation pipeline failed: {result.get('error', 'Unknown error')}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Speech-to-translation pipeline error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"original_text": "",
|
||||
"translated_text": "",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Text-to-Speech Operations
|
||||
|
||||
async def textToSpeech(self, text: str, languageCode: str = "de-DE",
|
||||
voiceName: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert text to speech using Google Cloud Text-to-Speech.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
languageCode: Language code (e.g., 'de-DE', 'en-US')
|
||||
voiceName: Specific voice name (optional)
|
||||
|
||||
Returns:
|
||||
Dict with success status and audio data
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🔊 Text-to-Speech request: '{text[:50]}...' in {languageCode}")
|
||||
|
||||
if not text.strip():
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Empty text provided for text-to-speech"
|
||||
}
|
||||
|
||||
connector = self._getGoogleSpeechConnector()
|
||||
result = await connector.textToSpeech(
|
||||
text=text,
|
||||
languageCode=languageCode,
|
||||
voiceName=voiceName
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
logger.info(f"✅ Text-to-Speech successful: {len(result['audio_content'])} bytes")
|
||||
else:
|
||||
logger.warning(f"⚠️ Text-to-Speech failed: {result.get('error', 'Unknown error')}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Text-to-Speech error: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Voice Settings Management
|
||||
|
||||
def getVoiceSettings(self, userId: str) -> Optional[VoiceSettings]:
|
||||
"""
|
||||
Get voice settings for a user.
|
||||
|
||||
Args:
|
||||
userId: User ID to get settings for
|
||||
|
||||
Returns:
|
||||
VoiceSettings object or None if not found
|
||||
"""
|
||||
try:
|
||||
# This would typically query the database
|
||||
# For now, return None as this is handled by the database interface
|
||||
logger.debug(f"Getting voice settings for user: {userId}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting voice settings: {e}")
|
||||
return None
|
||||
|
||||
def createVoiceSettings(self, settingsData: Dict[str, Any]) -> Optional[VoiceSettings]:
|
||||
"""
|
||||
Create new voice settings.
|
||||
|
||||
Args:
|
||||
settingsData: Dictionary containing voice settings data
|
||||
|
||||
Returns:
|
||||
Created VoiceSettings object or None if failed
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Creating voice settings: {settingsData}")
|
||||
|
||||
# Ensure mandateId is set from user context if not provided
|
||||
if "mandateId" not in settingsData or not settingsData["mandateId"]:
|
||||
if not self.currentUser or not self.currentUser.mandateId:
|
||||
raise ValueError("mandateId is required but not provided and user context has no mandateId")
|
||||
settingsData["mandateId"] = self.currentUser.mandateId
|
||||
|
||||
# Add timestamps
|
||||
currentTime = getUtcTimestamp()
|
||||
settingsData["creationDate"] = currentTime
|
||||
settingsData["lastModified"] = currentTime
|
||||
|
||||
# Create VoiceSettings object
|
||||
voiceSettings = VoiceSettings(**settingsData)
|
||||
|
||||
logger.info(f"✅ Voice settings created: {voiceSettings.id}")
|
||||
return voiceSettings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error creating voice settings: {e}")
|
||||
return None
|
||||
|
||||
def updateVoiceSettings(self, userId: str, settingsData: Dict[str, Any]) -> Optional[VoiceSettings]:
|
||||
"""
|
||||
Update existing voice settings.
|
||||
|
||||
Args:
|
||||
userId: User ID to update settings for
|
||||
settingsData: Dictionary containing updated voice settings data
|
||||
|
||||
Returns:
|
||||
Updated VoiceSettings object or None if failed
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Updating voice settings for user {userId}: {settingsData}")
|
||||
|
||||
# Add last modified timestamp
|
||||
settingsData["lastModified"] = getUtcTimestamp()
|
||||
|
||||
# Create updated VoiceSettings object
|
||||
voiceSettings = VoiceSettings(**settingsData)
|
||||
|
||||
logger.info(f"✅ Voice settings updated: {voiceSettings.id}")
|
||||
return voiceSettings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error updating voice settings: {e}")
|
||||
return None
|
||||
|
||||
def getOrCreateVoiceSettings(self, userId: str) -> Optional[VoiceSettings]:
|
||||
"""
|
||||
Get existing voice settings or create default ones.
|
||||
|
||||
Args:
|
||||
userId: User ID to get/create settings for
|
||||
|
||||
Returns:
|
||||
VoiceSettings object
|
||||
"""
|
||||
try:
|
||||
# Try to get existing settings
|
||||
existingSettings = self.getVoiceSettings(userId)
|
||||
|
||||
if existingSettings:
|
||||
return existingSettings
|
||||
|
||||
# Create default settings if none exist
|
||||
defaultSettings = {
|
||||
"userId": userId,
|
||||
"mandateId": self.currentUser.mandateId,
|
||||
"sttLanguage": "de-DE",
|
||||
"ttsLanguage": "de-DE",
|
||||
"ttsVoice": "de-DE-Wavenet-A",
|
||||
"translationEnabled": True,
|
||||
"targetLanguage": "en-US"
|
||||
}
|
||||
|
||||
return self.createVoiceSettings(defaultSettings)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting or creating voice settings: {e}")
|
||||
return None
|
||||
|
||||
# Language and Voice Information
|
||||
|
||||
async def getAvailableLanguages(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get available languages from Google Cloud Text-to-Speech.
|
||||
|
||||
Returns:
|
||||
Dict containing success status and list of available languages
|
||||
"""
|
||||
try:
|
||||
logger.info("🌐 Getting available languages from Google Cloud TTS")
|
||||
|
||||
connector = self._getGoogleSpeechConnector()
|
||||
result = await connector.getAvailableLanguages()
|
||||
|
||||
if result["success"]:
|
||||
logger.info(f"✅ Found {len(result['languages'])} available languages")
|
||||
else:
|
||||
logger.warning(f"⚠️ Failed to get languages: {result.get('error', 'Unknown error')}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting available languages: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"languages": []
|
||||
}
|
||||
|
||||
async def getAvailableVoices(self, languageCode: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get available voices from Google Cloud Text-to-Speech.
|
||||
|
||||
Args:
|
||||
languageCode: Optional language code to filter voices
|
||||
|
||||
Returns:
|
||||
Dict containing success status and list of available voices
|
||||
"""
|
||||
try:
|
||||
logger.info(f"🎤 Getting available voices, language filter: {languageCode}")
|
||||
|
||||
connector = self._getGoogleSpeechConnector()
|
||||
result = await connector.getAvailableVoices(languageCode=languageCode)
|
||||
|
||||
if result["success"]:
|
||||
logger.info(f"✅ Found {len(result['voices'])} voices for language filter: {languageCode}")
|
||||
else:
|
||||
logger.warning(f"⚠️ Failed to get voices: {result.get('error', 'Unknown error')}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting available voices: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"voices": []
|
||||
}
|
||||
|
||||
# Audio Validation
|
||||
|
||||
def validateAudioFormat(self, audioContent: bytes) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate audio format for Google Cloud Speech-to-Text.
|
||||
|
||||
Args:
|
||||
audioContent: Raw audio data
|
||||
|
||||
Returns:
|
||||
Dict containing validation results
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Validating audio format: {len(audioContent)} bytes")
|
||||
|
||||
connector = self._getGoogleSpeechConnector()
|
||||
result = connector.validateAudioFormat(audioContent)
|
||||
|
||||
if result["valid"]:
|
||||
logger.debug(f"✅ Audio validation successful: {result['format']}, {result['sample_rate']}Hz, {result['channels']}ch")
|
||||
else:
|
||||
logger.warning(f"⚠️ Audio validation failed: {result.get('error', 'Unknown error')}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Audio validation error: {e}")
|
||||
return {
|
||||
"valid": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Health Check
|
||||
|
||||
async def healthCheck(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform health check for voice services.
|
||||
|
||||
Returns:
|
||||
Dict containing health status and test results
|
||||
"""
|
||||
try:
|
||||
logger.info("🏥 Performing voice services health check")
|
||||
|
||||
connector = self._getGoogleSpeechConnector()
|
||||
|
||||
# Test with a simple translation
|
||||
testResult = await connector.translateText(
|
||||
text="Hello",
|
||||
sourceLanguage="en",
|
||||
targetLanguage="de"
|
||||
)
|
||||
|
||||
if testResult["success"]:
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "Google Cloud Speech-to-Text & Translation",
|
||||
"test_translation": testResult["translated_text"]
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": testResult.get("error", "Unknown error")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Health check failed: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def getVoiceInterface(currentUser: User = None) -> VoiceObjects:
|
||||
"""
|
||||
Factory function to get or create Voice interface instance.
|
||||
|
||||
Args:
|
||||
currentUser: User object for context (optional)
|
||||
|
||||
Returns:
|
||||
VoiceObjects instance
|
||||
"""
|
||||
# For now, create a new instance each time
|
||||
# In the future, this could be enhanced with singleton pattern per user
|
||||
voiceInterface = VoiceObjects()
|
||||
|
||||
if currentUser:
|
||||
voiceInterface.setUserContext(currentUser)
|
||||
|
||||
return voiceInterface
|
||||
|
|
@ -1,140 +0,0 @@
|
|||
"""Base class for web classes."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from modules.interfaces.interfaceChatModel import ActionDocument, ActionResult
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
from typing import List
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
|
||||
# Configuration loading functions
|
||||
def get_web_search_max_query_length() -> int:
|
||||
"""Get maximum query length from configuration"""
|
||||
return int(APP_CONFIG.get("Web_Search_MAX_QUERY_LENGTH", "400"))
|
||||
|
||||
|
||||
def get_web_search_max_results() -> int:
|
||||
"""Get maximum search results from configuration"""
|
||||
return int(APP_CONFIG.get("Web_Search_MAX_RESULTS", "20"))
|
||||
|
||||
|
||||
def get_web_search_min_results() -> int:
|
||||
"""Get minimum search results from configuration"""
|
||||
return int(APP_CONFIG.get("Web_Search_MIN_RESULTS", "1"))
|
||||
|
||||
|
||||
# --- Web search ---
|
||||
|
||||
# query -> list of URLs
|
||||
|
||||
|
||||
class WebSearchRequest(BaseModel):
|
||||
query: str = Field(min_length=1, max_length=get_web_search_max_query_length())
|
||||
max_results: int = Field(ge=get_web_search_min_results(), le=get_web_search_max_results())
|
||||
|
||||
|
||||
class WebSearchResultItem(BaseModel):
|
||||
"""Individual search result"""
|
||||
|
||||
title: str
|
||||
url: HttpUrl
|
||||
|
||||
|
||||
class WebSearchDocumentData(BaseModel):
|
||||
"""Complete search results document"""
|
||||
|
||||
query: str = Field(min_length=1, max_length=get_web_search_max_query_length())
|
||||
results: List[WebSearchResultItem]
|
||||
total_count: int
|
||||
|
||||
|
||||
class WebSearchActionDocument(ActionDocument):
|
||||
documentData: WebSearchDocumentData
|
||||
|
||||
|
||||
class WebSearchActionResult(ActionResult):
|
||||
documents: List[WebSearchActionDocument] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WebSearchBase(ABC):
|
||||
@abstractmethod
|
||||
async def search_urls(self, request: WebSearchRequest) -> WebSearchActionResult: ...
|
||||
|
||||
|
||||
# --- Web crawl ---
|
||||
|
||||
# list of URLs -> list of extracted HTML content
|
||||
|
||||
|
||||
class WebCrawlRequest(BaseModel):
|
||||
urls: List[HttpUrl]
|
||||
|
||||
|
||||
class WebCrawlResultItem(BaseModel):
|
||||
"""Individual crawl result"""
|
||||
|
||||
url: HttpUrl
|
||||
content: str
|
||||
|
||||
|
||||
class WebCrawlDocumentData(BaseModel):
|
||||
"""Complete crawl results document"""
|
||||
|
||||
urls: List[HttpUrl]
|
||||
results: List[WebCrawlResultItem]
|
||||
total_count: int
|
||||
|
||||
|
||||
class WebCrawlActionDocument(ActionDocument):
|
||||
documentData: WebCrawlDocumentData = Field(
|
||||
description="The data extracted from crawled URLs"
|
||||
)
|
||||
|
||||
|
||||
class WebCrawlActionResult(ActionResult):
|
||||
documents: List[WebCrawlActionDocument] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WebCrawlBase(ABC):
|
||||
@abstractmethod
|
||||
async def crawl_urls(self, request: WebCrawlRequest) -> WebCrawlActionResult: ...
|
||||
|
||||
|
||||
# --- Web scrape ---
|
||||
|
||||
# scrape -> list of extracted text; combines web search and crawl in one step
|
||||
|
||||
|
||||
class WebScrapeRequest(BaseModel):
|
||||
query: str = Field(min_length=1, max_length=get_web_search_max_query_length())
|
||||
max_results: int = Field(ge=get_web_search_min_results(), le=get_web_search_max_results())
|
||||
|
||||
|
||||
class WebScrapeResultItem(BaseModel):
|
||||
"""Individual scrape result"""
|
||||
|
||||
url: HttpUrl
|
||||
content: str
|
||||
|
||||
|
||||
class WebScrapeDocumentData(BaseModel):
|
||||
"""Complete scrape results document"""
|
||||
|
||||
query: str = Field(min_length=1, max_length=get_web_search_max_query_length())
|
||||
results: List[WebScrapeResultItem]
|
||||
total_count: int
|
||||
|
||||
|
||||
class WebScrapeActionDocument(ActionDocument):
|
||||
documentData: WebScrapeDocumentData = Field(
|
||||
description="The data extracted from scraped URLs"
|
||||
)
|
||||
|
||||
|
||||
class WebScrapeActionResult(ActionResult):
|
||||
documents: List[WebScrapeActionDocument] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WebScrapeBase(ABC):
|
||||
@abstractmethod
|
||||
async def scrape(self, request: WebScrapeRequest) -> WebScrapeActionResult: ...
|
||||
|
|
@ -1,118 +0,0 @@
|
|||
from typing import Optional
|
||||
import json
|
||||
import csv
|
||||
import io
|
||||
from modules.interfaces.interfaceWebModel import (
|
||||
WebCrawlActionResult,
|
||||
WebSearchActionResult,
|
||||
WebSearchRequest,
|
||||
WebCrawlRequest,
|
||||
WebScrapeActionResult,
|
||||
WebScrapeRequest,
|
||||
WebCrawlDocumentData,
|
||||
WebScrapeDocumentData,
|
||||
WebSearchDocumentData,
|
||||
)
|
||||
|
||||
from dataclasses import dataclass
|
||||
from modules.connectors.connectorWebTavily import ConnectorTavily
|
||||
from modules.interfaces.interfaceChatModel import ActionDocument
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class WebInterface:
|
||||
connectorWebTavily: ConnectorTavily
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.connectorWebTavily is None:
|
||||
raise TypeError(
|
||||
"connectorWebTavily must be provided. "
|
||||
"Use `await WebInterface.create()` or pass a ConnectorTavily."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def create(cls) -> "WebInterface":
|
||||
connectorWebTavily = await ConnectorTavily.create()
|
||||
|
||||
return WebInterface(connectorWebTavily=connectorWebTavily)
|
||||
|
||||
async def search(
|
||||
self, web_search_request: WebSearchRequest
|
||||
) -> WebSearchActionResult:
|
||||
# NOTE: Add connectors here
|
||||
return await self.connectorWebTavily.search_urls(web_search_request)
|
||||
|
||||
async def crawl(self, web_crawl_request: WebCrawlRequest) -> WebCrawlActionResult:
|
||||
# NOTE: Add connectors here
|
||||
return await self.connectorWebTavily.crawl_urls(web_crawl_request)
|
||||
|
||||
async def scrape(
|
||||
self, web_scrape_request: WebScrapeRequest
|
||||
) -> WebScrapeActionResult:
|
||||
# NOTE: Add connectors here
|
||||
return await self.connectorWebTavily.scrape(web_scrape_request)
|
||||
|
||||
def convert_web_result_to_json(self, web_result) -> str:
|
||||
"""Convert WebCrawlActionResult or WebScrapeActionResult to proper JSON format"""
|
||||
if not web_result.success or not web_result.documents:
|
||||
return json.dumps({"success": web_result.success, "error": web_result.error})
|
||||
|
||||
# Extract the document data and convert to dict
|
||||
document_data = web_result.documents[0].documentData
|
||||
|
||||
# Convert Pydantic model to dict
|
||||
result_dict = {
|
||||
"success": web_result.success,
|
||||
"results": [
|
||||
{
|
||||
"url": str(result.url),
|
||||
"content": result.content
|
||||
}
|
||||
for result in document_data.results
|
||||
],
|
||||
"total_count": document_data.total_count
|
||||
}
|
||||
|
||||
# Add type-specific fields
|
||||
if hasattr(document_data, 'urls'):
|
||||
# WebCrawlDocumentData has urls field
|
||||
result_dict["urls"] = [str(url) for url in document_data.urls]
|
||||
elif hasattr(document_data, 'query'):
|
||||
# WebScrapeDocumentData has query field
|
||||
result_dict["query"] = document_data.query
|
||||
|
||||
return json.dumps(result_dict, indent=2, ensure_ascii=False)
|
||||
|
||||
def convert_web_search_result_to_csv(self, web_search_result: WebSearchActionResult) -> str:
|
||||
"""Convert WebSearchActionResult to CSV format with url and title columns"""
|
||||
if not web_search_result.success or not web_search_result.documents:
|
||||
return ""
|
||||
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output, delimiter=';')
|
||||
|
||||
# Write header
|
||||
writer.writerow(['url', 'title'])
|
||||
|
||||
# Write data rows
|
||||
document_data = web_search_result.documents[0].documentData
|
||||
for result in document_data.results:
|
||||
writer.writerow([str(result.url), result.title])
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
def create_json_action_document(self, json_content: str, document_name: str) -> ActionDocument:
|
||||
"""Create an ActionDocument with JSON content"""
|
||||
return ActionDocument(
|
||||
documentName=document_name,
|
||||
documentData=json_content,
|
||||
mimeType="application/json"
|
||||
)
|
||||
|
||||
def create_csv_action_document(self, csv_content: str, document_name: str) -> ActionDocument:
|
||||
"""Create an ActionDocument with CSV content"""
|
||||
return ActionDocument(
|
||||
documentName=document_name,
|
||||
documentData=csv_content,
|
||||
mimeType="text/csv"
|
||||
)
|
||||
|
|
@ -1,201 +0,0 @@
|
|||
"""
|
||||
AI processing method module.
|
||||
Handles direct AI calls for any type of task.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from modules.chat.methodBase import MethodBase, action
|
||||
from modules.interfaces.interfaceChatModel import ActionResult
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MethodAi(MethodBase):
|
||||
"""AI processing methods."""
|
||||
|
||||
def __init__(self, service):
|
||||
super().__init__(service)
|
||||
self.name = "ai"
|
||||
self.description = "AI processing methods"
|
||||
|
||||
def _format_timestamp_for_filename(self) -> str:
|
||||
"""Format current timestamp as YYYYMMDD-hhmmss for filenames."""
|
||||
return datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
|
||||
|
||||
@action
|
||||
async def process(self, parameters: Dict[str, Any]) -> ActionResult:
|
||||
"""
|
||||
Perform an AI call for any type of task with optional document references
|
||||
|
||||
Parameters:
|
||||
aiPrompt (str): The AI prompt for processing
|
||||
documentList (list, optional): List of document references to include in context
|
||||
expectedDocumentFormats (list, optional): Expected output formats with extension, mimeType, description
|
||||
processingMode (str, optional): Processing mode ('basic', 'advanced', 'detailed') - defaults to 'basic'
|
||||
includeMetadata (bool, optional): Whether to include metadata (default: True)
|
||||
customInstructions (str, optional): Additional custom instructions for the AI
|
||||
"""
|
||||
try:
|
||||
aiPrompt = parameters.get("aiPrompt")
|
||||
documentList = parameters.get("documentList", [])
|
||||
expectedDocumentFormats = parameters.get("expectedDocumentFormats", [])
|
||||
processingMode = parameters.get("processingMode", "basic")
|
||||
includeMetadata = parameters.get("includeMetadata", True)
|
||||
customInstructions = parameters.get("customInstructions", "")
|
||||
|
||||
if not aiPrompt:
|
||||
return ActionResult.isFailure(
|
||||
error="AI prompt is required"
|
||||
)
|
||||
|
||||
# Determine output format first (needed for context building)
|
||||
output_extension = ".txt" # Default
|
||||
output_mime_type = "text/plain" # Default
|
||||
|
||||
if expectedDocumentFormats and len(expectedDocumentFormats) > 0:
|
||||
expected_format = expectedDocumentFormats[0]
|
||||
output_extension = expected_format.get("extension", ".txt")
|
||||
output_mime_type = expected_format.get("mimeType", "text/plain")
|
||||
logger.info(f"Using expected format: {output_extension} ({output_mime_type})")
|
||||
|
||||
# Build context from documents if provided
|
||||
context = ""
|
||||
if documentList:
|
||||
chatDocuments = self.service.getChatDocumentsFromDocumentList(documentList)
|
||||
if chatDocuments:
|
||||
context_parts = []
|
||||
for doc in chatDocuments:
|
||||
file_info = self.service.getFileInfo(doc.fileId)
|
||||
|
||||
try:
|
||||
# Use the document content extraction service with the specific AI prompt context
|
||||
# This tells the extraction engine exactly what and how to extract
|
||||
extraction_prompt = f"""
|
||||
Extract content from this document for AI processing context.
|
||||
|
||||
AI Task: {aiPrompt}
|
||||
Processing Mode: {processingMode}
|
||||
Expected Output: {output_extension.upper()} format
|
||||
|
||||
Requirements:
|
||||
1. Extract the most relevant text content that would be useful for the AI task
|
||||
2. Focus on content that directly relates to: {aiPrompt}
|
||||
3. Include key information, data, and insights that the AI needs
|
||||
4. Provide clean, readable text without formatting artifacts
|
||||
|
||||
Document: {doc.fileName}
|
||||
"""
|
||||
|
||||
logger.debug(f"Extracting content from {doc.fileName} with task-specific prompt: {extraction_prompt[:100]}...")
|
||||
|
||||
extracted_content = await self.service.extractContentFromDocument(
|
||||
prompt=extraction_prompt.strip(),
|
||||
document=doc
|
||||
)
|
||||
|
||||
if extracted_content and extracted_content.contents:
|
||||
# Get the first content item's data
|
||||
content = ""
|
||||
for content_item in extracted_content.contents:
|
||||
if hasattr(content_item, 'data') and content_item.data:
|
||||
content += content_item.data + " "
|
||||
|
||||
|
||||
if content.strip():
|
||||
metadata_info = ""
|
||||
if file_info and includeMetadata:
|
||||
metadata_info = f" (Size: {file_info.get('fileSize', 'unknown')}, Type: {file_info.get('mimeType', 'unknown')})"
|
||||
|
||||
# Adjust context length based on processing mode and AI task relevance
|
||||
base_length = 5000 if processingMode == "detailed" else 3000 if processingMode == "advanced" else 2000
|
||||
|
||||
# For detailed mode, include more context
|
||||
if processingMode == "detailed":
|
||||
context_parts.append(f"Document: {doc.fileName}{metadata_info}\nRelevance to AI Task: This document contains content directly related to '{aiPrompt[:100]}...'\nContent:\n{content[:base_length]}...")
|
||||
else:
|
||||
context_parts.append(f"Document: {doc.fileName}{metadata_info}\nContent:\n{content[:base_length]}...")
|
||||
else:
|
||||
context_parts.append(f"Document: {doc.fileName} [No readable text content - binary file]")
|
||||
else:
|
||||
context_parts.append(f"Document: {doc.fileName} [No readable text content - binary file]")
|
||||
|
||||
except Exception as extract_error:
|
||||
context_parts.append(f"Document: {doc.fileName} [Could not extract content - binary file]")
|
||||
|
||||
if context_parts:
|
||||
# Add a summary header to help the AI understand the context
|
||||
context_header = f"""
|
||||
=== DOCUMENT CONTEXT FOR AI PROCESSING ===
|
||||
AI Task: {aiPrompt[:100]}...
|
||||
Processing Mode: {processingMode}
|
||||
Expected Output Format: {output_extension.upper()}
|
||||
Total Documents: {len(chatDocuments)}
|
||||
|
||||
The following documents contain content relevant to your task.
|
||||
Use this information to provide the most accurate and helpful response.
|
||||
================================================
|
||||
"""
|
||||
|
||||
context = context_header + "\n\n" + "\n\n".join(context_parts)
|
||||
logger.info(f"Included {len(chatDocuments)} documents in AI context with task-specific extraction")
|
||||
|
||||
# Build enhanced prompt
|
||||
enhanced_prompt = aiPrompt
|
||||
|
||||
# Add processing mode instructions if specified (generic, not analysis-specific)
|
||||
if processingMode == "detailed":
|
||||
enhanced_prompt += "\n\nPlease provide a detailed response with comprehensive information."
|
||||
elif processingMode == "advanced":
|
||||
enhanced_prompt += "\n\nPlease provide an advanced response with deep insights."
|
||||
|
||||
# Add custom instructions if provided
|
||||
if customInstructions:
|
||||
enhanced_prompt += f"\n\nAdditional Instructions: {customInstructions}"
|
||||
|
||||
# Add format-specific instructions only if non-text format is requested
|
||||
if output_extension != ".txt":
|
||||
if output_extension == ".csv":
|
||||
enhanced_prompt += f"\n\nCRITICAL: Deliver the result as pure CSV data without any markdown formatting, code blocks, or additional text. Output only the CSV content with proper headers and data rows."
|
||||
elif output_extension == ".json":
|
||||
enhanced_prompt += f"\n\nCRITICAL: Deliver the result as pure JSON data without any markdown formatting, code blocks, or additional text. Output only the JSON content."
|
||||
elif output_extension == ".xml":
|
||||
enhanced_prompt += f"\n\nCRITICAL: Deliver the result as pure XML data without any markdown formatting, code blocks, or additional text. Output only the XML content."
|
||||
else:
|
||||
enhanced_prompt += f"\n\nCRITICAL: Deliver the result as pure {output_extension.upper()} data without any markdown formatting, code blocks, or additional text."
|
||||
|
||||
# Call appropriate AI service based on processing mode
|
||||
logger.info(f"Executing AI call with mode: {processingMode}, prompt length: {len(enhanced_prompt)}")
|
||||
if context:
|
||||
logger.info(f"Including context from {len(documentList)} documents")
|
||||
|
||||
if processingMode in ["advanced", "detailed"]:
|
||||
result = await self.service.callAiTextAdvanced(enhanced_prompt, context)
|
||||
else:
|
||||
result = await self.service.callAiTextBasic(enhanced_prompt, context)
|
||||
|
||||
# Create result document
|
||||
fileName = f"ai_{processingMode}_{self._format_timestamp_for_filename()}{output_extension}"
|
||||
|
||||
|
||||
|
||||
# Return result in the standard ActionResult format
|
||||
return ActionResult.isSuccess(
|
||||
documents=[{
|
||||
"documentName": fileName,
|
||||
"documentData": {
|
||||
"result": result,
|
||||
"fileName": fileName,
|
||||
"processedDocuments": len(documentList) if documentList else 0
|
||||
},
|
||||
"mimeType": output_mime_type
|
||||
}]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in AI processing: {str(e)}")
|
||||
return ActionResult.isFailure(
|
||||
error=str(e)
|
||||
)
|
||||
|
|
@ -1,789 +0,0 @@
|
|||
"""
|
||||
Document processing method module.
|
||||
Handles document operations using the document service.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from modules.chat.methodBase import MethodBase, action
|
||||
from modules.interfaces.interfaceChatModel import ActionResult
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MethodDocument(MethodBase):
|
||||
"""Document method implementation for document operations"""
|
||||
|
||||
def __init__(self, serviceCenter: Any):
|
||||
"""Initialize the document method"""
|
||||
super().__init__(serviceCenter)
|
||||
self.name = "document"
|
||||
self.description = "Handle document operations like extraction and analysis"
|
||||
|
||||
def _format_timestamp_for_filename(self) -> str:
|
||||
"""Format current timestamp as YYYYMMDD-hhmmss for filenames."""
|
||||
return datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
|
||||
|
||||
@action
|
||||
async def extract(self, parameters: Dict[str, Any]) -> ActionResult:
|
||||
"""
|
||||
Extract content from any document using AI prompt.
|
||||
|
||||
Parameters:
|
||||
documentList (str): Document list reference
|
||||
aiPrompt (str): AI prompt for extraction
|
||||
expectedDocumentFormats (list, optional): Output formats
|
||||
includeMetadata (bool, optional): Include metadata (default: True)
|
||||
"""
|
||||
try:
|
||||
documentList = parameters.get("documentList")
|
||||
aiPrompt = parameters.get("aiPrompt")
|
||||
expectedDocumentFormats = parameters.get("expectedDocumentFormats", [])
|
||||
includeMetadata = parameters.get("includeMetadata", True)
|
||||
|
||||
if not documentList:
|
||||
return ActionResult.isFailure(
|
||||
error="Document list reference is required"
|
||||
)
|
||||
|
||||
if not aiPrompt:
|
||||
return ActionResult.isFailure(
|
||||
error="AI prompt is required"
|
||||
)
|
||||
|
||||
chatDocuments = self.service.getChatDocumentsFromDocumentList(documentList)
|
||||
if not chatDocuments:
|
||||
return ActionResult.isFailure(
|
||||
error="No documents found for the provided reference"
|
||||
)
|
||||
|
||||
# Extract content from all documents using AI
|
||||
all_extracted_content = []
|
||||
file_infos = []
|
||||
|
||||
for chatDocument in chatDocuments:
|
||||
file_info = self.service.getFileInfo(chatDocument.fileId)
|
||||
|
||||
try:
|
||||
# Use the document content extraction service with the specific AI prompt
|
||||
# This handles all document types (text, binary, image, etc.) intelligently
|
||||
extracted_content = await self.service.extractContentFromDocument(
|
||||
prompt=aiPrompt,
|
||||
document=chatDocument
|
||||
)
|
||||
|
||||
if extracted_content and extracted_content.contents:
|
||||
all_extracted_content.append(extracted_content)
|
||||
if includeMetadata:
|
||||
file_infos.append(file_info)
|
||||
logger.info(f"Successfully extracted content from {chatDocument.fileName}")
|
||||
else:
|
||||
logger.warning(f"No content extracted from {chatDocument.fileName}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content from {chatDocument.fileName}: {str(e)}")
|
||||
continue
|
||||
|
||||
if not all_extracted_content:
|
||||
return ActionResult.isFailure(
|
||||
error="No content could be extracted from any documents"
|
||||
)
|
||||
|
||||
# Process each document individually with its own format conversion
|
||||
output_documents = []
|
||||
|
||||
for i, (chatDocument, extracted_content) in enumerate(zip(chatDocuments, all_extracted_content)):
|
||||
# Extract text content from this document
|
||||
text_content = ""
|
||||
if hasattr(extracted_content, 'contents') and extracted_content.contents:
|
||||
# Extract text from ContentItem objects
|
||||
text_parts = []
|
||||
for content_item in extracted_content.contents:
|
||||
if hasattr(content_item, 'data') and content_item.data:
|
||||
text_parts.append(content_item.data)
|
||||
text_content = "\n".join(text_parts)
|
||||
elif isinstance(extracted_content, str):
|
||||
text_content = extracted_content
|
||||
else:
|
||||
text_content = str(extracted_content)
|
||||
|
||||
# Get the expected format for this document (or use default)
|
||||
target_format = None
|
||||
if expectedDocumentFormats and i < len(expectedDocumentFormats):
|
||||
target_format = expectedDocumentFormats[i]
|
||||
elif expectedDocumentFormats and len(expectedDocumentFormats) > 0:
|
||||
# If fewer formats than documents, use the last format for remaining documents
|
||||
target_format = expectedDocumentFormats[-1]
|
||||
|
||||
# Determine output format and fileName
|
||||
if target_format:
|
||||
target_extension = target_format.get("extension", ".txt")
|
||||
target_mime_type = target_format.get("mimeType", "text/plain")
|
||||
|
||||
# Check if format conversion is needed
|
||||
if target_extension not in [".txt", ".text"] or target_mime_type != "text/plain":
|
||||
logger.info(f"Converting document {i+1} to format: {target_extension} ({target_mime_type})")
|
||||
# Use AI to convert format
|
||||
formatted_content = await self._convertContentToFormat(text_content, target_format)
|
||||
final_content = formatted_content
|
||||
final_mime_type = target_mime_type
|
||||
final_extension = target_extension
|
||||
else:
|
||||
logger.info(f"Document {i+1}: No format conversion needed, using plain text")
|
||||
final_content = text_content
|
||||
final_mime_type = "text/plain"
|
||||
final_extension = ".txt"
|
||||
else:
|
||||
logger.info(f"Document {i+1}: No expected format specified, using plain text")
|
||||
final_content = text_content
|
||||
final_mime_type = "text/plain"
|
||||
final_extension = ".txt"
|
||||
|
||||
# Create output fileName based on original fileName and target format
|
||||
original_fileName = chatDocument.fileName
|
||||
base_name = original_fileName.rsplit('.', 1)[0] if '.' in original_fileName else original_fileName
|
||||
output_fileName = f"{base_name}_extracted_{self._format_timestamp_for_filename()}{final_extension}"
|
||||
|
||||
# Create result data for this document
|
||||
result_data = {
|
||||
"documentCount": 1,
|
||||
"content": final_content,
|
||||
"originalfileName": original_fileName,
|
||||
"fileInfos": [file_infos[i]] if includeMetadata and i < len(file_infos) else None,
|
||||
"timestamp": get_utc_timestamp()
|
||||
}
|
||||
|
||||
logger.info(f"Created output document: {output_fileName} with {len(final_content)} characters")
|
||||
|
||||
output_documents.append({
|
||||
"documentName": output_fileName,
|
||||
"documentData": result_data,
|
||||
"mimeType": final_mime_type
|
||||
})
|
||||
|
||||
return ActionResult.isSuccess(
|
||||
documents=output_documents
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content: {str(e)}")
|
||||
return ActionResult.isFailure(
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@action
|
||||
async def generate(self, parameters: Dict[str, Any]) -> ActionResult:
|
||||
"""
|
||||
Convert TEXT-ONLY documents to target formats (NO AI usage).
|
||||
|
||||
Parameters:
|
||||
documentList (list): TEXT-ONLY documents only
|
||||
expectedDocumentFormats (list): Target formats
|
||||
originalDocuments (list, optional): Original names
|
||||
includeMetadata (bool, optional): Include metadata (default: True)
|
||||
mergeDocuments (bool, optional): Merge all documents into single output (default: False)
|
||||
"""
|
||||
try:
|
||||
document_list = parameters.get("documentList", [])
|
||||
expected_document_formats = parameters.get("expectedDocumentFormats", [])
|
||||
original_documents = parameters.get("originalDocuments", [])
|
||||
include_metadata = parameters.get("includeMetadata", True)
|
||||
merge_documents = parameters.get("mergeDocuments", False)
|
||||
|
||||
if not document_list:
|
||||
return ActionResult.isFailure(
|
||||
error="Document list is required for generation"
|
||||
)
|
||||
|
||||
if not expected_document_formats or len(expected_document_formats) == 0:
|
||||
return ActionResult.isFailure(
|
||||
error="Expected document formats specification is required"
|
||||
)
|
||||
|
||||
# Get chat documents for original documents list
|
||||
chat_documents = self.service.getChatDocumentsFromDocumentList(document_list)
|
||||
logger.info(f"Found {len(chat_documents)} chat documents")
|
||||
|
||||
if not chat_documents:
|
||||
return ActionResult.isFailure(
|
||||
error="No documents found for the provided documentList reference"
|
||||
)
|
||||
|
||||
# Update original documents list if not provided
|
||||
if not original_documents:
|
||||
original_documents = [doc.fileName if hasattr(doc, 'fileName') else str(doc.id) for doc in chat_documents]
|
||||
|
||||
# Extract content from all documents first
|
||||
document_contents = []
|
||||
for i, chat_document in enumerate(chat_documents):
|
||||
# Extract content from this document directly - NO AI, just read the data as-is
|
||||
# This ensures we get the original text content for format conversion
|
||||
content = ""
|
||||
if hasattr(chat_document, 'fileId') and chat_document.fileId:
|
||||
try:
|
||||
# Get file data directly without AI processing
|
||||
file_data = self.service.getFileData(chat_document.fileId)
|
||||
if file_data:
|
||||
# Check if it's text data and convert to string
|
||||
if isinstance(file_data, bytes):
|
||||
try:
|
||||
# Try to decode as UTF-8 to check if it's text
|
||||
content = file_data.decode('utf-8')
|
||||
logger.info(f"Document {i+1} ({chat_document.fileName}): Successfully decoded as UTF-8 text")
|
||||
except UnicodeDecodeError:
|
||||
logger.info(f"Document {i+1} ({chat_document.fileName}): Binary data, not text - skipping")
|
||||
continue
|
||||
else:
|
||||
# Already a string
|
||||
content = str(file_data)
|
||||
logger.info(f"Document {i+1} ({chat_document.fileName}): Already text data")
|
||||
else:
|
||||
logger.warning(f"Document {i+1} ({chat_document.fileName}): No file data found")
|
||||
continue
|
||||
|
||||
if not content.strip():
|
||||
logger.info(f"Document {i+1} ({chat_document.fileName}): Empty text content, skipping")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading document {i+1} ({chat_document.fileName}): {str(e)}")
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"Document {i+1} has no fileId, skipping")
|
||||
continue
|
||||
|
||||
logger.info(f"Extracted content from document {i+1}: {len(content)} characters")
|
||||
|
||||
document_contents.append({
|
||||
"document": chat_document,
|
||||
"content": content,
|
||||
"index": i,
|
||||
"original_name": original_documents[i] if i < len(original_documents) else f"document_{i+1}"
|
||||
})
|
||||
|
||||
if not document_contents:
|
||||
return ActionResult.isFailure(
|
||||
error="No valid text content could be extracted from any documents"
|
||||
)
|
||||
|
||||
if merge_documents and len(document_contents) > 1:
|
||||
# Merge all documents into single output
|
||||
logger.info("Merging all documents into single output")
|
||||
return await self._mergeDocuments(document_contents, expected_document_formats, include_metadata)
|
||||
else:
|
||||
# Process each document individually with its own format conversion
|
||||
logger.info("Processing documents individually")
|
||||
output_documents = []
|
||||
|
||||
for item in document_contents:
|
||||
chat_document = item["document"]
|
||||
content = item["content"]
|
||||
i = item["index"]
|
||||
original_name = item["original_name"]
|
||||
|
||||
# Get the expected format for this document (or use default)
|
||||
target_format = None
|
||||
if i < len(expected_document_formats):
|
||||
target_format = expected_document_formats[i]
|
||||
elif len(expected_document_formats) > 0:
|
||||
# If fewer formats than documents, use the last format for remaining documents
|
||||
target_format = expected_document_formats[-1]
|
||||
|
||||
if not target_format:
|
||||
logger.warning(f"No expected format for document {i+1}, skipping")
|
||||
continue
|
||||
|
||||
# Use AI to convert format
|
||||
formatted_content = await self._convertContentToFormat(content, target_format)
|
||||
if not formatted_content:
|
||||
logger.warning(f"Failed to format document {i+1}, skipping")
|
||||
continue
|
||||
|
||||
target_extension = target_format.get("extension", ".txt")
|
||||
target_mime_type = target_format.get("mimeType", "text/plain")
|
||||
|
||||
# Create output fileName
|
||||
base_name = original_name.rsplit('.', 1)[0] if '.' in original_name else original_name
|
||||
output_fileName = f"{base_name}_generated_{self._format_timestamp_for_filename()}{target_extension}"
|
||||
|
||||
# Create result data
|
||||
result_data = {
|
||||
"documentCount": 1,
|
||||
"content": formatted_content,
|
||||
"outputFormat": target_format,
|
||||
"originalDocument": original_name,
|
||||
"timestamp": get_utc_timestamp()
|
||||
}
|
||||
|
||||
logger.info(f"Generated document: {output_fileName} with {len(formatted_content)} characters")
|
||||
|
||||
output_documents.append({
|
||||
"documentName": output_fileName,
|
||||
"documentData": result_data,
|
||||
"mimeType": target_mime_type
|
||||
})
|
||||
|
||||
if not output_documents:
|
||||
return ActionResult.isFailure(
|
||||
error="No documents could be generated"
|
||||
)
|
||||
|
||||
return ActionResult.isSuccess(
|
||||
documents=output_documents
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating document: {str(e)}")
|
||||
return ActionResult.isFailure(
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def _mergeDocuments(self, document_contents: List[Dict[str, Any]],
|
||||
expected_document_formats: List[Dict[str, Any]],
|
||||
include_metadata: bool) -> ActionResult:
|
||||
"""
|
||||
Merge all documents into a single output document.
|
||||
"""
|
||||
try:
|
||||
# Combine all document content
|
||||
combined_content_parts = []
|
||||
original_file_names = []
|
||||
|
||||
for item in document_contents:
|
||||
chat_document = item["document"]
|
||||
content = item["content"]
|
||||
original_name = item["original_name"]
|
||||
|
||||
if content.strip():
|
||||
combined_content_parts.append(f"=== Document: {original_name} ===\n{content}\n")
|
||||
original_file_names.append(original_name)
|
||||
|
||||
if not combined_content_parts:
|
||||
return ActionResult.isFailure(
|
||||
error="No content could be extracted from any documents for merging"
|
||||
)
|
||||
|
||||
# Combine all content
|
||||
combined_content = "\n".join(combined_content_parts)
|
||||
logger.info(f"Combined content from {len(original_file_names)} documents: {len(combined_content)} characters")
|
||||
|
||||
# Get the expected format for the merged output
|
||||
target_format = None
|
||||
if expected_document_formats and len(expected_document_formats) > 0:
|
||||
target_format = expected_document_formats[0] # Use first format for merged output
|
||||
|
||||
if not target_format:
|
||||
logger.warning("No expected format specified for merged output, using plain text")
|
||||
target_format = {"extension": ".txt", "mimeType": "text/plain"}
|
||||
|
||||
# Use AI to convert format
|
||||
formatted_content = await self._convertContentToFormat(combined_content, target_format)
|
||||
if not formatted_content:
|
||||
logger.warning("Failed to format merged content, using raw content")
|
||||
formatted_content = combined_content
|
||||
|
||||
target_extension = target_format.get("extension", ".txt")
|
||||
target_mime_type = target_format.get("mimeType", "text/plain")
|
||||
|
||||
# Create output fileName for merged document
|
||||
timestamp = self._format_timestamp_for_filename()
|
||||
output_fileName = f"merged_documents_{timestamp}{target_extension}"
|
||||
|
||||
# Create result data for merged document
|
||||
result_data = {
|
||||
"documentCount": len(document_contents),
|
||||
"content": formatted_content,
|
||||
"outputFormat": target_format,
|
||||
"originalDocuments": original_file_names,
|
||||
"timestamp": get_utc_timestamp(),
|
||||
"merged": True
|
||||
}
|
||||
|
||||
logger.info(f"Created merged document: {output_fileName} with {len(formatted_content)} characters")
|
||||
|
||||
return ActionResult.isSuccess(
|
||||
documents=[{
|
||||
"documentName": output_fileName,
|
||||
"documentData": result_data,
|
||||
"mimeType": target_mime_type
|
||||
}]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error merging documents: {str(e)}")
|
||||
return ActionResult.isFailure(
|
||||
error=f"Failed to merge documents: {str(e)}"
|
||||
)
|
||||
|
||||
async def _convertContentToFormat(self, content: str, target_format: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Helper function to convert content to the specified format using AI.
|
||||
"""
|
||||
try:
|
||||
extension = target_format.get("extension", ".txt")
|
||||
mime_type = target_format.get("mimeType", "text/plain")
|
||||
|
||||
logger.info(f"Converting content to format: {extension} ({mime_type})")
|
||||
|
||||
# Create AI prompt for format conversion
|
||||
format_prompts = {
|
||||
".csv": f"""
|
||||
Convert the following content into a proper CSV format.
|
||||
|
||||
Requirements:
|
||||
1. Output ONLY the CSV data without any markdown, code blocks, or additional text
|
||||
2. Use appropriate headers based on the content
|
||||
3. Ensure proper CSV formatting with commas and quotes where needed
|
||||
4. Make the data easily readable and importable into spreadsheet applications
|
||||
|
||||
Content to convert:
|
||||
{content}
|
||||
|
||||
Generate ONLY the CSV data:
|
||||
""",
|
||||
|
||||
".json": f"""
|
||||
Convert the following content into a proper JSON format.
|
||||
|
||||
Requirements:
|
||||
1. Output ONLY the JSON data without any markdown, code blocks, or additional text
|
||||
2. Structure the data logically with appropriate keys and values
|
||||
3. Ensure valid JSON syntax
|
||||
4. Make the data easily parseable and readable
|
||||
|
||||
Content to convert:
|
||||
{content}
|
||||
|
||||
Generate ONLY the JSON data:
|
||||
""",
|
||||
|
||||
".xml": f"""
|
||||
Convert the following content into a proper XML format.
|
||||
|
||||
Requirements:
|
||||
1. Output ONLY the XML data without any markdown, code blocks, or additional text
|
||||
2. Use appropriate XML tags and structure
|
||||
3. Ensure valid XML syntax
|
||||
4. Make the data easily parseable and readable
|
||||
|
||||
Content to convert:
|
||||
{content}
|
||||
|
||||
Generate ONLY the XML data:
|
||||
""",
|
||||
|
||||
".html": f"""
|
||||
Convert the following content into a proper HTML format.
|
||||
|
||||
Requirements:
|
||||
1. Output ONLY the HTML data without any markdown, code blocks, or additional text
|
||||
2. Use appropriate HTML tags and structure
|
||||
3. Ensure valid HTML syntax
|
||||
4. Make the data easily readable in web browsers
|
||||
|
||||
Content to convert:
|
||||
{content}
|
||||
|
||||
Generate ONLY the HTML data:
|
||||
""",
|
||||
|
||||
".md": f"""
|
||||
Convert the following content into a proper Markdown format.
|
||||
|
||||
Requirements:
|
||||
1. Output ONLY the Markdown data without any code blocks or additional text
|
||||
2. Use appropriate Markdown syntax for headers, lists, emphasis, etc.
|
||||
3. Structure the content logically
|
||||
4. Make the data easily readable and convertible to other formats
|
||||
|
||||
Content to convert:
|
||||
{content}
|
||||
|
||||
Generate ONLY the Markdown data:
|
||||
"""
|
||||
}
|
||||
|
||||
# Get the appropriate prompt for the target format
|
||||
if extension in format_prompts:
|
||||
ai_prompt = format_prompts[extension]
|
||||
else:
|
||||
# Generic format conversion
|
||||
ai_prompt = f"""
|
||||
Convert the following content into {extension.upper()} format.
|
||||
|
||||
Requirements:
|
||||
1. Output ONLY the {extension.upper()} data without any markdown, code blocks, or additional text
|
||||
2. Use appropriate formatting for {extension.upper()} files
|
||||
3. Ensure the output is valid and usable
|
||||
4. Make the data easily readable and importable
|
||||
|
||||
Content to convert:
|
||||
{content}
|
||||
|
||||
Generate ONLY the {extension.upper()} data:
|
||||
"""
|
||||
|
||||
# Call AI to generate the formatted content
|
||||
logger.info(f"Calling AI for {extension} format conversion")
|
||||
formatted_content = await self.service.callAiTextBasic(ai_prompt, content)
|
||||
|
||||
if not formatted_content or formatted_content.strip() == "":
|
||||
logger.warning("AI format conversion failed, using fallback")
|
||||
return self._generateFallbackFormattedContent(content, extension, mime_type)
|
||||
|
||||
# Clean up the AI response
|
||||
formatted_content = formatted_content.strip()
|
||||
|
||||
# Remove markdown code blocks if present
|
||||
if formatted_content.startswith("```") and formatted_content.endswith("```"):
|
||||
lines = formatted_content.split('\n')
|
||||
if len(lines) > 2:
|
||||
formatted_content = '\n'.join(lines[1:-1])
|
||||
|
||||
# For HTML format, check if AI returned complete HTML document
|
||||
if extension == ".html" and (formatted_content.startswith('<!DOCTYPE') or formatted_content.startswith('<html')):
|
||||
return formatted_content
|
||||
|
||||
return formatted_content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in AI format conversion: {str(e)}")
|
||||
return self._generateFallbackFormattedContent(content, extension, mime_type)
|
||||
|
||||
def _generateFallbackFormattedContent(self, content: str, extension: str, mime_type: str) -> str:
|
||||
"""
|
||||
Generate fallback formatted content when AI conversion fails.
|
||||
"""
|
||||
try:
|
||||
if extension == ".csv":
|
||||
# Simple CSV fallback - split by lines and create basic CSV
|
||||
lines = content.strip().split('\n')
|
||||
if lines:
|
||||
# Create a simple CSV with line numbers and content
|
||||
csv_lines = ["Line,Content"]
|
||||
for i, line in enumerate(lines, 1):
|
||||
# Escape quotes and wrap in quotes if comma present
|
||||
if ',' in line:
|
||||
line = f'"{line.replace(chr(34), chr(34) + chr(34))}"'
|
||||
csv_lines.append(f"{i},{line}")
|
||||
return '\n'.join(csv_lines)
|
||||
return "Line,Content\n1,No content available"
|
||||
|
||||
elif extension == ".json":
|
||||
# Simple JSON fallback
|
||||
content_escaped = content.replace('"', '\\"')
|
||||
timestamp = get_utc_timestamp()
|
||||
return f'{{"content": "{content_escaped}", "format": "json", "timestamp": {timestamp}}}'
|
||||
|
||||
elif extension == ".xml":
|
||||
# Simple XML fallback
|
||||
timestamp = get_utc_timestamp()
|
||||
return f'<?xml version="1.0" encoding="UTF-8"?>\n<document>\n<content>{content}</content>\n<format>xml</format>\n<timestamp>{timestamp}</timestamp>\n</document>'
|
||||
|
||||
elif extension == ".html":
|
||||
# Simple HTML fallback
|
||||
timestamp = int(get_utc_timestamp())
|
||||
return f'<!DOCTYPE html>\n<html>\n<head><meta charset="UTF-8"><title>Generated Document</title></head>\n<body>\n<pre>{content}</pre>\n<p><em>Generated on {timestamp}</em></p>\n</body>\n</html>'
|
||||
|
||||
elif extension == ".md":
|
||||
# Simple Markdown fallback
|
||||
timestamp = int(get_utc_timestamp())
|
||||
return f"# Generated Document\n\n{content}\n\n---\n*Generated on {timestamp}*"
|
||||
|
||||
else:
|
||||
# Generic fallback - return content as-is
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in fallback format conversion: {str(e)}")
|
||||
return content
|
||||
|
||||
@action
|
||||
async def generateReport(self, parameters: Dict[str, Any]) -> ActionResult:
|
||||
"""
|
||||
Generate HTML report from multiple documents using AI.
|
||||
|
||||
Parameters:
|
||||
documentList (str): Document list reference
|
||||
prompt (str): AI prompt for report generation
|
||||
title (str, optional): Report title (default: "Summary Report")
|
||||
includeMetadata (bool, optional): Include metadata (default: True)
|
||||
"""
|
||||
try:
|
||||
documentList = parameters.get("documentList")
|
||||
prompt = parameters.get("prompt")
|
||||
title = parameters.get("title", "Summary Report")
|
||||
includeMetadata = parameters.get("includeMetadata", True)
|
||||
|
||||
if not documentList:
|
||||
return ActionResult.isFailure(
|
||||
error="Document list reference is required"
|
||||
)
|
||||
|
||||
if not prompt:
|
||||
return ActionResult.isFailure(
|
||||
error="Prompt is required to specify what kind of report to generate"
|
||||
)
|
||||
|
||||
chatDocuments = self.service.getChatDocumentsFromDocumentList(documentList)
|
||||
logger.info(f"Retrieved {len(chatDocuments)} chat documents for report generation")
|
||||
|
||||
if not chatDocuments:
|
||||
return ActionResult.isFailure(
|
||||
error="No documents found for the provided reference"
|
||||
)
|
||||
|
||||
# Generate HTML report
|
||||
html_content = await self._generateHtmlReport(chatDocuments, title, includeMetadata, prompt)
|
||||
|
||||
# Create output fileName
|
||||
timestamp = int(get_utc_timestamp())
|
||||
output_fileName = f"report_{self._format_timestamp_for_filename()}.html"
|
||||
|
||||
result_data = {
|
||||
"documentCount": len(chatDocuments),
|
||||
"content": html_content,
|
||||
"title": title,
|
||||
"timestamp": get_utc_timestamp()
|
||||
}
|
||||
|
||||
logger.info(f"Generated HTML report: {output_fileName} with {len(html_content)} characters")
|
||||
|
||||
return ActionResult.isSuccess(
|
||||
documents=[{
|
||||
"documentName": output_fileName,
|
||||
"documentData": result_data,
|
||||
"mimeType": "text/html"
|
||||
}]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating report: {str(e)}")
|
||||
return ActionResult.isFailure(
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def _generateHtmlReport(self, chatDocuments: List[Any], title: str, includeMetadata: bool, prompt: str) -> str:
|
||||
"""
|
||||
Generate a comprehensive HTML report using AI from all input documents.
|
||||
"""
|
||||
try:
|
||||
# Filter out empty documents and collect content
|
||||
validDocuments = []
|
||||
allContent = []
|
||||
|
||||
for doc in chatDocuments:
|
||||
content = ""
|
||||
logger.info(f"Processing document: type={type(doc)}")
|
||||
|
||||
# Get actual file content using the document content extraction service
|
||||
try:
|
||||
extracted_content = await self.service.extractContentFromDocument(
|
||||
prompt="Extract readable text content for HTML report generation",
|
||||
document=doc
|
||||
)
|
||||
|
||||
if extracted_content and extracted_content.contents:
|
||||
# Get the first content item's data
|
||||
for content_item in extracted_content.contents:
|
||||
if hasattr(content_item, 'data') and content_item.data:
|
||||
content += content_item.data + " "
|
||||
|
||||
if content.strip():
|
||||
logger.info(f" Retrieved content from file: {len(content)} characters")
|
||||
else:
|
||||
logger.info(f" No readable text content found (binary file)")
|
||||
else:
|
||||
logger.info(f" No content extracted (binary file)")
|
||||
except Exception as e:
|
||||
logger.info(f" Could not extract content (binary file): {str(e)}")
|
||||
|
||||
# Skip empty documents
|
||||
if content and content.strip():
|
||||
validDocuments.append(doc)
|
||||
allContent.append(f"Document: {doc.fileName}\n{content}\n")
|
||||
logger.info(f" Added document to valid documents list")
|
||||
else:
|
||||
logger.info(f" Skipping document with no readable text content")
|
||||
|
||||
if not validDocuments:
|
||||
# If no valid documents, create a simple report
|
||||
html = ["<html><head><meta charset='utf-8'><title>" + title + "</title></head><body>"]
|
||||
html.append(f"<h1>{title}</h1>")
|
||||
html.append(f"<p><b>Generated:</b> {int(get_utc_timestamp())}</p>")
|
||||
html.append("<p><em>No content available in the provided documents.</em></p>")
|
||||
html.append("</body></html>")
|
||||
return '\n'.join(html)
|
||||
|
||||
# Create AI prompt for comprehensive report generation using user's prompt
|
||||
combinedContent = "\n\n".join(allContent)
|
||||
aiPrompt = f"""
|
||||
{prompt}
|
||||
|
||||
Report Title: {title}
|
||||
|
||||
Additional Requirements:
|
||||
1. Create a professional, well-formatted HTML report
|
||||
2. Include an executive summary at the beginning
|
||||
3. Organize information logically with clear sections
|
||||
4. Highlight key findings and insights
|
||||
5. Include relevant data, statistics, and conclusions
|
||||
6. Use proper HTML formatting with headers, lists, and styling
|
||||
7. Make it readable and professional
|
||||
|
||||
Document Content:
|
||||
---START OF DOCUMENT CONTENT-----------------------------------------------
|
||||
{combinedContent}
|
||||
---END OF DOCUMENT CONTENT-----------------------------------------------
|
||||
Generate a complete HTML report that addresses the user's specific requirements and integrates all the information into a cohesive, professional document.
|
||||
"""
|
||||
|
||||
# Call AI to generate the report
|
||||
logger.info(f"Generating AI report for {len(validDocuments)} documents")
|
||||
aiReport = await self.service.callAiTextBasic(aiPrompt, combinedContent)
|
||||
|
||||
# If AI call fails, return error - AI is crucial for report generation
|
||||
if not aiReport or aiReport.strip() == "":
|
||||
logger.error("AI report generation failed - AI is crucial for this action")
|
||||
raise Exception("AI report generation failed - AI is required for report generation")
|
||||
|
||||
# Clean up the AI response and ensure it's valid HTML
|
||||
aiReport = aiReport.strip()
|
||||
|
||||
# Strip fenced code blocks like ```html ... ``` if present
|
||||
if aiReport.startswith("```") and aiReport.endswith("```"):
|
||||
lines = aiReport.split('\n')
|
||||
if len(lines) >= 2:
|
||||
# remove first and last fence lines (language tag allowed on first)
|
||||
aiReport = '\n'.join(lines[1:-1]).strip()
|
||||
|
||||
# Check if AI response starts with DOCTYPE or html tag (complete HTML document)
|
||||
if aiReport.startswith('<!DOCTYPE') or aiReport.startswith('<html'):
|
||||
# AI returned complete HTML document, use it directly
|
||||
return aiReport
|
||||
else:
|
||||
# AI returned HTML content without document structure, wrap it
|
||||
|
||||
# Check if AI response already contains a title/header
|
||||
has_title = any(title.lower() in aiReport.lower() for title in [title, "outlook", "report", "status"])
|
||||
|
||||
# Wrap the AI content in proper HTML structure
|
||||
html = ["<html><head><meta charset='utf-8'><title>" + title + "</title></head><body>"]
|
||||
|
||||
# Only add the title if the AI response doesn't already have one
|
||||
if not has_title:
|
||||
html.append(f"<h1>{title}</h1>")
|
||||
|
||||
html.append(f"<p><b>Generated:</b> {int(get_utc_timestamp())}</p>")
|
||||
html.append(f"<p><b>Total Documents Analyzed:</b> {len(validDocuments)}</p>")
|
||||
html.append("<hr>")
|
||||
html.append(aiReport)
|
||||
html.append("</body></html>")
|
||||
return '\n'.join(html)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating AI report: {str(e)}")
|
||||
# Re-raise the error - AI is crucial for report generation
|
||||
raise
|
||||
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,284 +0,0 @@
|
|||
import logging
|
||||
import csv
|
||||
import io
|
||||
from typing import Any, Dict
|
||||
from modules.chat.methodBase import MethodBase, action
|
||||
from modules.interfaces.interfaceChatModel import ActionResult, ActionDocument
|
||||
from modules.interfaces.interfaceWebObjects import WebInterface
|
||||
from modules.interfaces.interfaceWebModel import (
|
||||
WebSearchRequest,
|
||||
WebCrawlRequest,
|
||||
WebScrapeRequest,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MethodWeb(MethodBase):
|
||||
"""Web method implementation for web operations."""
|
||||
|
||||
def __init__(self, serviceCenter: Any):
|
||||
super().__init__(serviceCenter)
|
||||
self.name = "web"
|
||||
self.description = "Web search, crawling, and scraping operations using Tavily"
|
||||
|
||||
@action
|
||||
async def search(self, parameters: Dict[str, Any]) -> ActionResult:
|
||||
"""Perform a web search and outputs a csv file with a list of found URLs
|
||||
|
||||
Each result contains columns "url" and "title".
|
||||
|
||||
Parameters:
|
||||
query (str): Search query to perform
|
||||
maxResults (int, optional): Maximum number of results (default: 10)
|
||||
"""
|
||||
|
||||
try:
|
||||
# Prepare request data
|
||||
web_search_request = WebSearchRequest(
|
||||
query=parameters.get("query"),
|
||||
max_results=parameters.get("maxResults", 10),
|
||||
)
|
||||
|
||||
# Perform request
|
||||
web_interface = await WebInterface.create()
|
||||
web_search_result = await web_interface.search(web_search_request)
|
||||
|
||||
# Convert search results to CSV format
|
||||
if web_search_result.success and web_search_result.documents:
|
||||
csv_content = web_interface.convert_web_search_result_to_csv(web_search_result)
|
||||
|
||||
# Create CSV document
|
||||
csv_document = web_interface.create_csv_action_document(
|
||||
csv_content,
|
||||
f"web_search_results.csv"
|
||||
)
|
||||
|
||||
return ActionResult(
|
||||
success=True,
|
||||
documents=[csv_document]
|
||||
)
|
||||
else:
|
||||
return web_search_result
|
||||
|
||||
except Exception as e:
|
||||
return ActionResult(success=False, error=str(e))
|
||||
|
||||
|
||||
|
||||
def _read_csv_with_urls(self, csv_content: str) -> list:
|
||||
"""Read CSV content and extract URLs from url,title or title,url format (both ; and , delimiters)"""
|
||||
urls = []
|
||||
|
||||
# Try both semicolon and comma delimiters
|
||||
for delimiter in [';', ',']:
|
||||
try:
|
||||
reader = csv.DictReader(io.StringIO(csv_content), delimiter=delimiter)
|
||||
for row in reader:
|
||||
# Look for url column (case insensitive)
|
||||
url = None
|
||||
for key in row.keys():
|
||||
if key.lower() == 'url':
|
||||
url = row[key].strip()
|
||||
break
|
||||
|
||||
if url and (url.startswith('http://') or url.startswith('https://')):
|
||||
urls.append(url)
|
||||
|
||||
# If we found URLs with this delimiter, return them
|
||||
if urls:
|
||||
return urls
|
||||
|
||||
except Exception:
|
||||
# Try next delimiter
|
||||
continue
|
||||
|
||||
# If no valid CSV found, try simple text parsing as fallback
|
||||
lines = csv_content.split('\n')
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line and (line.startswith('http://') or line.startswith('https://')):
|
||||
urls.append(line)
|
||||
|
||||
return urls
|
||||
|
||||
@action
|
||||
async def crawl(self, parameters: Dict[str, Any]) -> ActionResult:
|
||||
"""Crawls a list of URLs and extracts information from them.
|
||||
|
||||
Parameters:
|
||||
documentList (str): Document list reference containing URL lists from search results
|
||||
expectedDocumentFormats (list, optional): Expected document formats with extension, mimeType, description
|
||||
"""
|
||||
try:
|
||||
document_list = parameters.get("documentList")
|
||||
|
||||
if not document_list:
|
||||
return ActionResult(
|
||||
success=False, error="No document list reference provided."
|
||||
)
|
||||
|
||||
# Resolve document list reference to ChatDocument objects
|
||||
chat_documents = self.service.getChatDocumentsFromDocumentList(document_list)
|
||||
|
||||
if not chat_documents:
|
||||
return ActionResult(
|
||||
success=False,
|
||||
error=f"No documents found for reference: {document_list}",
|
||||
)
|
||||
|
||||
# Extract URLs from all documents and combine them
|
||||
all_urls = []
|
||||
import json
|
||||
import re
|
||||
|
||||
for i, doc in enumerate(chat_documents):
|
||||
logger.info(f"Processing document {i+1}/{len(chat_documents)}: {doc.fileName}")
|
||||
|
||||
# Get file data using the service center
|
||||
file_data = self.service.getFileData(doc.fileId)
|
||||
if not file_data:
|
||||
logger.warning(f"Could not retrieve file data for document: {doc.fileName}")
|
||||
continue
|
||||
|
||||
content = file_data.decode("utf-8")
|
||||
|
||||
# Try to parse as CSV first (for new CSV format)
|
||||
if doc.fileName.lower().endswith('.csv') or 'csv' in doc.mimeType.lower():
|
||||
logger.info(f"Processing CSV file: {doc.fileName}")
|
||||
doc_urls = self._read_csv_with_urls(content)
|
||||
else:
|
||||
# Parse JSON to extract URLs from search results
|
||||
try:
|
||||
# The document structure from WebSearchActionResult
|
||||
search_data = json.loads(content)
|
||||
|
||||
# Extract URLs from the search results structure
|
||||
doc_urls = []
|
||||
if isinstance(search_data, dict):
|
||||
# Handle the document structure: documentData contains the actual search results
|
||||
doc_data = search_data.get("documentData", search_data)
|
||||
if "results" in doc_data and isinstance(doc_data["results"], list):
|
||||
doc_urls = [
|
||||
result["url"]
|
||||
for result in doc_data["results"]
|
||||
if isinstance(result, dict) and "url" in result
|
||||
]
|
||||
elif "urls" in doc_data and isinstance(doc_data["urls"], list):
|
||||
# Fallback: if URLs are stored directly in a 'urls' field
|
||||
doc_urls = [url for url in doc_data["urls"] if isinstance(url, str)]
|
||||
|
||||
# Fallback: try to parse as plain text with regex (for backward compatibility)
|
||||
if not doc_urls:
|
||||
logger.warning(
|
||||
f"Could not extract URLs from JSON structure in {doc.fileName}, trying plain text parsing"
|
||||
)
|
||||
doc_urls = re.split(r"[\n,;]+", content)
|
||||
doc_urls = [
|
||||
u.strip()
|
||||
for u in doc_urls
|
||||
if u.strip()
|
||||
and (
|
||||
u.strip().startswith("http://")
|
||||
or u.strip().startswith("https://")
|
||||
)
|
||||
]
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# Fallback to plain text parsing if JSON parsing fails
|
||||
logger.warning(f"Document {doc.fileName} is not valid JSON, trying plain text parsing")
|
||||
doc_urls = re.split(r"[\n,;]+", content)
|
||||
doc_urls = [
|
||||
u.strip()
|
||||
for u in doc_urls
|
||||
if u.strip()
|
||||
and (
|
||||
u.strip().startswith("http://")
|
||||
or u.strip().startswith("https://")
|
||||
)
|
||||
]
|
||||
|
||||
if doc_urls:
|
||||
all_urls.extend(doc_urls)
|
||||
logger.info(f"Extracted {len(doc_urls)} URLs from {doc.fileName}")
|
||||
else:
|
||||
logger.warning(f"No valid URLs found in document: {doc.fileName}")
|
||||
|
||||
if not all_urls:
|
||||
return ActionResult(
|
||||
success=False, error="No valid URLs found in any of the documents."
|
||||
)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
unique_urls = list(dict.fromkeys(all_urls))
|
||||
logger.info(f"Extracted {len(unique_urls)} unique URLs from {len(chat_documents)} documents")
|
||||
|
||||
# Prepare request data
|
||||
web_crawl_request = WebCrawlRequest(urls=unique_urls)
|
||||
|
||||
# Perform request
|
||||
web_interface = await WebInterface.create()
|
||||
web_crawl_result = await web_interface.crawl(web_crawl_request)
|
||||
|
||||
# Convert to proper JSON format
|
||||
if web_crawl_result.success:
|
||||
json_content = web_interface.convert_web_result_to_json(web_crawl_result)
|
||||
json_document = web_interface.create_json_action_document(
|
||||
json_content,
|
||||
f"web_crawl_results.json"
|
||||
)
|
||||
return ActionResult(
|
||||
success=True,
|
||||
documents=[json_document]
|
||||
)
|
||||
else:
|
||||
return web_crawl_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in crawl method: {str(e)}")
|
||||
return ActionResult(success=False, error=str(e))
|
||||
|
||||
@action
|
||||
async def scrape(self, parameters: Dict[str, Any]) -> ActionResult:
|
||||
"""Scrapes web content by searching for URLs and then extracting their content.
|
||||
|
||||
Combines search and crawl operations in one step.
|
||||
|
||||
Parameters:
|
||||
query (str): Search query to perform
|
||||
maxResults (int, optional): Maximum number of results (default: 10)
|
||||
"""
|
||||
try:
|
||||
query = parameters.get("query")
|
||||
max_results = parameters.get("maxResults", 10)
|
||||
|
||||
if not query:
|
||||
return ActionResult(success=False, error="Search query is required")
|
||||
|
||||
# Prepare request data
|
||||
web_scrape_request = WebScrapeRequest(
|
||||
query=query,
|
||||
max_results=max_results,
|
||||
)
|
||||
|
||||
# Perform request
|
||||
web_interface = await WebInterface.create()
|
||||
web_scrape_result = await web_interface.scrape(web_scrape_request)
|
||||
|
||||
# Convert to proper JSON format
|
||||
if web_scrape_result.success:
|
||||
json_content = web_interface.convert_web_result_to_json(web_scrape_result)
|
||||
json_document = web_interface.create_json_action_document(
|
||||
json_content,
|
||||
f"web_scrape_results.json"
|
||||
)
|
||||
return ActionResult(
|
||||
success=True,
|
||||
documents=[json_document]
|
||||
)
|
||||
else:
|
||||
return web_scrape_result
|
||||
|
||||
except Exception as e:
|
||||
return ActionResult(success=False, error=str(e))
|
||||
|
|
@ -1,484 +0,0 @@
|
|||
"""
|
||||
DSGVO-konformer Daten-Neutralisierer für KI-Agentensysteme
|
||||
Unterstützt TXT, JSON, CSV, Excel und Word-Dateien
|
||||
Mehrsprachig: DE, EN, FR, IT
|
||||
"""
|
||||
|
||||
import re
|
||||
import json
|
||||
import pandas as pd
|
||||
import docx
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Any, Union, Optional
|
||||
from dataclasses import dataclass
|
||||
import uuid
|
||||
import logging
|
||||
import traceback
|
||||
import csv
|
||||
from datetime import datetime
|
||||
import xml.etree.ElementTree as ET
|
||||
import os
|
||||
import random
|
||||
from io import StringIO
|
||||
from modules.neutralizer.patterns import Pattern, HeaderPatterns, DataPatterns, get_pattern_for_header, find_patterns_in_text, TextTablePatterns
|
||||
import base64
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TableData:
|
||||
"""Repräsentiert Tabellendaten"""
|
||||
headers: List[str]
|
||||
rows: List[List[str]]
|
||||
source_type: str # 'csv', 'json', 'xml', 'text_table'
|
||||
|
||||
@dataclass
|
||||
class PlainText:
|
||||
"""Repräsentiert normalen Text"""
|
||||
content: str
|
||||
source_type: str # 'txt', 'docx', 'text_plain'
|
||||
|
||||
@dataclass
|
||||
class ProcessResult:
|
||||
"""Result of content processing"""
|
||||
data: Any
|
||||
mapping: Dict[str, str]
|
||||
replaced_fields: List[str]
|
||||
processed_info: Dict[str, Any] # Additional processing information
|
||||
|
||||
class DataAnonymizer:
|
||||
"""Hauptklasse für die Datenanonymisierung"""
|
||||
|
||||
def __init__(self, names_to_parse: List[str] = None):
|
||||
"""Initialize the anonymizer with patterns and custom names
|
||||
|
||||
Args:
|
||||
names_to_parse: List of names to parse and replace (case-insensitive)
|
||||
"""
|
||||
self.header_patterns = HeaderPatterns.patterns
|
||||
self.data_patterns = DataPatterns.patterns
|
||||
self.names_to_parse = names_to_parse or []
|
||||
self.replaced_fields = set()
|
||||
self.mapping = {}
|
||||
self.processing_info = []
|
||||
|
||||
def _normalize_whitespace(self, text: str) -> str:
|
||||
"""Normalize whitespace in text"""
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _is_table_line(self, line: str) -> bool:
|
||||
"""Check if a line represents a table row"""
|
||||
return bool(re.match(r'^\s*[^:]+:\s*[^:]+$', line) or
|
||||
re.match(r'^\s*[^\t]+\t[^\t]+$', line))
|
||||
|
||||
def _extract_tables_from_text(self, content: str) -> Tuple[List[TableData], List[PlainText]]:
|
||||
"""
|
||||
Extract tables and plain text from content
|
||||
|
||||
Args:
|
||||
content: Content to process
|
||||
|
||||
Returns:
|
||||
Tuple of (list of tables, list of plain text sections)
|
||||
"""
|
||||
tables = []
|
||||
plain_texts = []
|
||||
|
||||
# Process the entire content as plain text
|
||||
plain_texts.append(PlainText(content=content, source_type='text_plain'))
|
||||
|
||||
return tables, plain_texts
|
||||
|
||||
def _anonymize_table(self, table: TableData) -> TableData:
|
||||
"""Anonymize table data"""
|
||||
try:
|
||||
anonymized_table = TableData(
|
||||
headers=table.headers.copy(),
|
||||
rows=[row.copy() for row in table.rows],
|
||||
source_type=table.source_type
|
||||
)
|
||||
|
||||
for i, header in enumerate(anonymized_table.headers):
|
||||
pattern = get_pattern_for_header(header, self.header_patterns)
|
||||
if pattern:
|
||||
for row in anonymized_table.rows:
|
||||
if row[i] is not None:
|
||||
original = str(row[i])
|
||||
if original not in self.mapping:
|
||||
self.mapping[original] = pattern.replacement_template.format(len(self.mapping) + 1)
|
||||
row[i] = self.mapping[original]
|
||||
|
||||
return anonymized_table
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error anonymizing table: {str(e)}")
|
||||
raise
|
||||
|
||||
def _anonymize_plain_text(self, text: PlainText) -> PlainText:
|
||||
"""Anonymize plain text content using simple search-and-replace approach"""
|
||||
try:
|
||||
current_text = text.content
|
||||
|
||||
# Step 1: Replace custom names first (simple regex search-and-replace)
|
||||
for name in self.names_to_parse:
|
||||
if not name.strip():
|
||||
continue
|
||||
|
||||
# Create case-insensitive regex pattern with word boundaries
|
||||
pattern = re.compile(r'\b' + re.escape(name.strip()) + r'\b', re.IGNORECASE)
|
||||
|
||||
# Find all matches for this name
|
||||
matches = list(pattern.finditer(current_text))
|
||||
|
||||
# Replace each match with a placeholder
|
||||
for match in reversed(matches): # Process from right to left to avoid position shifts
|
||||
matched_text = match.group()
|
||||
if matched_text not in self.mapping:
|
||||
# Generate a UUID for the placeholder
|
||||
import uuid
|
||||
placeholder_id = str(uuid.uuid4())
|
||||
self.mapping[matched_text] = f"[name.{placeholder_id}]"
|
||||
|
||||
replacement = self.mapping[matched_text]
|
||||
start, end = match.span()
|
||||
current_text = current_text[:start] + replacement + current_text[end:]
|
||||
|
||||
# Step 2: Replace pattern-based matches (emails, phones, etc.)
|
||||
# Use the same simple approach for patterns
|
||||
pattern_matches = find_patterns_in_text(current_text, self.data_patterns)
|
||||
|
||||
# Process pattern matches from right to left to avoid position shifts
|
||||
for pattern_name, matched_text, start, end in reversed(pattern_matches):
|
||||
# Skip if already a placeholder
|
||||
if re.match(r'\[[a-z]+\.[a-f0-9-]+\]', matched_text):
|
||||
continue
|
||||
|
||||
# Skip if contains placeholder characters
|
||||
if '[' in matched_text or ']' in matched_text:
|
||||
continue
|
||||
|
||||
if matched_text not in self.mapping:
|
||||
# Generate a UUID for the placeholder
|
||||
import uuid
|
||||
placeholder_id = str(uuid.uuid4())
|
||||
# Create placeholder in format [type.uuid]
|
||||
type_mapping = {
|
||||
'email': 'email',
|
||||
'phone': 'phone',
|
||||
'address': 'address',
|
||||
'id': 'id'
|
||||
}
|
||||
placeholder_type = type_mapping.get(pattern_name, 'data')
|
||||
self.mapping[matched_text] = f"[{placeholder_type}.{placeholder_id}]"
|
||||
|
||||
replacement = self.mapping[matched_text]
|
||||
current_text = current_text[:start] + replacement + current_text[end:]
|
||||
|
||||
return PlainText(content=current_text, source_type=text.source_type)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error anonymizing plain text: {str(e)}")
|
||||
raise
|
||||
|
||||
def _anonymize_json_value(self, value: Any, key: str = None) -> Any:
|
||||
"""
|
||||
Recursively anonymize JSON values based on their keys and content
|
||||
|
||||
Args:
|
||||
value: Value to anonymize
|
||||
key: Key name (if part of a key-value pair)
|
||||
|
||||
Returns:
|
||||
Anonymized value
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return {k: self._anonymize_json_value(v, k) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [self._anonymize_json_value(item) for item in value]
|
||||
elif isinstance(value, str):
|
||||
# Check if this is a key we should process
|
||||
if key:
|
||||
pattern = get_pattern_for_header(key, self.header_patterns)
|
||||
if pattern:
|
||||
if value not in self.mapping:
|
||||
# Generate a UUID for the placeholder
|
||||
import uuid
|
||||
placeholder_id = str(uuid.uuid4())
|
||||
# Create placeholder in format [type.uuid]
|
||||
type_mapping = {
|
||||
'email': 'email',
|
||||
'phone': 'phone',
|
||||
'name': 'name',
|
||||
'address': 'address',
|
||||
'id': 'id'
|
||||
}
|
||||
placeholder_type = type_mapping.get(pattern.name, 'data')
|
||||
self.mapping[value] = f"[{placeholder_type}.{placeholder_id}]"
|
||||
return self.mapping[value]
|
||||
|
||||
# Check if the value itself matches any patterns
|
||||
pattern_matches = find_patterns_in_text(value, self.data_patterns)
|
||||
custom_name_matches = self._find_custom_names(value)
|
||||
|
||||
if pattern_matches or custom_name_matches:
|
||||
# Use the first match's pattern or custom name
|
||||
if pattern_matches:
|
||||
pattern_name = pattern_matches[0][0]
|
||||
if value not in self.mapping:
|
||||
# Generate a UUID for the placeholder
|
||||
import uuid
|
||||
placeholder_id = str(uuid.uuid4())
|
||||
# Create placeholder in format [type.uuid]
|
||||
type_mapping = {
|
||||
'email': 'email',
|
||||
'phone': 'phone',
|
||||
'name': 'name',
|
||||
'address': 'address',
|
||||
'id': 'id'
|
||||
}
|
||||
placeholder_type = type_mapping.get(pattern_name, 'data')
|
||||
self.mapping[value] = f"[{placeholder_type}.{placeholder_id}]"
|
||||
elif custom_name_matches:
|
||||
if value not in self.mapping:
|
||||
# Generate a UUID for the placeholder
|
||||
import uuid
|
||||
placeholder_id = str(uuid.uuid4())
|
||||
self.mapping[value] = f"[name.{placeholder_id}]"
|
||||
return self.mapping[value]
|
||||
|
||||
return value
|
||||
else:
|
||||
return value
|
||||
|
||||
def _anonymize_xml_element(self, element: ET.Element, indent: str = '') -> str:
|
||||
"""
|
||||
Recursively process XML element and return formatted string
|
||||
|
||||
Args:
|
||||
element: XML element to process
|
||||
indent: Current indentation level
|
||||
|
||||
Returns:
|
||||
Formatted XML string
|
||||
"""
|
||||
# Process attributes
|
||||
processed_attrs = {}
|
||||
for attr_name, attr_value in element.attrib.items():
|
||||
# Check if attribute name matches any header patterns
|
||||
pattern = get_pattern_for_header(attr_name, self.header_patterns)
|
||||
if pattern:
|
||||
if attr_value not in self.mapping:
|
||||
# Generate a UUID for the placeholder
|
||||
import uuid
|
||||
placeholder_id = str(uuid.uuid4())
|
||||
# Create placeholder in format [type.uuid]
|
||||
type_mapping = {
|
||||
'email': 'email',
|
||||
'phone': 'phone',
|
||||
'name': 'name',
|
||||
'address': 'address',
|
||||
'id': 'id'
|
||||
}
|
||||
placeholder_type = type_mapping.get(pattern.name, 'data')
|
||||
self.mapping[attr_value] = f"[{placeholder_type}.{placeholder_id}]"
|
||||
processed_attrs[attr_name] = self.mapping[attr_value]
|
||||
else:
|
||||
# Check if attribute value matches any data patterns
|
||||
matches = find_patterns_in_text(attr_value, self.data_patterns)
|
||||
if matches:
|
||||
pattern_name = matches[0][0]
|
||||
pattern = next((p for p in self.data_patterns if p.name == pattern_name), None)
|
||||
if pattern:
|
||||
if attr_value not in self.mapping:
|
||||
# Generate a UUID for the placeholder
|
||||
import uuid
|
||||
placeholder_id = str(uuid.uuid4())
|
||||
# Create placeholder in format [type.uuid]
|
||||
type_mapping = {
|
||||
'email': 'email',
|
||||
'phone': 'phone',
|
||||
'name': 'name',
|
||||
'address': 'address',
|
||||
'id': 'id'
|
||||
}
|
||||
placeholder_type = type_mapping.get(pattern_name, 'data')
|
||||
self.mapping[attr_value] = f"[{placeholder_type}.{placeholder_id}]"
|
||||
processed_attrs[attr_name] = self.mapping[attr_value]
|
||||
else:
|
||||
processed_attrs[attr_name] = attr_value
|
||||
else:
|
||||
processed_attrs[attr_name] = attr_value
|
||||
|
||||
attrs = ' '.join(f'{k}="{v}"' for k, v in processed_attrs.items())
|
||||
attrs = f' {attrs}' if attrs else ''
|
||||
|
||||
# Process text content
|
||||
text = element.text.strip() if element.text and element.text.strip() else ''
|
||||
if text:
|
||||
# Check if text matches any patterns or custom names
|
||||
pattern_matches = find_patterns_in_text(text, self.data_patterns)
|
||||
custom_name_matches = self._find_custom_names(text)
|
||||
|
||||
if pattern_matches or custom_name_matches:
|
||||
if pattern_matches:
|
||||
pattern_name = pattern_matches[0][0]
|
||||
pattern = next((p for p in self.data_patterns if p.name == pattern_name), None)
|
||||
if pattern:
|
||||
if text not in self.mapping:
|
||||
# Generate a UUID for the placeholder
|
||||
import uuid
|
||||
placeholder_id = str(uuid.uuid4())
|
||||
# Create placeholder in format [type.uuid]
|
||||
type_mapping = {
|
||||
'email': 'email',
|
||||
'phone': 'phone',
|
||||
'name': 'name',
|
||||
'address': 'address',
|
||||
'id': 'id'
|
||||
}
|
||||
placeholder_type = type_mapping.get(pattern_name, 'data')
|
||||
self.mapping[text] = f"[{placeholder_type}.{placeholder_id}]"
|
||||
text = self.mapping[text]
|
||||
elif custom_name_matches:
|
||||
if text not in self.mapping:
|
||||
# Generate a UUID for the placeholder
|
||||
import uuid
|
||||
placeholder_id = str(uuid.uuid4())
|
||||
self.mapping[text] = f"[name.{placeholder_id}]"
|
||||
text = self.mapping[text]
|
||||
|
||||
# Process child elements
|
||||
children = []
|
||||
for child in element:
|
||||
child_str = self._anonymize_xml_element(child, indent + ' ')
|
||||
children.append(child_str)
|
||||
|
||||
# Build element string
|
||||
if not children and not text:
|
||||
return f"{indent}<{element.tag}{attrs}/>"
|
||||
elif not children:
|
||||
return f"{indent}<{element.tag}{attrs}>{text}</{element.tag}>"
|
||||
else:
|
||||
result = [f"{indent}<{element.tag}{attrs}>"]
|
||||
if text:
|
||||
result.append(f"{indent} {text}")
|
||||
result.extend(children)
|
||||
result.append(f"{indent}</{element.tag}>")
|
||||
return '\n'.join(result)
|
||||
|
||||
def process_content(self, content: str, content_type: str) -> ProcessResult:
|
||||
"""
|
||||
Process content and return anonymized data
|
||||
|
||||
Args:
|
||||
content: Content to process
|
||||
content_type: Type of content ('csv', 'json', 'xml', 'text')
|
||||
|
||||
Returns:
|
||||
ProcessResult: Contains anonymized data, mapping, replaced fields and processing info
|
||||
"""
|
||||
try:
|
||||
|
||||
# Check if content is binary data
|
||||
is_binary = False
|
||||
try:
|
||||
# First, check if content looks like base64 (contains only base64 characters)
|
||||
if re.match(r'^[A-Za-z0-9+/]*={0,2}$', content.strip()):
|
||||
# Try to decode base64 if it looks like base64
|
||||
try:
|
||||
decoded = base64.b64decode(content)
|
||||
# If it's not valid text, consider it binary
|
||||
decoded.decode('utf-8')
|
||||
is_binary = True
|
||||
except (base64.binascii.Error, UnicodeDecodeError):
|
||||
is_binary = False
|
||||
else:
|
||||
is_binary = False
|
||||
except Exception as e:
|
||||
is_binary = False
|
||||
|
||||
if is_binary:
|
||||
# TODO: Implement binary data neutralization
|
||||
# This would require:
|
||||
# 1. Detecting binary data types (images, audio, video, etc.)
|
||||
# 2. Implementing specific neutralization for each type
|
||||
# 3. Handling metadata and embedded content
|
||||
# 4. Preserving binary integrity while removing sensitive data
|
||||
return ProcessResult(content, self.mapping, [], {'type': 'binary', 'status': 'not_implemented'})
|
||||
|
||||
replaced_fields = []
|
||||
processed_info = {}
|
||||
|
||||
if content_type in ['csv', 'json', 'xml']:
|
||||
# Handle as table
|
||||
if content_type == 'csv':
|
||||
df = pd.read_csv(StringIO(content), encoding='utf-8')
|
||||
table = TableData(
|
||||
headers=df.columns.tolist(),
|
||||
rows=df.values.tolist(),
|
||||
source_type='csv'
|
||||
)
|
||||
processed_info['type'] = 'table'
|
||||
processed_info['headers'] = table.headers
|
||||
processed_info['row_count'] = len(table.rows)
|
||||
elif content_type == 'json':
|
||||
data = json.loads(content)
|
||||
# Process JSON recursively
|
||||
result = self._anonymize_json_value(data)
|
||||
processed_info['type'] = 'json'
|
||||
return ProcessResult(result, self.mapping, replaced_fields, processed_info)
|
||||
else: # xml
|
||||
root = ET.fromstring(content)
|
||||
# Process XML recursively with proper formatting
|
||||
result = self._anonymize_xml_element(root)
|
||||
processed_info['type'] = 'xml'
|
||||
return ProcessResult(result, self.mapping, replaced_fields, processed_info)
|
||||
|
||||
if not table.rows:
|
||||
return ProcessResult(None, self.mapping, [], processed_info)
|
||||
|
||||
anonymized_table = self._anonymize_table(table)
|
||||
|
||||
# Track replaced fields
|
||||
for i, header in enumerate(anonymized_table.headers):
|
||||
for orig_row, anon_row in zip(table.rows, anonymized_table.rows):
|
||||
if anon_row[i] != orig_row[i]:
|
||||
replaced_fields.append(header)
|
||||
|
||||
# Convert back to original format
|
||||
if content_type == 'csv':
|
||||
result = pd.DataFrame(anonymized_table.rows, columns=anonymized_table.headers)
|
||||
elif content_type == 'json':
|
||||
if len(anonymized_table.headers) == 1 and anonymized_table.headers[0] == 'value':
|
||||
result = anonymized_table.rows[0][0]
|
||||
else:
|
||||
result = dict(zip(anonymized_table.headers, anonymized_table.rows[0]))
|
||||
else: # xml
|
||||
result = ET.tostring(root, encoding='unicode')
|
||||
|
||||
return ProcessResult(result, self.mapping, replaced_fields, processed_info)
|
||||
else:
|
||||
# Handle as text
|
||||
# First, identify what needs to be replaced using table detection
|
||||
tables, plain_texts = self._extract_tables_from_text(content)
|
||||
processed_info['type'] = 'text'
|
||||
processed_info['tables'] = [{'headers': t.headers, 'row_count': len(t.rows)} for t in tables]
|
||||
|
||||
# Process plain text sections
|
||||
anonymized_texts = [self._anonymize_plain_text(text) for text in plain_texts]
|
||||
|
||||
# Combine all processed content
|
||||
result = content
|
||||
for i, (text, anonymized_text) in enumerate(zip(plain_texts, anonymized_texts)):
|
||||
if text.content != anonymized_text.content:
|
||||
result = result.replace(text.content, anonymized_text.content)
|
||||
|
||||
return ProcessResult(result, self.mapping, replaced_fields, processed_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing content: {str(e)}")
|
||||
return ProcessResult(None, self.mapping, [], {'type': 'error', 'error': str(e)})
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from fastapi import APIRouter, Response, Depends, Request
|
||||
from fastapi import APIRouter, Response, Depends, Request, Body
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import os
|
||||
|
|
@ -9,13 +9,8 @@ from fastapi import HTTPException, status
|
|||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.security.auth import limiter, getCurrentUser
|
||||
from modules.interfaces.interfaceAppModel import User
|
||||
|
||||
router = APIRouter(
|
||||
prefix="",
|
||||
tags=["General"],
|
||||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.interfaces.interfaceDbAppObjects import getRootInterface
|
||||
|
||||
# Static folder setup - using absolute path from app root
|
||||
baseDir = FilePath(__file__).parent.parent.parent # Go up to gateway root
|
||||
|
|
@ -25,41 +20,74 @@ os.makedirs(staticFolder, exist_ok=True)
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="",
|
||||
tags=["Administration"],
|
||||
responses={404: {"description": "Not found"}}
|
||||
prefix="", tags=["Administration"], responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
# Mount static files
|
||||
router.mount("/static", StaticFiles(directory=str(staticFolder), html=True), name="static")
|
||||
router.mount(
|
||||
"/static", StaticFiles(directory=str(staticFolder), html=True), name="static"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@limiter.limit("30/minute")
|
||||
async def root(request: Request) -> Dict[str, str]:
|
||||
"""API status endpoint"""
|
||||
# Validate required configuration values
|
||||
allowedOrigins = APP_CONFIG.get("APP_ALLOWED_ORIGINS")
|
||||
if not allowedOrigins:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="APP_ALLOWED_ORIGINS configuration is required"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "online",
|
||||
"message": "Data Platform API is active",
|
||||
"allowedOrigins": f"Allowed origins are {APP_CONFIG.get('APP_ALLOWED_ORIGINS')}"
|
||||
"allowedOrigins": f"Allowed origins are {allowedOrigins}",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/environment")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_environment(request: Request, currentUser: Dict[str, Any] = Depends(getCurrentUser)) -> Dict[str, str]:
|
||||
async def get_environment(
|
||||
request: Request, currentUser: Dict[str, Any] = Depends(getCurrentUser)
|
||||
) -> Dict[str, str]:
|
||||
"""Get environment configuration for frontend"""
|
||||
# Validate required configuration values
|
||||
apiBaseUrl = APP_CONFIG.get("APP_API_URL")
|
||||
if not apiBaseUrl:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="APP_API_URL configuration is required"
|
||||
)
|
||||
|
||||
environment = APP_CONFIG.get("APP_ENV")
|
||||
if not environment:
|
||||
raise HTTPException(status_code=500, detail="APP_ENV configuration is required")
|
||||
|
||||
instanceLabel = APP_CONFIG.get("APP_ENV_LABEL")
|
||||
if not instanceLabel:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="APP_ENV_LABEL configuration is required"
|
||||
)
|
||||
|
||||
return {
|
||||
"apiBaseUrl": APP_CONFIG.get("APP_API_URL", ""),
|
||||
"environment": APP_CONFIG.get("APP_ENV", "development"),
|
||||
"instanceLabel": APP_CONFIG.get("APP_ENV_LABEL", "Development"),
|
||||
"apiBaseUrl": apiBaseUrl,
|
||||
"environment": environment,
|
||||
"instanceLabel": instanceLabel,
|
||||
# Add other environment variables the frontend might need
|
||||
}
|
||||
|
||||
|
||||
@router.options("/{fullPath:path}")
|
||||
@limiter.limit("60/minute")
|
||||
async def options_route(request: Request, fullPath: str) -> Response:
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.get("/favicon.ico")
|
||||
@limiter.limit("30/minute")
|
||||
async def favicon(request: Request) -> FileResponse:
|
||||
return FileResponse(str(staticFolder / "favicon.ico"), media_type="image/x-icon")
|
||||
favicon_path = staticFolder / "favicon.ico"
|
||||
if not favicon_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Favicon not found")
|
||||
return FileResponse(str(favicon_path), media_type="image/x-icon")
|
||||
|
|
|
|||
150
modules/routes/routeAdminAutomationEvents.py
Normal file
150
modules/routes/routeAdminAutomationEvents.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
"""
|
||||
Admin automation events routes for the backend API.
|
||||
Sysadmin-only endpoints for viewing and controlling automation events.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Path, Request, Response
|
||||
from typing import List, Dict, Any
|
||||
from fastapi import status
|
||||
import logging
|
||||
|
||||
# Import interfaces and models
|
||||
import modules.interfaces.interfaceDbChatObjects as interfaceDbChatObjects
|
||||
from modules.security.auth import getCurrentUser, limiter
|
||||
from modules.datamodels.datamodelUam import User, UserPrivilege
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create router for admin automation events endpoints
|
||||
router = APIRouter(
|
||||
prefix="/api/admin/automation-events",
|
||||
tags=["Admin Automation Events"],
|
||||
responses={
|
||||
404: {"description": "Not found"},
|
||||
400: {"description": "Bad request"},
|
||||
401: {"description": "Unauthorized"},
|
||||
403: {"description": "Forbidden - Sysadmin only"},
|
||||
500: {"description": "Internal server error"}
|
||||
}
|
||||
)
|
||||
|
||||
def requireSysadmin(currentUser: User):
|
||||
"""Require sysadmin privilege"""
|
||||
if currentUser.privilege != UserPrivilege.SYSADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Sysadmin privilege required"
|
||||
)
|
||||
|
||||
@router.get("")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_all_automation_events(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all automation events across all mandates (sysadmin only).
|
||||
Returns list of all registered events with their automation IDs and schedules.
|
||||
"""
|
||||
requireSysadmin(currentUser)
|
||||
|
||||
try:
|
||||
from modules.shared.eventManagement import eventManager
|
||||
|
||||
# Get all jobs from scheduler
|
||||
jobs = []
|
||||
if eventManager.scheduler:
|
||||
for job in eventManager.scheduler.get_jobs():
|
||||
if job.id.startswith("automation."):
|
||||
automation_id = job.id.replace("automation.", "")
|
||||
jobs.append({
|
||||
"eventId": job.id,
|
||||
"automationId": automation_id,
|
||||
"nextRunTime": str(job.next_run_time) if job.next_run_time else None,
|
||||
"trigger": str(job.trigger) if job.trigger else None
|
||||
})
|
||||
|
||||
return jobs
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting automation events: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error getting automation events: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/sync")
|
||||
@limiter.limit("5/minute")
|
||||
async def sync_all_automation_events(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Manually trigger sync for all automations (sysadmin only).
|
||||
This will register/remove events based on active flags.
|
||||
"""
|
||||
requireSysadmin(currentUser)
|
||||
|
||||
try:
|
||||
chatInterface = interfaceDbChatObjects.getInterface(currentUser)
|
||||
|
||||
if not hasattr(chatInterface, 'syncAutomationEvents'):
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="Automation methods not available"
|
||||
)
|
||||
|
||||
result = await chatInterface.syncAutomationEvents()
|
||||
return {
|
||||
"success": True,
|
||||
"synced": result.get("synced", 0),
|
||||
"events": result.get("events", {})
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error syncing automation events: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error syncing automation events: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/{eventId}/remove")
|
||||
@limiter.limit("10/minute")
|
||||
async def remove_event(
|
||||
request: Request,
|
||||
eventId: str = Path(..., description="Event ID to remove"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Manually remove a specific event from scheduler (sysadmin only).
|
||||
Used for debugging and manual event cleanup.
|
||||
"""
|
||||
requireSysadmin(currentUser)
|
||||
|
||||
try:
|
||||
from modules.shared.eventManagement import eventManager
|
||||
|
||||
# Remove event
|
||||
eventManager.remove(eventId)
|
||||
|
||||
# Update automation's eventId if it exists
|
||||
if eventId.startswith("automation."):
|
||||
automation_id = eventId.replace("automation.", "")
|
||||
chatInterface = interfaceDbChatObjects.getInterface(currentUser)
|
||||
automation = chatInterface.getAutomationDefinition(automation_id)
|
||||
if automation and automation.get("eventId") == eventId:
|
||||
chatInterface.updateAutomationDefinition(automation_id, {"eventId": None})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"eventId": eventId,
|
||||
"message": f"Event {eventId} removed successfully"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing event: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error removing event: {str(e)}"
|
||||
)
|
||||
|
||||
|
|
@ -1,17 +1,11 @@
|
|||
from fastapi import APIRouter, HTTPException, Depends, Path, Response, Request
|
||||
from typing import List, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Path, Response, Request
|
||||
from fastapi import status
|
||||
import inspect
|
||||
import importlib
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
# Import auth module
|
||||
from modules.security.auth import limiter, getCurrentUser
|
||||
from modules.security.auth import limiter
|
||||
|
||||
# Import the attribute definition and helper functions
|
||||
from modules.interfaces.interfaceAppModel import User
|
||||
from modules.shared.attributeUtils import getModelClasses, getModelAttributeDefinitions, AttributeResponse, AttributeDefinition
|
||||
|
||||
# Configure logger
|
||||
|
|
|
|||
125
modules/routes/routeChatPlayground.py
Normal file
125
modules/routes/routeChatPlayground.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
"""
|
||||
Chat Playground routes for the backend API.
|
||||
Implements the endpoints for chat playground workflow management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Request
|
||||
|
||||
# Import auth modules
|
||||
from modules.security.auth import limiter, getCurrentUser
|
||||
|
||||
# Import interfaces
|
||||
import modules.interfaces.interfaceDbChatObjects as interfaceDbChatObjects
|
||||
|
||||
# Import models
|
||||
from modules.datamodels.datamodelChat import ChatWorkflow, UserInputRequest, WorkflowModeEnum
|
||||
from modules.datamodels.datamodelUam import User
|
||||
|
||||
# Import workflow control functions
|
||||
from modules.features.chatPlayground.mainChatPlayground import chatStart, chatStop
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create router for chat playground endpoints
|
||||
router = APIRouter(
|
||||
prefix="/api/chat/playground",
|
||||
tags=["Chat Playground"],
|
||||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
def getServiceChat(currentUser: User):
|
||||
return interfaceDbChatObjects.getInterface(currentUser)
|
||||
|
||||
# Workflow start endpoint
|
||||
@router.post("/start", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def start_workflow(
|
||||
request: Request,
|
||||
workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue"),
|
||||
workflowMode: WorkflowModeEnum = Query(..., description="Workflow mode: 'Actionplan', 'Dynamic', or 'Template' (mandatory)"),
|
||||
userInput: UserInputRequest = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> ChatWorkflow:
|
||||
"""
|
||||
Starts a new workflow or continues an existing one.
|
||||
Corresponds to State 1 in the state machine documentation.
|
||||
|
||||
Args:
|
||||
workflowMode: "Actionplan" for traditional task planning, "Dynamic" for iterative dynamic-style processing, "Template" for template-based processing
|
||||
"""
|
||||
try:
|
||||
# Start or continue workflow using playground controller
|
||||
workflow = await chatStart(currentUser, userInput, workflowMode, workflowId)
|
||||
|
||||
return workflow
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in start_workflow: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
# State 8: Workflow Stopped endpoint
|
||||
@router.post("/{workflowId}/stop", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def stop_workflow(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow to stop"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> ChatWorkflow:
|
||||
"""Stops a running workflow."""
|
||||
try:
|
||||
# Stop workflow using playground controller
|
||||
workflow = await chatStop(currentUser, workflowId)
|
||||
|
||||
return workflow
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stop_workflow: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
# Unified Chat Data Endpoint for Polling
|
||||
@router.get("/{workflowId}/chatData")
|
||||
@limiter.limit("120/minute")
|
||||
async def get_workflow_chat_data(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
afterTimestamp: Optional[float] = Query(None, description="Unix timestamp to get data after"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get unified chat data (messages, logs, stats) for a workflow with timestamp-based selective data transfer.
|
||||
Returns all data types in chronological order based on _createdAt timestamp.
|
||||
"""
|
||||
try:
|
||||
# Get service center
|
||||
interfaceDbChat = getServiceChat(currentUser)
|
||||
|
||||
# Verify workflow exists
|
||||
workflow = interfaceDbChat.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Workflow with ID {workflowId} not found"
|
||||
)
|
||||
|
||||
# Get unified chat data using the new method
|
||||
chatData = interfaceDbChat.getUnifiedChatData(workflowId, afterTimestamp)
|
||||
|
||||
return chatData
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting unified chat data: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error getting unified chat data: {str(e)}"
|
||||
)
|
||||
236
modules/routes/routeDataAutomation.py
Normal file
236
modules/routes/routeDataAutomation.py
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
"""
|
||||
Automation routes for the backend API.
|
||||
Implements the endpoints for automation definition management.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response, Query
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import status
|
||||
import logging
|
||||
import json
|
||||
|
||||
# Import interfaces and models
|
||||
from modules.interfaces.interfaceDbChatObjects import getInterface as getChatInterface
|
||||
from modules.security.auth import getCurrentUser, limiter
|
||||
from modules.datamodels.datamodelChat import AutomationDefinition, ChatWorkflow
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Model attributes for AutomationDefinition
|
||||
automationAttributes = getModelAttributeDefinitions(AutomationDefinition)
|
||||
|
||||
# Create router for automation endpoints
|
||||
router = APIRouter(
|
||||
prefix="/api/automations",
|
||||
tags=["Manage Automations"],
|
||||
responses={
|
||||
404: {"description": "Not found"},
|
||||
400: {"description": "Bad request"},
|
||||
401: {"description": "Unauthorized"},
|
||||
403: {"description": "Forbidden"},
|
||||
500: {"description": "Internal server error"}
|
||||
}
|
||||
)
|
||||
|
||||
@router.get("", response_model=PaginatedResponse[AutomationDefinition])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_automations(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser = Depends(getCurrentUser)
|
||||
) -> PaginatedResponse[AutomationDefinition]:
|
||||
"""
|
||||
Get automation definitions with optional pagination, sorting, and filtering.
|
||||
|
||||
Query Parameters:
|
||||
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
||||
"""
|
||||
try:
|
||||
# Parse pagination parameter
|
||||
paginationParams = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid pagination parameter: {str(e)}"
|
||||
)
|
||||
|
||||
chatInterface = getChatInterface(currentUser)
|
||||
result = chatInterface.getAllAutomationDefinitions(pagination=paginationParams)
|
||||
|
||||
# If pagination was requested, result is PaginatedResult
|
||||
# If no pagination, result is List[Dict]
|
||||
if paginationParams:
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=result.totalItems,
|
||||
totalPages=result.totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PaginatedResponse(
|
||||
items=result,
|
||||
pagination=None
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting automations: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error getting automations: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("", response_model=AutomationDefinition)
|
||||
@limiter.limit("10/minute")
|
||||
async def create_automation(
|
||||
request: Request,
|
||||
automation: AutomationDefinition,
|
||||
currentUser = Depends(getCurrentUser)
|
||||
) -> AutomationDefinition:
|
||||
"""Create a new automation definition"""
|
||||
try:
|
||||
chatInterface = getChatInterface(currentUser)
|
||||
automationData = automation.model_dump()
|
||||
created = chatInterface.createAutomationDefinition(automationData)
|
||||
return AutomationDefinition(**created)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating automation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error creating automation: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/{automationId}", response_model=AutomationDefinition)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_automation(
|
||||
request: Request,
|
||||
automationId: str = Path(..., description="Automation ID"),
|
||||
currentUser = Depends(getCurrentUser)
|
||||
) -> AutomationDefinition:
|
||||
"""Get a single automation definition by ID"""
|
||||
try:
|
||||
chatInterface = getChatInterface(currentUser)
|
||||
automation = chatInterface.getAutomationDefinition(automationId)
|
||||
if not automation:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Automation {automationId} not found"
|
||||
)
|
||||
|
||||
return AutomationDefinition(**automation)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting automation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error getting automation: {str(e)}"
|
||||
)
|
||||
|
||||
@router.put("/{automationId}", response_model=AutomationDefinition)
|
||||
@limiter.limit("10/minute")
|
||||
async def update_automation(
|
||||
request: Request,
|
||||
automationId: str = Path(..., description="Automation ID"),
|
||||
automation: AutomationDefinition = Body(...),
|
||||
currentUser = Depends(getCurrentUser)
|
||||
) -> AutomationDefinition:
|
||||
"""Update an automation definition"""
|
||||
try:
|
||||
chatInterface = getChatInterface(currentUser)
|
||||
automationData = automation.model_dump()
|
||||
updated = chatInterface.updateAutomationDefinition(automationId, automationData)
|
||||
return AutomationDefinition(**updated)
|
||||
except HTTPException:
|
||||
raise
|
||||
except PermissionError as e:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating automation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error updating automation: {str(e)}"
|
||||
)
|
||||
|
||||
@router.delete("/{automationId}")
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_automation(
|
||||
request: Request,
|
||||
automationId: str = Path(..., description="Automation ID"),
|
||||
currentUser = Depends(getCurrentUser)
|
||||
) -> Response:
|
||||
"""Delete an automation definition"""
|
||||
try:
|
||||
chatInterface = getChatInterface(currentUser)
|
||||
success = chatInterface.deleteAutomationDefinition(automationId)
|
||||
if success:
|
||||
return Response(status_code=204)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to delete automation"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except PermissionError as e:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting automation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error deleting automation: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/{automationId}/execute", response_model=ChatWorkflow)
|
||||
@limiter.limit("5/minute")
|
||||
async def execute_automation(
|
||||
request: Request,
|
||||
automationId: str = Path(..., description="Automation ID"),
|
||||
currentUser = Depends(getCurrentUser)
|
||||
) -> ChatWorkflow:
|
||||
"""Execute an automation immediately (test mode)"""
|
||||
try:
|
||||
chatInterface = getChatInterface(currentUser)
|
||||
workflow = await chatInterface.executeAutomation(automationId)
|
||||
return workflow
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=str(e)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing automation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error executing automation: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/attributes", response_model=Dict[str, Any])
|
||||
async def get_automation_attributes(
|
||||
request: Request
|
||||
) -> Dict[str, Any]:
|
||||
"""Get attribute definitions for AutomationDefinition model"""
|
||||
return {"attributes": automationAttributes}
|
||||
|
||||
|
|
@ -11,67 +11,74 @@ SECURITY NOTE:
|
|||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import status
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import json
|
||||
|
||||
from modules.interfaces.interfaceAppModel import User, UserConnection, AuthAuthority, ConnectionStatus, Token
|
||||
from modules.datamodels.datamodelUam import User, UserConnection, AuthAuthority, ConnectionStatus
|
||||
from modules.datamodels.datamodelSecurity import Token
|
||||
from modules.security.auth import getCurrentUser, limiter
|
||||
from modules.interfaces.interfaceAppObjects import getInterface, getRootInterface
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
from modules.interfaces.interfaceDbAppObjects import getInterface
|
||||
from modules.shared.timezoneUtils import getUtcTimestamp
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_token_status_for_connection(interface, connection_id: str) -> tuple[str, Optional[float]]:
|
||||
def getTokenStatusForConnection(interface, connectionId: str) -> tuple[str, Optional[float]]:
|
||||
"""
|
||||
Get token status and expiration for a connection.
|
||||
|
||||
Args:
|
||||
interface: The database interface
|
||||
connection_id: The connection ID to check
|
||||
connectionId: The connection ID to check
|
||||
|
||||
Returns:
|
||||
tuple: (token_status, token_expires_at)
|
||||
- token_status: 'active', 'expired', or 'none'
|
||||
- token_expires_at: UTC timestamp or None
|
||||
tuple: (tokenStatus, tokenExpiresAt)
|
||||
- tokenStatus: 'active', 'expired', or 'none'
|
||||
- tokenExpiresAt: UTC timestamp or None
|
||||
"""
|
||||
try:
|
||||
# Query tokens table for the latest token for this connection
|
||||
tokens = interface.db.getRecordset(
|
||||
Token,
|
||||
recordFilter={"connectionId": connection_id}
|
||||
recordFilter={"connectionId": connectionId}
|
||||
)
|
||||
|
||||
if not tokens:
|
||||
return "none", None
|
||||
|
||||
# Find the most recent token (highest createdAt timestamp)
|
||||
latest_token = None
|
||||
latest_created_at = 0
|
||||
latestToken = None
|
||||
latestCreatedAt = 0
|
||||
|
||||
for token_data in tokens:
|
||||
created_at = token_data.get("createdAt", 0)
|
||||
if created_at > latest_created_at:
|
||||
latest_created_at = created_at
|
||||
latest_token = token_data
|
||||
for tokenData in tokens:
|
||||
createdAt = tokenData.get("createdAt", 0)
|
||||
if createdAt > latestCreatedAt:
|
||||
latestCreatedAt = createdAt
|
||||
latestToken = tokenData
|
||||
|
||||
if not latest_token:
|
||||
if not latestToken:
|
||||
return "none", None
|
||||
|
||||
# Check if token is expired
|
||||
expires_at = latest_token.get("expiresAt")
|
||||
if not expires_at:
|
||||
expiresAt = latestToken.get("expiresAt")
|
||||
if not expiresAt:
|
||||
return "none", None
|
||||
|
||||
current_time = get_utc_timestamp()
|
||||
if expires_at <= current_time:
|
||||
return "expired", expires_at
|
||||
currentTime = getUtcTimestamp()
|
||||
|
||||
# Add 5 minute buffer for proactive refresh
|
||||
bufferTime = 5 * 60 # 5 minutes in seconds
|
||||
if expiresAt <= currentTime:
|
||||
return "expired", expiresAt
|
||||
elif expiresAt <= (currentTime + bufferTime):
|
||||
# Token expires soon - mark as active but log for proactive refresh
|
||||
logger.debug(f"Token for connection {connectionId} expires soon (in {expiresAt - currentTime} seconds)")
|
||||
return "active", expiresAt
|
||||
else:
|
||||
return "active", expires_at
|
||||
return "active", expiresAt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting token status for connection {connection_id}: {str(e)}")
|
||||
logger.error(f"Error getting token status for connection {connectionId}: {str(e)}")
|
||||
return "none", None
|
||||
|
||||
router = APIRouter(
|
||||
|
|
@ -89,6 +96,7 @@ async def get_connections(
|
|||
"""Get all connections for the current user
|
||||
|
||||
SECURITY: This endpoint is secure - users can only see their own connections.
|
||||
Automatically refreshes expired OAuth tokens in the background.
|
||||
"""
|
||||
try:
|
||||
interface = getInterface(currentUser)
|
||||
|
|
@ -97,11 +105,23 @@ async def get_connections(
|
|||
# This prevents admin from seeing other users' connections and causing confusion
|
||||
connections = interface.getUserConnections(currentUser.id)
|
||||
|
||||
# Perform silent token refresh for expired OAuth connections
|
||||
try:
|
||||
from modules.security.tokenRefreshService import token_refresh_service
|
||||
refresh_result = await token_refresh_service.refresh_expired_tokens(currentUser.id)
|
||||
if refresh_result.get("refreshed", 0) > 0:
|
||||
logger.info(f"Silently refreshed {refresh_result['refreshed']} tokens for user {currentUser.id}")
|
||||
# Re-fetch connections to get updated token status
|
||||
connections = interface.getUserConnections(currentUser.id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Silent token refresh failed for user {currentUser.id}: {str(e)}")
|
||||
# Continue with original connections even if refresh fails
|
||||
|
||||
# Enhance each connection with token status information
|
||||
enhanced_connections = []
|
||||
for connection in connections:
|
||||
# Get token status for this connection
|
||||
token_status, token_expires_at = get_token_status_for_connection(interface, connection.id)
|
||||
tokenStatus, tokenExpiresAt = getTokenStatusForConnection(interface, connection.id)
|
||||
|
||||
# Create enhanced connection with token status
|
||||
enhanced_connection = UserConnection(
|
||||
|
|
@ -115,8 +135,8 @@ async def get_connections(
|
|||
connectedAt=connection.connectedAt,
|
||||
lastChecked=connection.lastChecked,
|
||||
expiresAt=connection.expiresAt,
|
||||
tokenStatus=token_status,
|
||||
tokenExpiresAt=token_expires_at
|
||||
tokenStatus=tokenStatus,
|
||||
tokenExpiresAt=tokenExpiresAt
|
||||
)
|
||||
enhanced_connections.append(enhanced_connection)
|
||||
|
||||
|
|
@ -176,7 +196,7 @@ async def create_connection(
|
|||
)
|
||||
|
||||
# Save connection record - models now handle timestamp serialization automatically
|
||||
interface.db.recordModify(UserConnection, connection.id, connection.to_dict())
|
||||
interface.db.recordModify(UserConnection, connection.id, connection.model_dump())
|
||||
|
||||
|
||||
return connection
|
||||
|
|
@ -227,14 +247,13 @@ async def update_connection(
|
|||
setattr(connection, field, value)
|
||||
|
||||
# Update lastChecked timestamp using UTC timestamp
|
||||
connection.lastChecked = get_utc_timestamp()
|
||||
connection.lastChecked = getUtcTimestamp()
|
||||
|
||||
# Update connection - models now handle timestamp serialization automatically
|
||||
interface.db.recordModify(UserConnection, connectionId, connection.to_dict())
|
||||
|
||||
interface.db.recordModify(UserConnection, connectionId, connection.model_dump())
|
||||
|
||||
# Get token status for the updated connection
|
||||
token_status, token_expires_at = get_token_status_for_connection(interface, connectionId)
|
||||
tokenStatus, tokenExpiresAt = getTokenStatusForConnection(interface, connectionId)
|
||||
|
||||
# Create enhanced connection with token status
|
||||
enhanced_connection = UserConnection(
|
||||
|
|
@ -248,8 +267,8 @@ async def update_connection(
|
|||
connectedAt=connection.connectedAt,
|
||||
lastChecked=connection.lastChecked,
|
||||
expiresAt=connection.expiresAt,
|
||||
tokenStatus=token_status,
|
||||
tokenExpiresAt=token_expires_at
|
||||
tokenStatus=tokenStatus,
|
||||
tokenExpiresAt=tokenExpiresAt
|
||||
)
|
||||
|
||||
return enhanced_connection
|
||||
|
|
@ -362,10 +381,10 @@ async def disconnect_service(
|
|||
|
||||
# Update connection status
|
||||
connection.status = ConnectionStatus.INACTIVE
|
||||
connection.lastChecked = get_utc_timestamp()
|
||||
connection.lastChecked = getUtcTimestamp()
|
||||
|
||||
# Update connection record - models now handle timestamp serialization automatically
|
||||
interface.db.recordModify(UserConnection, connectionId, connection.to_dict())
|
||||
interface.db.recordModify(UserConnection, connectionId, connection.model_dump())
|
||||
|
||||
|
||||
return {"message": "Service disconnected successfully"}
|
||||
|
|
|
|||
|
|
@ -1,24 +1,18 @@
|
|||
from fastapi import APIRouter, HTTPException, Depends, File, UploadFile, Form, Path, Request, status, Query, Response, Body
|
||||
from fastapi.responses import JSONResponse, FileResponse
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass
|
||||
import io
|
||||
import inspect
|
||||
import importlib
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
# Import auth module
|
||||
from modules.security.auth import limiter, getCurrentUser
|
||||
|
||||
# Import interfaces
|
||||
import modules.interfaces.interfaceComponentObjects as interfaceComponentObjects
|
||||
from modules.interfaces.interfaceComponentModel import FileItem, FilePreview
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions, AttributeResponse, AttributeDefinition
|
||||
from modules.interfaces.interfaceAppModel import User, DataNeutraliserConfig, DataNeutralizerAttributes
|
||||
from modules.services.serviceNeutralization import NeutralizationService
|
||||
import modules.interfaces.interfaceDbComponentObjects as interfaceDbComponentObjects
|
||||
from modules.datamodels.datamodelFiles import FileItem, FilePreview
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -39,21 +33,61 @@ router = APIRouter(
|
|||
}
|
||||
)
|
||||
|
||||
@router.get("/list", response_model=List[FileItem])
|
||||
@router.get("/list", response_model=PaginatedResponse[FileItem])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_files(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[FileItem]:
|
||||
"""Get all files"""
|
||||
) -> PaginatedResponse[FileItem]:
|
||||
"""
|
||||
Get files with optional pagination, sorting, and filtering.
|
||||
|
||||
Query Parameters:
|
||||
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
||||
|
||||
Examples:
|
||||
- GET /api/files/list (no pagination - returns all items)
|
||||
- GET /api/files/list?pagination={"page":1,"pageSize":10,"sort":[]}
|
||||
- GET /api/files/list?pagination={"page":2,"pageSize":20,"sort":[{"field":"fileName","direction":"asc"}]}
|
||||
"""
|
||||
try:
|
||||
managementInterface = interfaceComponentObjects.getInterface(currentUser)
|
||||
# Parse pagination parameter
|
||||
paginationParams = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid pagination parameter: {str(e)}"
|
||||
)
|
||||
|
||||
# Get all files generically - only metadata, no binary data
|
||||
files = managementInterface.getAllFiles()
|
||||
managementInterface = interfaceDbComponentObjects.getInterface(currentUser)
|
||||
result = managementInterface.getAllFiles(pagination=paginationParams)
|
||||
|
||||
# Return files directly since they are already FileItem objects
|
||||
return files
|
||||
# If pagination was requested, result is PaginatedResult
|
||||
# If no pagination, result is List[FileItem]
|
||||
if paginationParams:
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=result.totalItems,
|
||||
totalPages=result.totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PaginatedResponse(
|
||||
items=result,
|
||||
pagination=None
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting files: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
|
@ -73,17 +107,17 @@ async def upload_file(
|
|||
file.fileName = file.filename
|
||||
"""Upload a file"""
|
||||
try:
|
||||
managementInterface = interfaceComponentObjects.getInterface(currentUser)
|
||||
managementInterface = interfaceDbComponentObjects.getInterface(currentUser)
|
||||
|
||||
# Read file
|
||||
fileContent = await file.read()
|
||||
|
||||
# Check size limits
|
||||
maxSize = int(interfaceComponentObjects.APP_CONFIG.get("File_Management_MAX_UPLOAD_SIZE_MB")) * 1024 * 1024 # in bytes
|
||||
maxSize = int(interfaceDbComponentObjects.APP_CONFIG.get("File_Management_MAX_UPLOAD_SIZE_MB")) * 1024 * 1024 # in bytes
|
||||
if len(fileContent) > maxSize:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
detail=f"File too large. Maximum size: {interfaceComponentObjects.APP_CONFIG.get('File_Management_MAX_UPLOAD_SIZE_MB')}MB"
|
||||
detail=f"File too large. Maximum size: {interfaceDbComponentObjects.APP_CONFIG.get('File_Management_MAX_UPLOAD_SIZE_MB')}MB"
|
||||
)
|
||||
|
||||
# Save file via LucyDOM interface in the database
|
||||
|
|
@ -104,7 +138,7 @@ async def upload_file(
|
|||
fileItem.workflowId = workflowId
|
||||
|
||||
# Convert FileItem to dictionary for JSON response
|
||||
fileMeta = fileItem.to_dict()
|
||||
fileMeta = fileItem.model_dump()
|
||||
|
||||
# Response with duplicate information
|
||||
return JSONResponse({
|
||||
|
|
@ -116,7 +150,7 @@ async def upload_file(
|
|||
"isDuplicate": duplicateType != "new_file"
|
||||
})
|
||||
|
||||
except interfaceComponentObjects.FileStorageError as e:
|
||||
except interfaceDbComponentObjects.FileStorageError as e:
|
||||
logger.error(f"Error during file upload (storage): {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
@ -138,7 +172,7 @@ async def get_file(
|
|||
) -> FileItem:
|
||||
"""Get a file"""
|
||||
try:
|
||||
managementInterface = interfaceComponentObjects.getInterface(currentUser)
|
||||
managementInterface = interfaceDbComponentObjects.getInterface(currentUser)
|
||||
|
||||
# Get file via LucyDOM interface from the database
|
||||
fileData = managementInterface.getFile(fileId)
|
||||
|
|
@ -150,19 +184,19 @@ async def get_file(
|
|||
|
||||
return fileData
|
||||
|
||||
except interfaceComponentObjects.FileNotFoundError as e:
|
||||
except interfaceDbComponentObjects.FileNotFoundError as e:
|
||||
logger.warning(f"File not found: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e)
|
||||
)
|
||||
except interfaceComponentObjects.FilePermissionError as e:
|
||||
except interfaceDbComponentObjects.FilePermissionError as e:
|
||||
logger.warning(f"No permission for file: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e)
|
||||
)
|
||||
except interfaceComponentObjects.FileError as e:
|
||||
except interfaceDbComponentObjects.FileError as e:
|
||||
logger.error(f"Error retrieving file: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
|
|
@ -185,7 +219,7 @@ async def update_file(
|
|||
) -> FileItem:
|
||||
"""Update file info"""
|
||||
try:
|
||||
managementInterface = interfaceComponentObjects.getInterface(currentUser)
|
||||
managementInterface = interfaceDbComponentObjects.getInterface(currentUser)
|
||||
|
||||
# Get the file from the database
|
||||
file = managementInterface.getFile(fileId)
|
||||
|
|
@ -231,7 +265,7 @@ async def delete_file(
|
|||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Delete a file"""
|
||||
managementInterface = interfaceComponentObjects.getInterface(currentUser)
|
||||
managementInterface = interfaceDbComponentObjects.getInterface(currentUser)
|
||||
|
||||
# Check if the file exists
|
||||
existingFile = managementInterface.getFile(fileId)
|
||||
|
|
@ -258,7 +292,7 @@ async def get_file_stats(
|
|||
) -> Dict[str, Any]:
|
||||
"""Returns statistics about the stored files"""
|
||||
try:
|
||||
managementInterface = interfaceComponentObjects.getInterface(currentUser)
|
||||
managementInterface = interfaceDbComponentObjects.getInterface(currentUser)
|
||||
|
||||
# Get all files - metadata only
|
||||
allFiles = managementInterface.getAllFiles()
|
||||
|
|
@ -297,7 +331,7 @@ async def download_file(
|
|||
) -> Response:
|
||||
"""Download a file"""
|
||||
try:
|
||||
managementInterface = interfaceComponentObjects.getInterface(currentUser)
|
||||
managementInterface = interfaceDbComponentObjects.getInterface(currentUser)
|
||||
|
||||
# Get file data
|
||||
fileData = managementInterface.getFile(fileId)
|
||||
|
|
@ -345,7 +379,7 @@ async def preview_file(
|
|||
) -> FilePreview:
|
||||
"""Preview a file's content"""
|
||||
try:
|
||||
managementInterface = interfaceComponentObjects.getInterface(currentUser)
|
||||
managementInterface = interfaceDbComponentObjects.getInterface(currentUser)
|
||||
|
||||
# Get file preview using the correct method
|
||||
preview = managementInterface.getFileContent(fileId)
|
||||
|
|
@ -365,253 +399,4 @@ async def preview_file(
|
|||
detail=f"Error previewing file: {str(e)}"
|
||||
)
|
||||
|
||||
# Data Neutralization endpoints
|
||||
|
||||
@router.get("/neutralization/config", response_model=DataNeutraliserConfig)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_neutralization_config(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> DataNeutraliserConfig:
|
||||
"""Get data neutralization configuration"""
|
||||
try:
|
||||
service = NeutralizationService(currentUser)
|
||||
config = service.get_config()
|
||||
|
||||
if not config:
|
||||
# Return default config instead of 404
|
||||
return DataNeutraliserConfig(
|
||||
mandateId=currentUser.mandateId,
|
||||
userId=currentUser.id,
|
||||
enabled=True,
|
||||
namesToParse="",
|
||||
sharepointSourcePath="",
|
||||
sharepointTargetPath=""
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting neutralization config: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error getting neutralization config: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/neutralization/config", response_model=DataNeutraliserConfig)
|
||||
@limiter.limit("10/minute")
|
||||
async def save_neutralization_config(
|
||||
request: Request,
|
||||
config_data: Dict[str, Any] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> DataNeutraliserConfig:
|
||||
"""Save or update data neutralization configuration"""
|
||||
try:
|
||||
service = NeutralizationService(currentUser)
|
||||
config = service.save_config(config_data)
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving neutralization config: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error saving neutralization config: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/neutralization/neutralize-text", response_model=Dict[str, Any])
|
||||
@limiter.limit("20/minute")
|
||||
async def neutralize_text(
|
||||
request: Request,
|
||||
text_data: Dict[str, Any] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Neutralize text content"""
|
||||
try:
|
||||
text = text_data.get("text", "")
|
||||
file_id = text_data.get("fileId")
|
||||
|
||||
if not text:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Text content is required"
|
||||
)
|
||||
|
||||
service = NeutralizationService(currentUser)
|
||||
result = service.neutralize_text(text, file_id)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error neutralizing text: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error neutralizing text: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/neutralization/resolve-text", response_model=Dict[str, str])
|
||||
@limiter.limit("20/minute")
|
||||
async def resolve_text(
|
||||
request: Request,
|
||||
text_data: Dict[str, str] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, str]:
|
||||
"""Resolve UIDs in neutralized text back to original text"""
|
||||
try:
|
||||
text = text_data.get("text", "")
|
||||
|
||||
if not text:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Text content is required"
|
||||
)
|
||||
|
||||
service = NeutralizationService(currentUser)
|
||||
resolved_text = service.resolve_text(text)
|
||||
|
||||
return {"resolved_text": resolved_text}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error resolving text: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error resolving text: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/neutralization/attributes", response_model=List[DataNeutralizerAttributes])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_neutralization_attributes(
|
||||
request: Request,
|
||||
fileId: Optional[str] = Query(None, description="Filter by file ID"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[DataNeutralizerAttributes]:
|
||||
"""Get neutralization attributes, optionally filtered by file ID"""
|
||||
try:
|
||||
service = NeutralizationService(currentUser)
|
||||
attributes = service.get_attributes(fileId)
|
||||
|
||||
return attributes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting neutralization attributes: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error getting neutralization attributes: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/neutralization/process-sharepoint", response_model=Dict[str, Any])
|
||||
@limiter.limit("5/minute")
|
||||
async def process_sharepoint_files(
|
||||
request: Request,
|
||||
paths_data: Dict[str, str] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Process files from SharePoint source path and store neutralized files in target path"""
|
||||
try:
|
||||
source_path = paths_data.get("sourcePath", "")
|
||||
target_path = paths_data.get("targetPath", "")
|
||||
|
||||
if not source_path or not target_path:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Both source and target paths are required"
|
||||
)
|
||||
|
||||
service = NeutralizationService(currentUser)
|
||||
result = await service.process_sharepoint_files(source_path, target_path)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing SharePoint files: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error processing SharePoint files: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/neutralization/batch-process", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def batch_process_files(
|
||||
request: Request,
|
||||
files_data: List[Dict[str, Any]] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Process multiple files for neutralization"""
|
||||
try:
|
||||
if not files_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Files data is required"
|
||||
)
|
||||
|
||||
service = NeutralizationService(currentUser)
|
||||
result = service.batch_neutralize_files(files_data)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error batch processing files: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error batch processing files: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/neutralization/stats", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_neutralization_stats(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get neutralization processing statistics"""
|
||||
try:
|
||||
service = NeutralizationService(currentUser)
|
||||
stats = service.get_processing_stats()
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting neutralization stats: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error getting neutralization stats: {str(e)}"
|
||||
)
|
||||
|
||||
@router.delete("/neutralization/attributes/{fileId}", response_model=Dict[str, str])
|
||||
@limiter.limit("10/minute")
|
||||
async def cleanup_file_attributes(
|
||||
request: Request,
|
||||
fileId: str = Path(..., description="File ID to cleanup attributes for"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, str]:
|
||||
"""Clean up neutralization attributes for a specific file"""
|
||||
try:
|
||||
service = NeutralizationService(currentUser)
|
||||
success = service.cleanup_file_attributes(fileId)
|
||||
|
||||
if success:
|
||||
return {"message": f"Successfully cleaned up attributes for file {fileId}"}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cleanup file attributes"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up file attributes: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error cleaning up file attributes: {str(e)}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,25 +3,22 @@ Mandate routes for the backend API.
|
|||
Implements the endpoints for mandate management.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response, Query
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import status
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import inspect
|
||||
import importlib
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
# Import auth module
|
||||
from modules.security.auth import limiter, getCurrentUser
|
||||
|
||||
# Import interfaces
|
||||
import modules.interfaces.interfaceAppObjects as interfaceAppObjects
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions, AttributeResponse, AttributeDefinition
|
||||
import modules.interfaces.interfaceDbAppObjects as interfaceDbAppObjects
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions
|
||||
|
||||
# Import the model classes
|
||||
from modules.interfaces.interfaceAppModel import Mandate, User
|
||||
from modules.datamodels.datamodelUam import Mandate, User
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -36,17 +33,60 @@ router = APIRouter(
|
|||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
@router.get("/", response_model=List[Mandate])
|
||||
@router.get("/", response_model=PaginatedResponse[Mandate])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_mandates(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[Mandate]:
|
||||
"""Get all mandates"""
|
||||
) -> PaginatedResponse[Mandate]:
|
||||
"""
|
||||
Get mandates with optional pagination, sorting, and filtering.
|
||||
|
||||
Query Parameters:
|
||||
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
||||
|
||||
Examples:
|
||||
- GET /api/mandates/ (no pagination - returns all items)
|
||||
- GET /api/mandates/?pagination={"page":1,"pageSize":10,"sort":[]}
|
||||
"""
|
||||
try:
|
||||
appInterface = interfaceAppObjects.getInterface(currentUser)
|
||||
mandates = appInterface.getAllMandates()
|
||||
return mandates
|
||||
# Parse pagination parameter
|
||||
paginationParams = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid pagination parameter: {str(e)}"
|
||||
)
|
||||
|
||||
appInterface = interfaceDbAppObjects.getInterface(currentUser)
|
||||
result = appInterface.getAllMandates(pagination=paginationParams)
|
||||
|
||||
# If pagination was requested, result is PaginatedResult
|
||||
# If no pagination, result is List[Mandate]
|
||||
if paginationParams:
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=result.totalItems,
|
||||
totalPages=result.totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PaginatedResponse(
|
||||
items=result,
|
||||
pagination=None
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting mandates: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
|
@ -63,7 +103,7 @@ async def get_mandate(
|
|||
) -> Mandate:
|
||||
"""Get a specific mandate by ID"""
|
||||
try:
|
||||
appInterface = interfaceAppObjects.getInterface(currentUser)
|
||||
appInterface = interfaceDbAppObjects.getInterface(currentUser)
|
||||
mandate = appInterface.getMandate(mandateId)
|
||||
|
||||
if not mandate:
|
||||
|
|
@ -91,7 +131,7 @@ async def create_mandate(
|
|||
) -> Mandate:
|
||||
"""Create a new mandate"""
|
||||
try:
|
||||
appInterface = interfaceAppObjects.getInterface(currentUser)
|
||||
appInterface = interfaceDbAppObjects.getInterface(currentUser)
|
||||
|
||||
# Create mandate
|
||||
newMandate = appInterface.createMandate(
|
||||
|
|
@ -125,7 +165,7 @@ async def update_mandate(
|
|||
) -> Mandate:
|
||||
"""Update an existing mandate"""
|
||||
try:
|
||||
appInterface = interfaceAppObjects.getInterface(currentUser)
|
||||
appInterface = interfaceDbAppObjects.getInterface(currentUser)
|
||||
|
||||
# Check if mandate exists
|
||||
existingMandate = appInterface.getMandate(mandateId)
|
||||
|
|
@ -136,7 +176,7 @@ async def update_mandate(
|
|||
)
|
||||
|
||||
# Update mandate
|
||||
updatedMandate = appInterface.updateMandate(mandateId, mandateData.to_dict())
|
||||
updatedMandate = appInterface.updateMandate(mandateId, mandateData.model_dump())
|
||||
|
||||
if not updatedMandate:
|
||||
raise HTTPException(
|
||||
|
|
@ -163,7 +203,7 @@ async def delete_mandate(
|
|||
) -> Dict[str, Any]:
|
||||
"""Delete a mandate"""
|
||||
try:
|
||||
appInterface = interfaceAppObjects.getInterface(currentUser)
|
||||
appInterface = interfaceDbAppObjects.getInterface(currentUser)
|
||||
|
||||
# Check if mandate exists
|
||||
existingMandate = appInterface.getMandate(mandateId)
|
||||
|
|
|
|||
275
modules/routes/routeDataNeutralization.py
Normal file
275
modules/routes/routeDataNeutralization.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
from fastapi import APIRouter, HTTPException, Depends, Path, Request, status, Query, Body
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
|
||||
# Import auth module
|
||||
from modules.security.auth import limiter, getCurrentUser
|
||||
|
||||
# Import interfaces
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelNeutralizer import DataNeutraliserConfig, DataNeutralizerAttributes
|
||||
from modules.features.neutralizePlayground.mainNeutralizePlayground import NeutralizationPlayground
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create router for neutralization endpoints
|
||||
router = APIRouter(
|
||||
prefix="/api/neutralization",
|
||||
tags=["Data Neutralisation"],
|
||||
responses={
|
||||
404: {"description": "Not found"},
|
||||
400: {"description": "Bad request"},
|
||||
401: {"description": "Unauthorized"},
|
||||
403: {"description": "Forbidden"},
|
||||
500: {"description": "Internal server error"}
|
||||
}
|
||||
)
|
||||
|
||||
@router.get("/config", response_model=DataNeutraliserConfig)
|
||||
@limiter.limit("30/minute")
|
||||
async def get_neutralization_config(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> DataNeutraliserConfig:
|
||||
"""Get data neutralization configuration"""
|
||||
try:
|
||||
service = NeutralizationPlayground(currentUser)
|
||||
config = service.getConfig()
|
||||
|
||||
if not config:
|
||||
# Return default config instead of 404
|
||||
return DataNeutraliserConfig(
|
||||
mandateId=currentUser.mandateId,
|
||||
userId=currentUser.id,
|
||||
enabled=True,
|
||||
namesToParse="",
|
||||
sharepointSourcePath="",
|
||||
sharepointTargetPath=""
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting neutralization config: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error getting neutralization config: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/config", response_model=DataNeutraliserConfig)
|
||||
@limiter.limit("10/minute")
|
||||
async def save_neutralization_config(
|
||||
request: Request,
|
||||
config_data: Dict[str, Any] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> DataNeutraliserConfig:
|
||||
"""Save or update data neutralization configuration"""
|
||||
try:
|
||||
service = NeutralizationPlayground(currentUser)
|
||||
config = service.saveConfig(config_data)
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving neutralization config: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error saving neutralization config: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/neutralize-text", response_model=Dict[str, Any])
|
||||
@limiter.limit("20/minute")
|
||||
async def neutralize_text(
|
||||
request: Request,
|
||||
text_data: Dict[str, Any] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Neutralize text content"""
|
||||
try:
|
||||
text = text_data.get("text", "")
|
||||
file_id = text_data.get("fileId")
|
||||
|
||||
if not text:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Text content is required"
|
||||
)
|
||||
|
||||
service = NeutralizationPlayground(currentUser)
|
||||
result = service.neutralizeText(text, file_id)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error neutralizing text: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error neutralizing text: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/resolve-text", response_model=Dict[str, str])
|
||||
@limiter.limit("20/minute")
|
||||
async def resolve_text(
|
||||
request: Request,
|
||||
text_data: Dict[str, str] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, str]:
|
||||
"""Resolve UIDs in neutralized text back to original text"""
|
||||
try:
|
||||
text = text_data.get("text", "")
|
||||
|
||||
if not text:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Text content is required"
|
||||
)
|
||||
|
||||
service = NeutralizationPlayground(currentUser)
|
||||
resolved_text = service.resolveText(text)
|
||||
|
||||
return {"resolved_text": resolved_text}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error resolving text: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error resolving text: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/attributes", response_model=List[DataNeutralizerAttributes])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_neutralization_attributes(
|
||||
request: Request,
|
||||
fileId: Optional[str] = Query(None, description="Filter by file ID"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[DataNeutralizerAttributes]:
|
||||
"""Get neutralization attributes, optionally filtered by file ID"""
|
||||
try:
|
||||
service = NeutralizationPlayground(currentUser)
|
||||
attributes = service.getAttributes(fileId)
|
||||
|
||||
return attributes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting neutralization attributes: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error getting neutralization attributes: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/process-sharepoint", response_model=Dict[str, Any])
|
||||
@limiter.limit("5/minute")
|
||||
async def process_sharepoint_files(
|
||||
request: Request,
|
||||
paths_data: Dict[str, str] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Process files from SharePoint source path and store neutralized files in target path"""
|
||||
try:
|
||||
source_path = paths_data.get("sourcePath", "")
|
||||
target_path = paths_data.get("targetPath", "")
|
||||
|
||||
if not source_path or not target_path:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Both source and target paths are required"
|
||||
)
|
||||
|
||||
service = NeutralizationPlayground(currentUser)
|
||||
result = await service.processSharepointFiles(source_path, target_path)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing SharePoint files: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error processing SharePoint files: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/batch-process", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def batch_process_files(
|
||||
request: Request,
|
||||
files_data: List[Dict[str, Any]] = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Process multiple files for neutralization"""
|
||||
try:
|
||||
if not files_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Files data is required"
|
||||
)
|
||||
|
||||
service = NeutralizationPlayground(currentUser)
|
||||
result = service.batchNeutralizeFiles(files_data)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error batch processing files: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error batch processing files: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/stats", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_neutralization_stats(
|
||||
request: Request,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get neutralization processing statistics"""
|
||||
try:
|
||||
service = NeutralizationPlayground(currentUser)
|
||||
stats = service.getProcessingStats()
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting neutralization stats: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error getting neutralization stats: {str(e)}"
|
||||
)
|
||||
|
||||
@router.delete("/attributes/{fileId}", response_model=Dict[str, str])
|
||||
@limiter.limit("10/minute")
|
||||
async def cleanup_file_attributes(
|
||||
request: Request,
|
||||
fileId: str = Path(..., description="File ID to cleanup attributes for"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, str]:
|
||||
"""Clean up neutralization attributes for a specific file"""
|
||||
try:
|
||||
service = NeutralizationPlayground(currentUser)
|
||||
success = service.cleanupFileAttributes(fileId)
|
||||
|
||||
if success:
|
||||
return {"message": f"Successfully cleaned up attributes for file {fileId}"}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cleanup file attributes"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up file attributes: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error cleaning up file attributes: {str(e)}"
|
||||
)
|
||||
|
|
@ -1,21 +1,17 @@
|
|||
from fastapi import APIRouter, HTTPException, Depends, Body, Query, Path, Request, Response
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Query
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import status
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import inspect
|
||||
import importlib
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
# Import auth module
|
||||
from modules.security.auth import limiter, getCurrentUser
|
||||
|
||||
# Import interfaces
|
||||
import modules.interfaces.interfaceComponentObjects as interfaceComponentObjects
|
||||
from modules.interfaces.interfaceComponentModel import Prompt
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions, AttributeResponse, AttributeDefinition
|
||||
from modules.interfaces.interfaceAppModel import User
|
||||
import modules.interfaces.interfaceDbComponentObjects as interfaceDbComponentObjects
|
||||
from modules.datamodels.datamodelUtils import Prompt
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -27,16 +23,58 @@ router = APIRouter(
|
|||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
@router.get("", response_model=List[Prompt])
|
||||
@router.get("", response_model=PaginatedResponse[Prompt])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_prompts(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[Prompt]:
|
||||
"""Get all prompts"""
|
||||
managementInterface = interfaceComponentObjects.getInterface(currentUser)
|
||||
prompts = managementInterface.getAllPrompts()
|
||||
return prompts
|
||||
) -> PaginatedResponse[Prompt]:
|
||||
"""
|
||||
Get prompts with optional pagination, sorting, and filtering.
|
||||
|
||||
Query Parameters:
|
||||
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
||||
|
||||
Examples:
|
||||
- GET /api/prompts (no pagination - returns all items)
|
||||
- GET /api/prompts?pagination={"page":1,"pageSize":10,"sort":[]}
|
||||
- GET /api/prompts?pagination={"page":2,"pageSize":20,"sort":[{"field":"name","direction":"asc"}]}
|
||||
"""
|
||||
# Parse pagination parameter
|
||||
paginationParams = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid pagination parameter: {str(e)}"
|
||||
)
|
||||
|
||||
managementInterface = interfaceDbComponentObjects.getInterface(currentUser)
|
||||
result = managementInterface.getAllPrompts(pagination=paginationParams)
|
||||
|
||||
# If pagination was requested, result is PaginatedResult
|
||||
# If no pagination, result is List[Prompt]
|
||||
if paginationParams:
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=result.totalItems,
|
||||
totalPages=result.totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PaginatedResponse(
|
||||
items=result,
|
||||
pagination=None
|
||||
)
|
||||
|
||||
@router.post("", response_model=Prompt)
|
||||
@limiter.limit("10/minute")
|
||||
|
|
@ -46,13 +84,10 @@ async def create_prompt(
|
|||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Prompt:
|
||||
"""Create a new prompt"""
|
||||
managementInterface = interfaceComponentObjects.getInterface(currentUser)
|
||||
|
||||
# Convert Prompt to dict for interface
|
||||
prompt_data = prompt.dict()
|
||||
managementInterface = interfaceDbComponentObjects.getInterface(currentUser)
|
||||
|
||||
# Create prompt
|
||||
newPrompt = managementInterface.createPrompt(prompt_data)
|
||||
newPrompt = managementInterface.createPrompt(prompt)
|
||||
|
||||
return Prompt(**newPrompt)
|
||||
|
||||
|
|
@ -64,7 +99,7 @@ async def get_prompt(
|
|||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Prompt:
|
||||
"""Get a specific prompt"""
|
||||
managementInterface = interfaceComponentObjects.getInterface(currentUser)
|
||||
managementInterface = interfaceDbComponentObjects.getInterface(currentUser)
|
||||
|
||||
# Get prompt
|
||||
prompt = managementInterface.getPrompt(promptId)
|
||||
|
|
@ -85,7 +120,7 @@ async def update_prompt(
|
|||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt"""
|
||||
managementInterface = interfaceComponentObjects.getInterface(currentUser)
|
||||
managementInterface = interfaceDbComponentObjects.getInterface(currentUser)
|
||||
|
||||
# Check if the prompt exists
|
||||
existingPrompt = managementInterface.getPrompt(promptId)
|
||||
|
|
@ -95,8 +130,11 @@ async def update_prompt(
|
|||
detail=f"Prompt with ID {promptId} not found"
|
||||
)
|
||||
|
||||
# Convert Prompt to dict for interface
|
||||
update_data = promptData.dict()
|
||||
# Convert Prompt to dict for interface, excluding the id field
|
||||
if hasattr(promptData, "model_dump"):
|
||||
update_data = promptData.model_dump(exclude={"id"})
|
||||
else:
|
||||
update_data = promptData.model_dump(exclude={"id"})
|
||||
|
||||
# Update prompt
|
||||
updatedPrompt = managementInterface.updatePrompt(promptId, update_data)
|
||||
|
|
@ -117,7 +155,7 @@ async def delete_prompt(
|
|||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Delete a prompt"""
|
||||
managementInterface = interfaceComponentObjects.getInterface(currentUser)
|
||||
managementInterface = interfaceDbComponentObjects.getInterface(currentUser)
|
||||
|
||||
# Check if the prompt exists
|
||||
existingPrompt = managementInterface.getPrompt(promptId)
|
||||
|
|
|
|||
|
|
@ -3,23 +3,19 @@ User routes for the backend API.
|
|||
Implements the endpoints for user management.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response, Query
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import status
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import inspect
|
||||
import importlib
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
# Import interfaces and models
|
||||
import modules.interfaces.interfaceAppObjects as interfaceAppObjects
|
||||
import modules.interfaces.interfaceDbAppObjects as interfaceDbAppObjects
|
||||
from modules.security.auth import getCurrentUser, limiter, getCurrentUser
|
||||
|
||||
# Import the attribute definition and helper functions
|
||||
from modules.interfaces.interfaceAppModel import User, AttributeDefinition
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions, AttributeResponse
|
||||
from modules.datamodels.datamodelUam import User, UserPrivilege
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -30,21 +26,65 @@ router = APIRouter(
|
|||
responses={404: {"description": "Not found"}}
|
||||
)
|
||||
|
||||
@router.get("/", response_model=List[User])
|
||||
@router.get("/", response_model=PaginatedResponse[User])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_users(
|
||||
request: Request,
|
||||
mandateId: Optional[str] = None,
|
||||
mandateId: Optional[str] = Query(None, description="Mandate ID to filter users"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[User]:
|
||||
"""Get all users in the current mandate"""
|
||||
) -> PaginatedResponse[User]:
|
||||
"""
|
||||
Get users with optional pagination, sorting, and filtering.
|
||||
|
||||
Query Parameters:
|
||||
- mandateId: Optional mandate ID to filter users
|
||||
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
||||
|
||||
Examples:
|
||||
- GET /api/users/ (no pagination - returns all users)
|
||||
- GET /api/users/?pagination={"page":1,"pageSize":10,"sort":[]}
|
||||
"""
|
||||
try:
|
||||
appInterface = interfaceAppObjects.getInterface(currentUser)
|
||||
# Parse pagination parameter
|
||||
paginationParams = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid pagination parameter: {str(e)}"
|
||||
)
|
||||
|
||||
appInterface = interfaceDbAppObjects.getInterface(currentUser)
|
||||
# If mandateId is provided, use it, otherwise use the current user's mandate
|
||||
targetMandateId = mandateId or currentUser.mandateId
|
||||
# Get all users without filtering by enabled status
|
||||
users = appInterface.getUsersByMandate(targetMandateId)
|
||||
return users
|
||||
# Get users with optional pagination
|
||||
result = appInterface.getUsersByMandate(targetMandateId, pagination=paginationParams)
|
||||
|
||||
# If pagination was requested, result is PaginatedResult
|
||||
# If no pagination, result is List[User]
|
||||
if paginationParams:
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=result.totalItems,
|
||||
totalPages=result.totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PaginatedResponse(
|
||||
items=result,
|
||||
pagination=None
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting users: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
|
@ -61,7 +101,7 @@ async def get_user(
|
|||
) -> User:
|
||||
"""Get a specific user by ID"""
|
||||
try:
|
||||
appInterface = interfaceAppObjects.getInterface(currentUser)
|
||||
appInterface = interfaceDbAppObjects.getInterface(currentUser)
|
||||
# Get user without filtering by enabled status
|
||||
user = appInterface.getUser(userId)
|
||||
|
||||
|
|
@ -89,13 +129,10 @@ async def create_user(
|
|||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> User:
|
||||
"""Create a new user"""
|
||||
appInterface = interfaceAppObjects.getInterface(currentUser)
|
||||
|
||||
# Convert User to dict for interface
|
||||
user_dict = user_data.dict()
|
||||
appInterface = interfaceDbAppObjects.getInterface(currentUser)
|
||||
|
||||
# Create user
|
||||
newUser = appInterface.createUser(user_dict)
|
||||
newUser = appInterface.createUser(user_data)
|
||||
|
||||
return newUser
|
||||
|
||||
|
|
@ -108,7 +145,7 @@ async def update_user(
|
|||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> User:
|
||||
"""Update an existing user"""
|
||||
appInterface = interfaceAppObjects.getInterface(currentUser)
|
||||
appInterface = interfaceDbAppObjects.getInterface(currentUser)
|
||||
|
||||
# Check if the user exists
|
||||
existingUser = appInterface.getUser(userId)
|
||||
|
|
@ -118,11 +155,8 @@ async def update_user(
|
|||
detail=f"User with ID {userId} not found"
|
||||
)
|
||||
|
||||
# Convert User to dict for interface
|
||||
update_data = userData.dict()
|
||||
|
||||
# Update user
|
||||
updatedUser = appInterface.updateUser(userId, update_data)
|
||||
updatedUser = appInterface.updateUser(userId, userData)
|
||||
|
||||
if not updatedUser:
|
||||
raise HTTPException(
|
||||
|
|
@ -132,6 +166,165 @@ async def update_user(
|
|||
|
||||
return updatedUser
|
||||
|
||||
@router.post("/{userId}/reset-password")
|
||||
@limiter.limit("5/minute")
|
||||
async def reset_user_password(
|
||||
request: Request,
|
||||
userId: str = Path(..., description="ID of the user to reset password for"),
|
||||
newPassword: str = Body(..., embed=True),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Reset user password (Admin only)"""
|
||||
try:
|
||||
# Check if current user is admin
|
||||
if currentUser.privilege != UserPrivilege.ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only administrators can reset passwords"
|
||||
)
|
||||
|
||||
# Get user interface
|
||||
appInterface = interfaceDbAppObjects.getInterface(currentUser)
|
||||
|
||||
# Get target user
|
||||
target_user = appInterface.getUserById(userId)
|
||||
if not target_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found"
|
||||
)
|
||||
|
||||
# Validate password strength
|
||||
if len(newPassword) < 8:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Password must be at least 8 characters long"
|
||||
)
|
||||
|
||||
# Reset password
|
||||
success = appInterface.resetUserPassword(userId, newPassword)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to reset password"
|
||||
)
|
||||
|
||||
# SECURITY: Automatically revoke all tokens for the user after password reset
|
||||
try:
|
||||
from modules.datamodels.datamodelUam import AuthAuthority
|
||||
revoked_count = appInterface.revokeTokensByUser(
|
||||
userId=userId,
|
||||
authority=None, # Revoke all authorities
|
||||
mandateId=None, # Revoke across all mandates
|
||||
revokedBy=currentUser.id,
|
||||
reason="password_reset"
|
||||
)
|
||||
logger.info(f"Revoked {revoked_count} tokens for user {userId} after password reset")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke tokens after password reset for user {userId}: {str(e)}")
|
||||
# Don't fail the password reset if token revocation fails
|
||||
|
||||
# Log password reset
|
||||
try:
|
||||
from modules.shared.auditLogger import audit_logger
|
||||
audit_logger.logSecurityEvent(
|
||||
userId=str(currentUser.id),
|
||||
mandateId=str(currentUser.mandateId),
|
||||
action="password_reset",
|
||||
details=f"Reset password for user {userId}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"message": "Password reset successfully",
|
||||
"user_id": userId
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error resetting password: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Password reset failed: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/change-password")
|
||||
@limiter.limit("5/minute")
|
||||
async def change_password(
|
||||
request: Request,
|
||||
currentPassword: str = Body(..., embed=True),
|
||||
newPassword: str = Body(..., embed=True),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Change current user's password"""
|
||||
try:
|
||||
# Get user interface
|
||||
appInterface = interfaceDbAppObjects.getInterface(currentUser)
|
||||
|
||||
# Verify current password
|
||||
if not appInterface.verifyPassword(currentPassword, currentUser.passwordHash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Current password is incorrect"
|
||||
)
|
||||
|
||||
# Validate new password strength
|
||||
if len(newPassword) < 8:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="New password must be at least 8 characters long"
|
||||
)
|
||||
|
||||
# Change password
|
||||
success = appInterface.resetUserPassword(str(currentUser.id), newPassword)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to change password"
|
||||
)
|
||||
|
||||
# SECURITY: Automatically revoke all tokens for the user after password change
|
||||
try:
|
||||
from modules.datamodels.datamodelUam import AuthAuthority
|
||||
revoked_count = appInterface.revokeTokensByUser(
|
||||
userId=str(currentUser.id),
|
||||
authority=None, # Revoke all authorities
|
||||
mandateId=None, # Revoke across all mandates
|
||||
revokedBy=currentUser.id,
|
||||
reason="password_change"
|
||||
)
|
||||
logger.info(f"Revoked {revoked_count} tokens for user {currentUser.id} after password change")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to revoke tokens after password change for user {currentUser.id}: {str(e)}")
|
||||
# Don't fail the password change if token revocation fails
|
||||
|
||||
# Log password change
|
||||
try:
|
||||
from modules.shared.auditLogger import audit_logger
|
||||
audit_logger.logSecurityEvent(
|
||||
userId=str(currentUser.id),
|
||||
mandateId=str(currentUser.mandateId),
|
||||
action="password_change",
|
||||
details="User changed their own password"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"message": "Password changed successfully. Please log in again with your new password."
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error changing password: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Password change failed: {str(e)}"
|
||||
)
|
||||
|
||||
@router.delete("/{userId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
async def delete_user(
|
||||
|
|
@ -140,7 +333,7 @@ async def delete_user(
|
|||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""Delete a user"""
|
||||
appInterface = interfaceAppObjects.getInterface(currentUser)
|
||||
appInterface = interfaceDbAppObjects.getInterface(currentUser)
|
||||
|
||||
# Check if the user exists
|
||||
existingUser = appInterface.getUser(userId)
|
||||
|
|
|
|||
|
|
@ -5,15 +5,16 @@ import os
|
|||
import logging
|
||||
|
||||
from modules.security.auth import getCurrentUser, limiter
|
||||
from modules.interfaces.interfaceAppObjects import getInterface, getRootInterface
|
||||
from modules.interfaces.interfaceAppModel import User, UserInDB, AuthAuthority, Token
|
||||
from modules.interfaces.interfaceDbAppObjects import getInterface, getRootInterface
|
||||
from modules.datamodels.datamodelUam import User, UserInDB, AuthAuthority
|
||||
from modules.datamodels.datamodelSecurity import Token
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/admin",
|
||||
tags=["Admin"],
|
||||
tags=["Security Administration"],
|
||||
responses={
|
||||
404: {"description": "Not found"},
|
||||
400: {"description": "Bad request"},
|
||||
|
|
@ -248,9 +249,145 @@ async def list_databases(
|
|||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
_ensure_admin_scope(currentUser)
|
||||
# For safety, expose only configured database name
|
||||
db_name = APP_CONFIG.get("DB_DATABASE") or APP_CONFIG.get("DB_NAME") or "poweron"
|
||||
return {"databases": [db_name]}
|
||||
|
||||
# Get database names from configuration for each interface
|
||||
databases = []
|
||||
|
||||
# App database (interfaceDbAppObjects.py)
|
||||
app_db = APP_CONFIG.get("DB_APP_DATABASE")
|
||||
if app_db:
|
||||
databases.append(app_db)
|
||||
|
||||
# Chat database (interfaceDbChatObjects.py)
|
||||
chat_db = APP_CONFIG.get("DB_CHAT_DATABASE")
|
||||
if chat_db:
|
||||
databases.append(chat_db)
|
||||
|
||||
# Management database (interfaceDbComponentObjects.py)
|
||||
management_db = APP_CONFIG.get("DB_MANAGEMENT_DATABASE")
|
||||
if management_db:
|
||||
databases.append(management_db)
|
||||
|
||||
# Fallback to default if no databases configured
|
||||
if not databases:
|
||||
databases = ["poweron"]
|
||||
|
||||
return {"databases": databases}
|
||||
|
||||
|
||||
@router.get("/databases/{database_name}/tables")
|
||||
@limiter.limit("30/minute")
|
||||
async def get_database_tables(
|
||||
request: Request,
|
||||
database_name: str,
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
_ensure_admin_scope(currentUser)
|
||||
|
||||
# Get all configured database names
|
||||
configured_dbs = []
|
||||
app_db = APP_CONFIG.get("DB_APP_DATABASE")
|
||||
if app_db:
|
||||
configured_dbs.append(app_db)
|
||||
chat_db = APP_CONFIG.get("DB_CHAT_DATABASE")
|
||||
if chat_db:
|
||||
configured_dbs.append(chat_db)
|
||||
management_db = APP_CONFIG.get("DB_MANAGEMENT_DATABASE")
|
||||
if management_db:
|
||||
configured_dbs.append(management_db)
|
||||
|
||||
if not configured_dbs:
|
||||
configured_dbs = ["poweron"]
|
||||
|
||||
if database_name not in configured_dbs:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid database name. Available databases: {configured_dbs}")
|
||||
|
||||
try:
|
||||
# Use the appropriate interface based on database name
|
||||
if database_name == app_db:
|
||||
appInterface = getRootInterface()
|
||||
tables = appInterface.db.getTables()
|
||||
elif database_name == chat_db:
|
||||
from modules.interfaces.interfaceDbChatObjects import getInterface as getChatInterface
|
||||
chatInterface = getChatInterface(currentUser)
|
||||
tables = chatInterface.db.getTables()
|
||||
elif database_name == management_db:
|
||||
from modules.interfaces.interfaceDbComponentObjects import getInterface as getComponentInterface
|
||||
componentInterface = getComponentInterface(currentUser)
|
||||
tables = componentInterface.db.getTables()
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Database not found")
|
||||
|
||||
return {"tables": tables}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting database tables: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get database tables")
|
||||
|
||||
|
||||
@router.post("/databases/{database_name}/tables/{table_name}/drop")
|
||||
@limiter.limit("10/minute")
|
||||
async def drop_table(
|
||||
request: Request,
|
||||
database_name: str,
|
||||
table_name: str,
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
payload: Dict[str, Any] = Body(...)
|
||||
) -> Dict[str, Any]:
|
||||
_ensure_admin_scope(currentUser)
|
||||
|
||||
# Get all configured database names
|
||||
configured_dbs = []
|
||||
app_db = APP_CONFIG.get("DB_APP_DATABASE")
|
||||
if app_db:
|
||||
configured_dbs.append(app_db)
|
||||
chat_db = APP_CONFIG.get("DB_CHAT_DATABASE")
|
||||
if chat_db:
|
||||
configured_dbs.append(chat_db)
|
||||
management_db = APP_CONFIG.get("DB_MANAGEMENT_DATABASE")
|
||||
if management_db:
|
||||
configured_dbs.append(management_db)
|
||||
|
||||
if not configured_dbs:
|
||||
configured_dbs = ["poweron"]
|
||||
|
||||
if database_name not in configured_dbs:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid database name. Available databases: {configured_dbs}")
|
||||
|
||||
try:
|
||||
# Use the appropriate interface based on database name
|
||||
if database_name == app_db:
|
||||
interface = getRootInterface()
|
||||
elif database_name == chat_db:
|
||||
from modules.interfaces.interfaceDbChatObjects import getInterface as getChatInterface
|
||||
interface = getChatInterface(currentUser)
|
||||
elif database_name == management_db:
|
||||
from modules.interfaces.interfaceDbComponentObjects import getInterface as getComponentInterface
|
||||
interface = getComponentInterface(currentUser)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Database not found")
|
||||
|
||||
conn = interface.db.connection
|
||||
with conn.cursor() as cursor:
|
||||
# Check if table exists
|
||||
cursor.execute("""
|
||||
SELECT table_name FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_name = %s
|
||||
""", (table_name,))
|
||||
if not cursor.fetchone():
|
||||
raise HTTPException(status_code=404, detail="Table not found")
|
||||
|
||||
# Drop the table
|
||||
cursor.execute(f'DROP TABLE IF EXISTS "{table_name}" CASCADE')
|
||||
conn.commit()
|
||||
logger.warning(f"Admin drop_table executed by {currentUser.id}: dropped table '{table_name}' from database '{database_name}'")
|
||||
return {"message": f"Table '{table_name}' dropped successfully from database '{database_name}'"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping table: {str(e)}")
|
||||
if 'interface' in locals() and interface and interface.db and interface.db.connection:
|
||||
interface.db.connection.rollback()
|
||||
raise HTTPException(status_code=500, detail="Failed to drop table")
|
||||
|
||||
|
||||
@router.post("/databases/drop")
|
||||
|
|
@ -262,13 +399,39 @@ async def drop_database(
|
|||
) -> Dict[str, Any]:
|
||||
_ensure_admin_scope(currentUser)
|
||||
db_name = payload.get("database")
|
||||
configured_db = APP_CONFIG.get("DB_DATABASE") or APP_CONFIG.get("DB_NAME") or "poweron"
|
||||
if not db_name or db_name != configured_db:
|
||||
raise HTTPException(status_code=400, detail="Invalid database name")
|
||||
|
||||
# Get all configured database names
|
||||
configured_dbs = []
|
||||
app_db = APP_CONFIG.get("DB_APP_DATABASE")
|
||||
if app_db:
|
||||
configured_dbs.append(app_db)
|
||||
chat_db = APP_CONFIG.get("DB_CHAT_DATABASE")
|
||||
if chat_db:
|
||||
configured_dbs.append(chat_db)
|
||||
management_db = APP_CONFIG.get("DB_MANAGEMENT_DATABASE")
|
||||
if management_db:
|
||||
configured_dbs.append(management_db)
|
||||
|
||||
if not configured_dbs:
|
||||
configured_dbs = ["poweron"]
|
||||
|
||||
if not db_name or db_name not in configured_dbs:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid database name. Available databases: {configured_dbs}")
|
||||
|
||||
try:
|
||||
appInterface = getRootInterface()
|
||||
conn = appInterface.db.connection
|
||||
# Use the appropriate interface based on database name
|
||||
if db_name == app_db:
|
||||
interface = getRootInterface()
|
||||
elif db_name == chat_db:
|
||||
from modules.interfaces.interfaceDbChatObjects import getInterface as getChatInterface
|
||||
interface = getChatInterface(currentUser)
|
||||
elif db_name == management_db:
|
||||
from modules.interfaces.interfaceDbComponentObjects import getInterface as getComponentInterface
|
||||
interface = getComponentInterface(currentUser)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Database not found")
|
||||
|
||||
conn = interface.db.connection
|
||||
with conn.cursor() as cursor:
|
||||
# Drop all user tables (public schema) except system table
|
||||
cursor.execute("""
|
||||
|
|
@ -281,12 +444,12 @@ async def drop_database(
|
|||
cursor.execute(f'DROP TABLE IF EXISTS "{tbl}" CASCADE')
|
||||
dropped.append(tbl)
|
||||
conn.commit()
|
||||
logger.warning(f"Admin drop_database executed by {currentUser.id}: dropped tables: {dropped}")
|
||||
logger.warning(f"Admin drop_database executed by {currentUser.id}: dropped tables from '{db_name}': {dropped}")
|
||||
return {"droppedTables": dropped}
|
||||
except Exception as e:
|
||||
logger.error(f"Error dropping database tables: {str(e)}")
|
||||
if appInterface and appInterface.db and appInterface.db.connection:
|
||||
appInterface.db.connection.rollback()
|
||||
if 'interface' in locals() and interface and interface.db and interface.db.connection:
|
||||
interface.db.connection.rollback()
|
||||
raise HTTPException(status_code=500, detail="Failed to drop database tables")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,16 +7,14 @@ from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
|||
import logging
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from requests_oauthlib import OAuth2Session
|
||||
import httpx
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.interfaces.interfaceAppObjects import getInterface, getRootInterface
|
||||
from modules.interfaces.interfaceAppModel import AuthAuthority, User, Token, ConnectionStatus, UserConnection
|
||||
from modules.interfaces.interfaceDbAppObjects import getInterface, getRootInterface
|
||||
from modules.datamodels.datamodelUam import AuthAuthority, User, ConnectionStatus, UserConnection
|
||||
from modules.security.auth import getCurrentUser, limiter
|
||||
from modules.shared.attributeUtils import ModelMixin
|
||||
from modules.shared.timezoneUtils import get_utc_now, create_expiration_timestamp, get_utc_timestamp
|
||||
from modules.shared.timezoneUtils import createExpirationTimestamp, getUtcTimestamp
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -169,7 +167,7 @@ async def login(
|
|||
try:
|
||||
if connectionId:
|
||||
rootInterface = getRootInterface()
|
||||
from modules.interfaces.interfaceAppModel import UserConnection
|
||||
from modules.datamodels.datamodelUam import UserConnection
|
||||
records = rootInterface.db.getRecordset(UserConnection, recordFilter={"id": connectionId})
|
||||
if records:
|
||||
record = records[0]
|
||||
|
|
@ -208,7 +206,7 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
"""Handle Google OAuth callback"""
|
||||
try:
|
||||
# Import Token at function level to avoid scoping issues
|
||||
from modules.interfaces.interfaceAppModel import Token
|
||||
from modules.datamodels.datamodelSecurity import Token
|
||||
|
||||
# Parse state
|
||||
state_data = json.loads(state)
|
||||
|
|
@ -339,7 +337,7 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
)
|
||||
|
||||
# Create JWT token data (like Microsoft does)
|
||||
from modules.security.auth import createAccessToken
|
||||
from modules.security.jwtService import createAccessToken
|
||||
jwt_token_data = {
|
||||
"sub": user.username,
|
||||
"mandateId": str(user.mandateId),
|
||||
|
|
@ -358,7 +356,7 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
tokenRefresh=token_response.get("refresh_token", ""),
|
||||
tokenType="bearer",
|
||||
expiresAt=jwt_expires_at.timestamp(),
|
||||
createdAt=get_utc_timestamp()
|
||||
createdAt=getUtcTimestamp()
|
||||
)
|
||||
|
||||
# Save access token (no connectionId)
|
||||
|
|
@ -376,7 +374,7 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
window.opener.postMessage({{
|
||||
type: 'google_auth_success',
|
||||
access_token: {json.dumps(token_response["access_token"])},
|
||||
token_data: {json.dumps(token.to_dict())}
|
||||
token_data: {json.dumps(token.model_dump())}
|
||||
}}, '*');
|
||||
}}
|
||||
setTimeout(() => window.close(), 1000);
|
||||
|
|
@ -462,15 +460,15 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
logger.info(f"Updating connection {connection_id} for user {user.username}")
|
||||
# Update connection with external service details
|
||||
connection.status = ConnectionStatus.ACTIVE
|
||||
connection.lastChecked = get_utc_timestamp()
|
||||
connection.expiresAt = get_utc_timestamp() + token_response.get("expires_in", 0)
|
||||
connection.lastChecked = getUtcTimestamp()
|
||||
connection.expiresAt = getUtcTimestamp() + token_response.get("expires_in", 0)
|
||||
connection.externalId = user_info.get("id")
|
||||
connection.externalUsername = user_info.get("email")
|
||||
connection.externalEmail = user_info.get("email")
|
||||
|
||||
# Update connection record directly
|
||||
from modules.interfaces.interfaceAppModel import UserConnection
|
||||
rootInterface.db.recordModify(UserConnection, connection_id, connection.to_dict())
|
||||
from modules.datamodels.datamodelUam import UserConnection
|
||||
rootInterface.db.recordModify(UserConnection, connection_id, connection.model_dump())
|
||||
|
||||
|
||||
# Save token
|
||||
|
|
@ -481,8 +479,8 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
tokenAccess=token_response["access_token"],
|
||||
tokenRefresh=token_response.get("refresh_token", ""),
|
||||
tokenType=token_response.get("token_type", "bearer"),
|
||||
expiresAt=create_expiration_timestamp(token_response.get("expires_in", 0)),
|
||||
createdAt=get_utc_timestamp()
|
||||
expiresAt=createExpirationTimestamp(token_response.get("expires_in", 0)),
|
||||
createdAt=getUtcTimestamp()
|
||||
)
|
||||
interface.saveConnectionToken(token)
|
||||
|
||||
|
|
@ -500,8 +498,8 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
id: '{connection.id}',
|
||||
status: 'connected',
|
||||
type: 'google',
|
||||
lastChecked: {get_utc_timestamp()},
|
||||
expiresAt: {create_expiration_timestamp(token_response.get("expires_in", 0))}
|
||||
lastChecked: {getUtcTimestamp()},
|
||||
expiresAt: {createExpirationTimestamp(token_response.get("expires_in", 0))}
|
||||
}}
|
||||
}}, '*');
|
||||
// Wait for message to be sent before closing
|
||||
|
|
@ -590,6 +588,20 @@ async def logout(
|
|||
try:
|
||||
appInterface = getInterface(currentUser)
|
||||
appInterface.logout()
|
||||
|
||||
# Log successful logout
|
||||
try:
|
||||
from modules.shared.auditLogger import audit_logger
|
||||
audit_logger.logUserAccess(
|
||||
userId=str(currentUser.id),
|
||||
mandateId=str(currentUser.mandateId),
|
||||
action="logout",
|
||||
successInfo="google_auth_logout"
|
||||
)
|
||||
except Exception:
|
||||
# Don't fail if audit logging fails
|
||||
pass
|
||||
|
||||
return {"message": "Logged out successfully"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout: {str(e)}")
|
||||
|
|
@ -623,29 +635,19 @@ async def verify_token(
|
|||
detail="No Google connection found for current user"
|
||||
)
|
||||
|
||||
# Get the current token
|
||||
current_token = appInterface.getConnectionToken(google_connection.id, auto_refresh=False)
|
||||
|
||||
# Get a fresh token via TokenManager convenience method
|
||||
from modules.security.tokenManager import TokenManager
|
||||
current_token = TokenManager().getFreshToken(google_connection.id)
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No Google token found for this connection"
|
||||
)
|
||||
|
||||
# Verify the token
|
||||
# Verify the (fresh) token
|
||||
token_verification = await verify_google_token(current_token.tokenAccess)
|
||||
|
||||
if not token_verification.get("valid"):
|
||||
# Try to refresh the token if verification failed
|
||||
from modules.security.tokenManager import TokenManager
|
||||
token_manager = TokenManager()
|
||||
refreshed_token = token_manager.refresh_token(current_token)
|
||||
|
||||
if refreshed_token:
|
||||
appInterface.saveConnectionToken(refreshed_token)
|
||||
# Verify the refreshed token
|
||||
token_verification = await verify_google_token(refreshed_token.tokenAccess)
|
||||
|
||||
return {
|
||||
"valid": token_verification.get("valid", False),
|
||||
"scopes": token_verification.get("scopes", []),
|
||||
|
|
@ -707,8 +709,9 @@ async def refresh_token(
|
|||
|
||||
logger.debug(f"Found Google connection: {google_connection.id}, status={google_connection.status}")
|
||||
|
||||
# Get the token for this specific connection using the new method
|
||||
current_token = appInterface.getConnectionToken(google_connection.id, auto_refresh=False)
|
||||
# Get the token for this specific connection (fresh if expiring soon)
|
||||
from modules.security.tokenManager import TokenManager
|
||||
current_token = TokenManager().getFreshToken(google_connection.id)
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(
|
||||
|
|
@ -717,38 +720,25 @@ async def refresh_token(
|
|||
)
|
||||
|
||||
|
||||
# If we could not obtain a fresh token, report error
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to refresh token")
|
||||
|
||||
# Always attempt refresh (as per your requirement)
|
||||
from modules.security.tokenManager import TokenManager
|
||||
token_manager = TokenManager()
|
||||
# Update the connection status and timing
|
||||
google_connection.expiresAt = float(current_token.expiresAt) if current_token.expiresAt else google_connection.expiresAt
|
||||
google_connection.lastChecked = getUtcTimestamp()
|
||||
google_connection.status = ConnectionStatus.ACTIVE
|
||||
appInterface.db.recordModify(UserConnection, google_connection.id, google_connection.model_dump())
|
||||
|
||||
refreshed_token = token_manager.refresh_token(current_token)
|
||||
if refreshed_token:
|
||||
# Save the new connection token (which will automatically replace old ones)
|
||||
appInterface.saveConnectionToken(refreshed_token)
|
||||
|
||||
# Update the connection's expiration time
|
||||
google_connection.expiresAt = float(refreshed_token.expiresAt)
|
||||
google_connection.lastChecked = get_utc_timestamp()
|
||||
google_connection.status = ConnectionStatus.ACTIVE
|
||||
|
||||
# Save updated connection
|
||||
appInterface.db.recordModify(UserConnection, google_connection.id, google_connection.to_dict())
|
||||
|
||||
# Calculate time until expiration
|
||||
current_time = get_utc_timestamp()
|
||||
expires_in = int(refreshed_token.expiresAt - current_time)
|
||||
|
||||
return {
|
||||
"message": "Token refreshed successfully",
|
||||
"expires_at": refreshed_token.expiresAt,
|
||||
"expires_in_seconds": expires_in
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to refresh token"
|
||||
)
|
||||
# Calculate time until expiration
|
||||
current_time = getUtcTimestamp()
|
||||
expires_in = int(current_token.expiresAt - current_time) if current_token.expiresAt else 0
|
||||
|
||||
return {
|
||||
"message": "Token refreshed successfully",
|
||||
"expires_at": current_token.expiresAt,
|
||||
"expires_in_seconds": expires_in
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -5,18 +5,18 @@ Routes for local security and authentication.
|
|||
from fastapi import APIRouter, HTTPException, status, Depends, Request, Response, Body
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
from fastapi.responses import JSONResponse, HTMLResponse, RedirectResponse
|
||||
import uuid
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Import auth modules
|
||||
from modules.security.auth import createAccessToken, getCurrentUser, limiter, SECRET_KEY, ALGORITHM
|
||||
from modules.interfaces.interfaceAppObjects import getInterface, getRootInterface
|
||||
from modules.interfaces.interfaceAppModel import User, UserInDB, AuthAuthority, UserPrivilege, Token
|
||||
from modules.shared.attributeUtils import ModelMixin
|
||||
from modules.security.auth import getCurrentUser, limiter, SECRET_KEY, ALGORITHM
|
||||
from modules.security.jwtService import createAccessToken, createRefreshToken, setAccessTokenCookie, setRefreshTokenCookie, clearAccessTokenCookie, clearRefreshTokenCookie
|
||||
from modules.interfaces.interfaceDbAppObjects import getInterface, getRootInterface
|
||||
from modules.datamodels.datamodelUam import User, UserInDB, AuthAuthority, UserPrivilege
|
||||
from modules.datamodels.datamodelSecurity import Token
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -38,6 +38,7 @@ router = APIRouter(
|
|||
@limiter.limit("30/minute")
|
||||
async def login(
|
||||
request: Request,
|
||||
response: Response,
|
||||
formData: OAuth2PasswordRequestForm = Depends(),
|
||||
) -> Dict[str, Any]:
|
||||
"""Get access token for local user authentication"""
|
||||
|
|
@ -54,7 +55,7 @@ async def login(
|
|||
rootInterface = getRootInterface()
|
||||
|
||||
# Get default mandate ID
|
||||
from modules.interfaces.interfaceAppModel import Mandate
|
||||
from modules.datamodels.datamodelUam import Mandate
|
||||
defaultMandateId = rootInterface.getInitialId(Mandate)
|
||||
if not defaultMandateId:
|
||||
raise HTTPException(
|
||||
|
|
@ -90,24 +91,27 @@ async def login(
|
|||
session_id = str(uuid.uuid4())
|
||||
token_data["sid"] = session_id
|
||||
|
||||
# Create access token
|
||||
access_token, expires_at = createAccessToken(token_data)
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create access token"
|
||||
)
|
||||
# Create access token + set cookie
|
||||
access_token, _access_expires = createAccessToken(token_data)
|
||||
setAccessTokenCookie(response, access_token)
|
||||
|
||||
# Create refresh token + set cookie
|
||||
refresh_token, _refresh_expires = createRefreshToken(token_data)
|
||||
setRefreshTokenCookie(response, refresh_token)
|
||||
|
||||
# Get expiration time for response
|
||||
try:
|
||||
payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
expires_at = datetime.fromtimestamp(payload.get("exp"))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode access token: {str(e)}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to finalize token")
|
||||
|
||||
# Get user-specific interface for token operations
|
||||
userInterface = getInterface(user)
|
||||
|
||||
# Decode JWT to get jti for DB persistence
|
||||
try:
|
||||
payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
jti = payload.get("jti")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode created JWT: {str(e)}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to finalize token")
|
||||
# Get jti from already decoded payload
|
||||
jti = payload.get("jti")
|
||||
|
||||
# Create token
|
||||
token = Token(
|
||||
|
|
@ -124,12 +128,25 @@ async def login(
|
|||
# Save access token
|
||||
userInterface.saveAccessToken(token)
|
||||
|
||||
# Create response data
|
||||
# Log successful login
|
||||
try:
|
||||
from modules.shared.auditLogger import audit_logger
|
||||
audit_logger.logUserAccess(
|
||||
userId=str(user.id),
|
||||
mandateId=str(user.mandateId),
|
||||
action="login",
|
||||
successInfo="local_auth_success"
|
||||
)
|
||||
except Exception:
|
||||
# Don't fail if audit logging fails
|
||||
pass
|
||||
|
||||
# Create response data (tokens are now in httpOnly cookies)
|
||||
response_data = {
|
||||
"type": "local_auth_success",
|
||||
"access_token": access_token,
|
||||
"token_data": token.dict(),
|
||||
"authenticationAuthority": "local"
|
||||
"message": "Login successful - tokens set in httpOnly cookies",
|
||||
"authenticationAuthority": "local",
|
||||
"expires_at": expires_at.isoformat()
|
||||
}
|
||||
|
||||
return response_data
|
||||
|
|
@ -138,6 +155,24 @@ async def login(
|
|||
# Handle authentication errors
|
||||
error_msg = str(e)
|
||||
logger.warning(f"Authentication failed for user {formData.username}: {error_msg}")
|
||||
|
||||
# Check if user is disabled and provide specific message
|
||||
if error_msg == "User is disabled":
|
||||
error_msg = "Your account is disabled. Please send an email to p.motsch@valueon.ch to get access to the PowerOn center."
|
||||
|
||||
# Log failed login attempt
|
||||
try:
|
||||
from modules.shared.auditLogger import audit_logger
|
||||
audit_logger.logUserAccess(
|
||||
userId="unknown",
|
||||
mandateId="unknown",
|
||||
action="login",
|
||||
successInfo=f"failed: {error_msg}"
|
||||
)
|
||||
except Exception:
|
||||
# Don't fail if audit logging fails
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=error_msg,
|
||||
|
|
@ -165,7 +200,7 @@ async def register_user(
|
|||
appInterface = getRootInterface()
|
||||
|
||||
# Get default mandate ID
|
||||
from modules.interfaces.interfaceAppModel import Mandate
|
||||
from modules.datamodels.datamodelUam import Mandate
|
||||
defaultMandateId = appInterface.getInitialId(Mandate)
|
||||
if not defaultMandateId:
|
||||
raise HTTPException(
|
||||
|
|
@ -178,14 +213,15 @@ async def register_user(
|
|||
|
||||
# Create user with local authentication
|
||||
# Set safe default privilege level for new registrations
|
||||
from modules.interfaces.interfaceAppModel import UserPrivilege
|
||||
# New users are disabled by default and require admin approval
|
||||
from modules.datamodels.datamodelUam import UserPrivilege
|
||||
user = appInterface.createUser(
|
||||
username=userData.username,
|
||||
password=password,
|
||||
email=userData.email,
|
||||
fullName=userData.fullName,
|
||||
language=userData.language,
|
||||
enabled=userData.enabled,
|
||||
enabled=False, # New users are disabled by default
|
||||
privilege=UserPrivilege.USER, # Always set to USER for new registrations
|
||||
authenticationAuthority=AuthAuthority.LOCAL
|
||||
)
|
||||
|
|
@ -226,20 +262,100 @@ async def read_user_me(
|
|||
detail=f"Failed to get current user: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/refresh")
|
||||
@limiter.limit("60/minute")
|
||||
async def refresh_token(
|
||||
request: Request,
|
||||
response: Response
|
||||
) -> Dict[str, Any]:
|
||||
"""Refresh access token using refresh token from cookie"""
|
||||
try:
|
||||
# Get refresh token from cookie
|
||||
refresh_token = request.cookies.get('refresh_token')
|
||||
if not refresh_token:
|
||||
raise HTTPException(status_code=401, detail="No refresh token found")
|
||||
|
||||
# Validate refresh token
|
||||
try:
|
||||
payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
if payload.get("type") != "refresh":
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token type")
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(status_code=401, detail="Refresh token expired")
|
||||
except jwt.JWTError:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
|
||||
# Get user information from refresh token payload
|
||||
user_id = payload.get("userId")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token - missing user ID")
|
||||
|
||||
# Get user from database using the user ID from refresh token
|
||||
try:
|
||||
app_interface = getRootInterface()
|
||||
current_user = app_interface.getUser(user_id)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user from database: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to validate user")
|
||||
|
||||
# Create new token data
|
||||
token_data = {
|
||||
"sub": current_user.username,
|
||||
"mandateId": str(current_user.mandateId),
|
||||
"userId": str(current_user.id),
|
||||
"authenticationAuthority": current_user.authenticationAuthority
|
||||
}
|
||||
|
||||
# Create new access token + set cookie
|
||||
access_token, _expires = createAccessToken(token_data)
|
||||
setAccessTokenCookie(response, access_token)
|
||||
|
||||
# Get expiration time
|
||||
try:
|
||||
payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
expires_at = datetime.fromtimestamp(payload.get("exp"))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decode new access token: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create new token")
|
||||
|
||||
return {
|
||||
"type": "token_refresh_success",
|
||||
"message": "Token refreshed successfully",
|
||||
"expires_at": expires_at.isoformat()
|
||||
}
|
||||
|
||||
except HTTPException as e:
|
||||
# If it's a 503 error (service unavailable due to missing token table), return it as-is
|
||||
if e.status_code == 503:
|
||||
raise
|
||||
# For other HTTP exceptions, re-raise them
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Token refresh failed")
|
||||
|
||||
@router.post("/logout")
|
||||
@limiter.limit("30/minute")
|
||||
async def logout(request: Request, currentUser: User = Depends(getCurrentUser)) -> JSONResponse:
|
||||
async def logout(request: Request, response: Response, currentUser: User = Depends(getCurrentUser)) -> JSONResponse:
|
||||
"""Logout from local authentication"""
|
||||
try:
|
||||
# Get user interface with current user context
|
||||
appInterface = getInterface(currentUser)
|
||||
# Read bearer token from Authorization header to obtain session id / jti
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing Authorization header")
|
||||
raw_token = auth_header.split(" ", 1)[1].strip()
|
||||
|
||||
# Get token from cookie or Authorization header
|
||||
token = request.cookies.get('auth_token')
|
||||
if not token:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.lower().startswith("bearer "):
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No token found")
|
||||
|
||||
try:
|
||||
payload = jwt.decode(raw_token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
session_id = payload.get("sid") or payload.get("sessionId")
|
||||
jti = payload.get("jti")
|
||||
except Exception as e:
|
||||
|
|
@ -253,11 +369,31 @@ async def logout(request: Request, currentUser: User = Depends(getCurrentUser))
|
|||
appInterface.revokeTokenById(jti, revokedBy=currentUser.id, reason="logout")
|
||||
revoked = 1
|
||||
|
||||
return JSONResponse({
|
||||
"message": "Successfully logged out",
|
||||
# Log successful logout
|
||||
try:
|
||||
from modules.shared.auditLogger import audit_logger
|
||||
audit_logger.logUserAccess(
|
||||
userId=str(currentUser.id),
|
||||
mandateId=str(currentUser.mandateId),
|
||||
action="logout",
|
||||
successInfo=f"revoked_tokens: {revoked}"
|
||||
)
|
||||
except Exception:
|
||||
# Don't fail if audit logging fails
|
||||
pass
|
||||
|
||||
# Create the JSON response first
|
||||
json_response = JSONResponse({
|
||||
"message": "Successfully logged out - cookies cleared",
|
||||
"revokedTokens": revoked
|
||||
})
|
||||
|
||||
# Clear httpOnly cookies on the response we're actually returning
|
||||
clearAccessTokenCookie(json_response)
|
||||
clearRefreshTokenCookie(json_response)
|
||||
|
||||
return json_response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -7,16 +7,16 @@ from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
|||
import logging
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import msal
|
||||
import httpx
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.interfaces.interfaceAppObjects import getInterface, getRootInterface
|
||||
from modules.interfaces.interfaceAppModel import AuthAuthority, User, Token, ConnectionStatus, UserConnection
|
||||
from modules.security.auth import getCurrentUser, limiter, createAccessToken
|
||||
from modules.shared.attributeUtils import ModelMixin
|
||||
from modules.shared.timezoneUtils import get_utc_now, create_expiration_timestamp, get_utc_timestamp
|
||||
from modules.interfaces.interfaceDbAppObjects import getInterface, getRootInterface
|
||||
from modules.datamodels.datamodelUam import AuthAuthority, User, ConnectionStatus, UserConnection
|
||||
from modules.datamodels.datamodelSecurity import Token
|
||||
from modules.security.auth import getCurrentUser, limiter
|
||||
from modules.security.jwtService import createAccessToken
|
||||
from modules.shared.timezoneUtils import createExpirationTimestamp, getUtcTimestamp
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -199,8 +199,8 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
tokenAccess=token_response["access_token"],
|
||||
tokenRefresh=token_response.get("refresh_token", ""),
|
||||
tokenType=token_response.get("token_type", "bearer"),
|
||||
expiresAt=create_expiration_timestamp(token_response.get("expires_in", 0)),
|
||||
createdAt=get_utc_timestamp()
|
||||
expiresAt=createExpirationTimestamp(token_response.get("expires_in", 0)),
|
||||
createdAt=getUtcTimestamp()
|
||||
)
|
||||
|
||||
# Save access token (no connectionId)
|
||||
|
|
@ -225,14 +225,14 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
tokenAccess=jwt_token,
|
||||
tokenType="bearer",
|
||||
expiresAt=jwt_expires_at.timestamp(),
|
||||
createdAt=get_utc_timestamp()
|
||||
createdAt=getUtcTimestamp()
|
||||
)
|
||||
|
||||
# Save JWT access token
|
||||
appInterface.saveAccessToken(jwt_token_obj)
|
||||
|
||||
# Convert token to dict and ensure proper timestamp handling
|
||||
token_dict = jwt_token_obj.to_dict()
|
||||
token_dict = jwt_token_obj.model_dump()
|
||||
# Remove datetime conversion logic - models now handle this automatically
|
||||
# The token model already returns float timestamps
|
||||
|
||||
|
|
@ -332,14 +332,14 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
logger.info(f"Updating connection {connection_id} for user {user.username}")
|
||||
# Update connection with external service details
|
||||
connection.status = ConnectionStatus.ACTIVE
|
||||
connection.lastChecked = get_utc_timestamp()
|
||||
connection.expiresAt = get_utc_timestamp() + token_response.get("expires_in", 0)
|
||||
connection.lastChecked = getUtcTimestamp()
|
||||
connection.expiresAt = getUtcTimestamp() + token_response.get("expires_in", 0)
|
||||
connection.externalId = user_info.get("id")
|
||||
connection.externalUsername = user_info.get("userPrincipalName")
|
||||
connection.externalEmail = user_info.get("mail")
|
||||
|
||||
# Update connection record directly
|
||||
rootInterface.db.recordModify(UserConnection, connection_id, connection.to_dict())
|
||||
rootInterface.db.recordModify(UserConnection, connection_id, connection.model_dump())
|
||||
|
||||
|
||||
# Save token
|
||||
|
|
@ -351,8 +351,8 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
tokenAccess=token_response["access_token"],
|
||||
tokenRefresh=token_response.get("refresh_token", ""),
|
||||
tokenType=token_response.get("token_type", "bearer"),
|
||||
expiresAt=create_expiration_timestamp(token_response.get("expires_in", 0)),
|
||||
createdAt=get_utc_timestamp()
|
||||
expiresAt=createExpirationTimestamp(token_response.get("expires_in", 0)),
|
||||
createdAt=getUtcTimestamp()
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -373,8 +373,8 @@ async def auth_callback(code: str, state: str, request: Request) -> HTMLResponse
|
|||
id: '{connection.id}',
|
||||
status: 'connected',
|
||||
type: 'msft',
|
||||
lastChecked: {get_utc_timestamp()},
|
||||
expiresAt: {create_expiration_timestamp(token_response.get("expires_in", 0))}
|
||||
lastChecked: {getUtcTimestamp()},
|
||||
expiresAt: {createExpirationTimestamp(token_response.get("expires_in", 0))}
|
||||
}}
|
||||
}}, '*');
|
||||
// Wait for message to be sent before closing
|
||||
|
|
@ -463,6 +463,20 @@ async def logout(
|
|||
try:
|
||||
appInterface = getInterface(currentUser)
|
||||
appInterface.logout()
|
||||
|
||||
# Log successful logout
|
||||
try:
|
||||
from modules.shared.auditLogger import audit_logger
|
||||
audit_logger.logUserAccess(
|
||||
userId=str(currentUser.id),
|
||||
mandateId=str(currentUser.mandateId),
|
||||
action="logout",
|
||||
successInfo="microsoft_auth_logout"
|
||||
)
|
||||
except Exception:
|
||||
# Don't fail if audit logging fails
|
||||
pass
|
||||
|
||||
return {"message": "Logged out successfully"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error during logout: {str(e)}")
|
||||
|
|
@ -545,9 +559,9 @@ async def refresh_token(
|
|||
|
||||
logger.debug(f"Found Microsoft connection: {msft_connection.id}, status={msft_connection.status}")
|
||||
|
||||
# Get the token for this specific connection using the new method
|
||||
# Enable auto-refresh to handle expired tokens gracefully
|
||||
current_token = appInterface.getConnectionToken(msft_connection.id, auto_refresh=True)
|
||||
# Get a fresh token via TokenManager convenience method
|
||||
from modules.security.tokenManager import TokenManager
|
||||
current_token = TokenManager().getFreshToken(msft_connection.id)
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(
|
||||
|
|
@ -561,27 +575,27 @@ async def refresh_token(
|
|||
from modules.security.tokenManager import TokenManager
|
||||
token_manager = TokenManager()
|
||||
|
||||
refreshed_token = token_manager.refresh_token(current_token)
|
||||
if refreshed_token:
|
||||
refreshedToken = token_manager.refreshToken(current_token)
|
||||
if refreshedToken:
|
||||
# Save the new connection token (which will automatically replace old ones)
|
||||
appInterface.saveConnectionToken(refreshed_token)
|
||||
appInterface.saveConnectionToken(refreshedToken)
|
||||
|
||||
# Update the connection's expiration time
|
||||
msft_connection.expiresAt = float(refreshed_token.expiresAt)
|
||||
msft_connection.lastChecked = get_utc_timestamp()
|
||||
msft_connection.expiresAt = float(refreshedToken.expiresAt)
|
||||
msft_connection.lastChecked = getUtcTimestamp()
|
||||
msft_connection.status = ConnectionStatus.ACTIVE
|
||||
|
||||
# Save updated connection
|
||||
appInterface.db.recordModify(UserConnection, msft_connection.id, msft_connection.to_dict())
|
||||
appInterface.db.recordModify(UserConnection, msft_connection.id, msft_connection.model_dump())
|
||||
|
||||
# Calculate time until expiration
|
||||
current_time = get_utc_timestamp()
|
||||
expires_in = int(refreshed_token.expiresAt - current_time)
|
||||
current_time = getUtcTimestamp()
|
||||
expiresIn = int(refreshedToken.expiresAt - current_time)
|
||||
|
||||
return {
|
||||
"message": "Token refreshed successfully",
|
||||
"expires_at": refreshed_token.expiresAt,
|
||||
"expires_in_seconds": expires_in
|
||||
"expires_at": refreshedToken.expiresAt,
|
||||
"expires_in_seconds": expiresIn
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
146
modules/routes/routeSharepoint.py
Normal file
146
modules/routes/routeSharepoint.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
"""
|
||||
SharePoint routes for folder browsing
|
||||
Provides endpoints for listing SharePoint sites and browsing folders
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, Path, Query, Request, status
|
||||
|
||||
from modules.security.auth import limiter, getCurrentUser
|
||||
from modules.datamodels.datamodelUam import User, UserConnection
|
||||
from modules.interfaces.interfaceDbAppObjects import getInterface
|
||||
from modules.services import getInterface as getServices
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/sharepoint",
|
||||
tags=["SharePoint"],
|
||||
responses={
|
||||
404: {"description": "Not found"},
|
||||
400: {"description": "Bad request"},
|
||||
401: {"description": "Unauthorized"},
|
||||
500: {"description": "Internal server error"}
|
||||
}
|
||||
)
|
||||
|
||||
def _getUserConnection(interface, connectionId: str, userId: str) -> Optional[UserConnection]:
|
||||
"""Get a user connection by ID, ensuring it belongs to the user"""
|
||||
try:
|
||||
connections = interface.getUserConnections(userId)
|
||||
for conn in connections:
|
||||
if conn.id == connectionId:
|
||||
return conn
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user connection: {str(e)}")
|
||||
return None
|
||||
|
||||
@router.get("/{connectionId}/sites", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("30/minute")
|
||||
async def get_sharepoint_sites(
|
||||
request: Request,
|
||||
connectionId: str = Path(..., description="Microsoft connection ID"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get all SharePoint sites accessible via a Microsoft connection"""
|
||||
try:
|
||||
interface = getInterface(currentUser)
|
||||
|
||||
# Get the connection and verify it belongs to the user
|
||||
connection = _getUserConnection(interface, connectionId, currentUser.id)
|
||||
if not connection:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Connection {connectionId} not found or does not belong to user"
|
||||
)
|
||||
|
||||
# Verify it's a Microsoft connection
|
||||
authority = connection.authority.value if hasattr(connection.authority, 'value') else str(connection.authority)
|
||||
if authority.lower() != 'msft':
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection {connectionId} is not a Microsoft connection"
|
||||
)
|
||||
|
||||
# Initialize services
|
||||
services = getServices(currentUser, None)
|
||||
|
||||
# Set access token on SharePoint service
|
||||
if not services.sharepoint.setAccessTokenFromConnection(connection):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Failed to set SharePoint access token. Connection may be expired or invalid."
|
||||
)
|
||||
|
||||
# Discover SharePoint sites
|
||||
sites = await services.sharepoint.discoverSites()
|
||||
|
||||
return sites
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting SharePoint sites: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error getting SharePoint sites: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/{connectionId}/sites/{siteId}/folders", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("60/minute")
|
||||
async def list_sharepoint_folders(
|
||||
request: Request,
|
||||
connectionId: str = Path(..., description="Microsoft connection ID"),
|
||||
siteId: str = Path(..., description="SharePoint site ID"),
|
||||
path: Optional[str] = Query(None, description="Folder path (empty for root)"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List folder contents for a SharePoint site and folder path"""
|
||||
try:
|
||||
interface = getInterface(currentUser)
|
||||
|
||||
# Get the connection and verify it belongs to the user
|
||||
connection = _getUserConnection(interface, connectionId, currentUser.id)
|
||||
if not connection:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Connection {connectionId} not found or does not belong to user"
|
||||
)
|
||||
|
||||
# Verify it's a Microsoft connection
|
||||
authority = connection.authority.value if hasattr(connection.authority, 'value') else str(connection.authority)
|
||||
if authority.lower() != 'msft':
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection {connectionId} is not a Microsoft connection"
|
||||
)
|
||||
|
||||
# Initialize services
|
||||
services = getServices(currentUser, None)
|
||||
|
||||
# Set access token on SharePoint service
|
||||
if not services.sharepoint.setAccessTokenFromConnection(connection):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Failed to set SharePoint access token. Connection may be expired or invalid."
|
||||
)
|
||||
|
||||
# Normalize folder path (empty string for root)
|
||||
folderPath = path or ''
|
||||
|
||||
# List folder contents
|
||||
items = await services.sharepoint.listFolderContents(siteId, folderPath)
|
||||
|
||||
return items or []
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing SharePoint folders: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error listing SharePoint folders: {str(e)}"
|
||||
)
|
||||
|
||||
|
|
@ -1,59 +1,80 @@
|
|||
"""
|
||||
Google Cloud Voice Services Routes
|
||||
Replaces Azure voice services with Google Cloud Speech-to-Text and Translation
|
||||
Includes WebSocket support for real-time voice streaming
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException, Body
|
||||
import json
|
||||
import base64
|
||||
from fastapi import APIRouter, File, Form, UploadFile, Depends, HTTPException, Body, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import Response
|
||||
from typing import Optional, Dict, Any
|
||||
from modules.connectors.connectorGoogleSpeech import ConnectorGoogleSpeech
|
||||
from typing import Optional, Dict, Any, List
|
||||
from modules.security.auth import getCurrentUser
|
||||
from modules.interfaces.interfaceAppModel import User
|
||||
from modules.interfaces.interfaceComponentObjects import getInterface
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.interfaces.interfaceVoiceObjects import getVoiceInterface, VoiceObjects
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/voice-google", tags=["Voice Google"])
|
||||
|
||||
# Global connector instance
|
||||
_google_speech_connector = None
|
||||
# Store active WebSocket connections
|
||||
activeConnections: Dict[str, WebSocket] = {}
|
||||
|
||||
def get_google_speech_connector() -> ConnectorGoogleSpeech:
|
||||
"""Get or create Google Cloud Speech connector instance."""
|
||||
global _google_speech_connector
|
||||
|
||||
if _google_speech_connector is None:
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.activeConnections: List[WebSocket] = []
|
||||
|
||||
async def connect(self, websocket: WebSocket, connectionId: str):
|
||||
await websocket.accept()
|
||||
self.activeConnections.append(websocket)
|
||||
activeConnections[connectionId] = websocket
|
||||
logger.info(f"WebSocket connected: {connectionId}")
|
||||
|
||||
def disconnect(self, websocket: WebSocket, connectionId: str):
|
||||
if websocket in self.activeConnections:
|
||||
self.activeConnections.remove(websocket)
|
||||
if connectionId in activeConnections:
|
||||
del activeConnections[connectionId]
|
||||
logger.info(f"WebSocket disconnected: {connectionId}")
|
||||
|
||||
async def sendPersonalMessage(self, message: dict, websocket: WebSocket):
|
||||
try:
|
||||
_google_speech_connector = ConnectorGoogleSpeech()
|
||||
logger.info("✅ Google Cloud Speech connector initialized")
|
||||
|
||||
await websocket.send_text(json.dumps(message))
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to initialize Google Cloud Speech connector: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to initialize Google Cloud Speech connector: {str(e)}"
|
||||
)
|
||||
|
||||
return _google_speech_connector
|
||||
logger.error(f"Error sending message: {e}")
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
def _getVoiceInterface(currentUser: User) -> VoiceObjects:
|
||||
"""Get voice interface instance with user context."""
|
||||
try:
|
||||
return getVoiceInterface(currentUser)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize voice interface: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to initialize voice interface: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post("/speech-to-text")
|
||||
async def speech_to_text(
|
||||
audio_file: UploadFile = File(...),
|
||||
audioFile: UploadFile = File(...),
|
||||
language: str = Form("de-DE"),
|
||||
current_user: User = Depends(getCurrentUser)
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
):
|
||||
"""Convert speech to text using Google Cloud Speech-to-Text API."""
|
||||
try:
|
||||
logger.info(f"🎤 Speech-to-text request: {audio_file.filename}, language: {language}")
|
||||
logger.info(f"🎤 Speech-to-text request: {audioFile.filename}, language: {language}")
|
||||
|
||||
# Read audio file
|
||||
audio_content = await audio_file.read()
|
||||
logger.info(f"📊 Audio file size: {len(audio_content)} bytes")
|
||||
audioContent = await audioFile.read()
|
||||
logger.info(f"📊 Audio file size: {len(audioContent)} bytes")
|
||||
|
||||
# Get voice interface
|
||||
voiceInterface = _getVoiceInterface(currentUser)
|
||||
|
||||
# Validate audio format
|
||||
connector = get_google_speech_connector()
|
||||
validation = connector.validate_audio_format(audio_content)
|
||||
validation = voiceInterface.validateAudioFormat(audioContent)
|
||||
|
||||
if not validation["valid"]:
|
||||
raise HTTPException(
|
||||
|
|
@ -62,8 +83,8 @@ async def speech_to_text(
|
|||
)
|
||||
|
||||
# Perform speech recognition
|
||||
result = await connector.speech_to_text(
|
||||
audio_content=audio_content,
|
||||
result = await voiceInterface.speechToText(
|
||||
audioContent=audioContent,
|
||||
language=language
|
||||
)
|
||||
|
||||
|
|
@ -74,7 +95,7 @@ async def speech_to_text(
|
|||
"confidence": result["confidence"],
|
||||
"language": result["language"],
|
||||
"audio_info": {
|
||||
"size": len(audio_content),
|
||||
"size": len(audioContent),
|
||||
"format": validation["format"],
|
||||
"estimated_duration": validation.get("estimated_duration", 0)
|
||||
}
|
||||
|
|
@ -97,13 +118,13 @@ async def speech_to_text(
|
|||
@router.post("/translate")
|
||||
async def translate_text(
|
||||
text: str = Form(...),
|
||||
source_language: str = Form("de"),
|
||||
target_language: str = Form("en"),
|
||||
current_user: User = Depends(getCurrentUser)
|
||||
sourceLanguage: str = Form("de"),
|
||||
targetLanguage: str = Form("en"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
):
|
||||
"""Translate text using Google Cloud Translation API."""
|
||||
try:
|
||||
logger.info(f"🌐 Translation request: '{text}' ({source_language} -> {target_language})")
|
||||
logger.info(f"🌐 Translation request: '{text}' ({sourceLanguage} -> {targetLanguage})")
|
||||
|
||||
if not text.strip():
|
||||
raise HTTPException(
|
||||
|
|
@ -111,12 +132,14 @@ async def translate_text(
|
|||
detail="Empty text provided for translation"
|
||||
)
|
||||
|
||||
# Get voice interface
|
||||
voiceInterface = _getVoiceInterface(currentUser)
|
||||
|
||||
# Perform translation
|
||||
connector = get_google_speech_connector()
|
||||
result = await connector.translate_text(
|
||||
result = await voiceInterface.translateText(
|
||||
text=text,
|
||||
source_language=source_language,
|
||||
target_language=target_language
|
||||
sourceLanguage=sourceLanguage,
|
||||
targetLanguage=targetLanguage
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
|
|
@ -144,33 +167,35 @@ async def translate_text(
|
|||
|
||||
@router.post("/realtime-interpreter")
|
||||
async def realtime_interpreter(
|
||||
audio_file: UploadFile = File(...),
|
||||
from_language: str = Form("de-DE"),
|
||||
to_language: str = Form("en-US"),
|
||||
connection_id: str = Form(None),
|
||||
current_user: User = Depends(getCurrentUser)
|
||||
audioFile: UploadFile = File(...),
|
||||
fromLanguage: str = Form("de-DE"),
|
||||
toLanguage: str = Form("en-US"),
|
||||
connectionId: str = Form(None),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
):
|
||||
"""Real-time interpreter: speech to translated text using Google Cloud APIs."""
|
||||
try:
|
||||
logger.info(f"🔄 Real-time interpreter request: {audio_file.filename}")
|
||||
logger.info(f" From: {from_language} -> To: {to_language}")
|
||||
logger.info(f" MIME type: {audio_file.content_type}")
|
||||
logger.info(f"🔄 Real-time interpreter request: {audioFile.filename}")
|
||||
logger.info(f" From: {fromLanguage} -> To: {toLanguage}")
|
||||
logger.info(f" MIME type: {audioFile.content_type}")
|
||||
|
||||
# Read audio file
|
||||
audio_content = await audio_file.read()
|
||||
logger.info(f"📊 Audio file size: {len(audio_content)} bytes")
|
||||
audioContent = await audioFile.read()
|
||||
logger.info(f"📊 Audio file size: {len(audioContent)} bytes")
|
||||
|
||||
# Save audio file for debugging with correct extension
|
||||
file_extension = "webm" if audio_file.filename.endswith('.webm') else "wav"
|
||||
debug_filename = f"debug_audio/audio_google_{audio_file.filename.replace('.wav', '.webm')}"
|
||||
os.makedirs("debug_audio", exist_ok=True)
|
||||
with open(debug_filename, "wb") as f:
|
||||
f.write(audio_content)
|
||||
logger.info(f"💾 Saved audio file for debugging: {debug_filename}")
|
||||
# file_extension = "webm" if audio_file.filename.endswith('.webm') else "wav"
|
||||
# debug_filename = f"debug_audio/audio_google_{audio_file.filename.replace('.wav', '.webm')}"
|
||||
# os.makedirs("debug_audio", exist_ok=True)
|
||||
# with open(debug_filename, "wb") as f:
|
||||
# f.write(audio_content)
|
||||
# logger.info(f"💾 Saved audio file for debugging: {debug_filename}")
|
||||
|
||||
# Get voice interface
|
||||
voiceInterface = _getVoiceInterface(currentUser)
|
||||
|
||||
# Validate audio format
|
||||
connector = get_google_speech_connector()
|
||||
validation = connector.validate_audio_format(audio_content)
|
||||
validation = voiceInterface.validateAudioFormat(audioContent)
|
||||
|
||||
if not validation["valid"]:
|
||||
raise HTTPException(
|
||||
|
|
@ -179,10 +204,10 @@ async def realtime_interpreter(
|
|||
)
|
||||
|
||||
# Perform complete pipeline: Speech-to-Text + Translation
|
||||
result = await connector.speech_to_translated_text(
|
||||
audio_content=audio_content,
|
||||
from_language=from_language,
|
||||
to_language=to_language
|
||||
result = await voiceInterface.speechToTranslatedText(
|
||||
audioContent=audioContent,
|
||||
fromLanguage=fromLanguage,
|
||||
toLanguage=toLanguage
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
|
|
@ -198,7 +223,7 @@ async def realtime_interpreter(
|
|||
"source_language": result["source_language"],
|
||||
"target_language": result["target_language"],
|
||||
"audio_info": {
|
||||
"size": len(audio_content),
|
||||
"size": len(audioContent),
|
||||
"format": validation["format"],
|
||||
"estimated_duration": validation.get("estimated_duration", 0)
|
||||
}
|
||||
|
|
@ -224,7 +249,7 @@ async def text_to_speech(
|
|||
text: str = Form(...),
|
||||
language: str = Form("de-DE"),
|
||||
voice: str = Form(None),
|
||||
current_user: User = Depends(getCurrentUser)
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
):
|
||||
"""Convert text to speech using Google Cloud Text-to-Speech."""
|
||||
try:
|
||||
|
|
@ -236,11 +261,11 @@ async def text_to_speech(
|
|||
detail="Empty text provided for text-to-speech"
|
||||
)
|
||||
|
||||
connector = get_google_speech_connector()
|
||||
result = await connector.text_to_speech(
|
||||
voiceInterface = _getVoiceInterface(currentUser)
|
||||
result = await voiceInterface.textToSpeech(
|
||||
text=text,
|
||||
language_code=language,
|
||||
voice_name=voice
|
||||
languageCode=language,
|
||||
voiceName=voice
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
|
|
@ -268,30 +293,84 @@ async def text_to_speech(
|
|||
detail=f"Text-to-Speech processing failed: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check(current_user: User = Depends(getCurrentUser)):
|
||||
"""Health check for Google Cloud voice services."""
|
||||
@router.get("/languages")
|
||||
async def get_available_languages(currentUser: User = Depends(getCurrentUser)):
|
||||
"""Get available languages from Google Cloud Text-to-Speech."""
|
||||
try:
|
||||
connector = get_google_speech_connector()
|
||||
logger.info("🌐 Getting available languages from Google Cloud TTS")
|
||||
|
||||
# Test with a simple translation
|
||||
test_result = await connector.translate_text(
|
||||
text="Hello",
|
||||
source_language="en",
|
||||
target_language="de"
|
||||
)
|
||||
voiceInterface = _getVoiceInterface(currentUser)
|
||||
result = await voiceInterface.getAvailableLanguages()
|
||||
|
||||
if test_result["success"]:
|
||||
if result["success"]:
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "Google Cloud Speech-to-Text & Translation",
|
||||
"test_translation": test_result["translated_text"]
|
||||
"success": True,
|
||||
"languages": result["languages"]
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to get languages: {result.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get languages error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get available languages: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/voices")
|
||||
async def get_available_voices(
|
||||
languageCode: Optional[str] = None,
|
||||
language_code: Optional[str] = None, # Accept both camelCase and snake_case
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
):
|
||||
"""
|
||||
Get available voices from Google Cloud Text-to-Speech.
|
||||
Accepts languageCode (camelCase) or language_code (snake_case) query parameter.
|
||||
"""
|
||||
# Use language_code if provided (frontend sends this), otherwise use languageCode
|
||||
if language_code:
|
||||
languageCode = language_code
|
||||
|
||||
try:
|
||||
logger.info(f"🎤 Getting available voices, language filter: {languageCode}")
|
||||
|
||||
voiceInterface = _getVoiceInterface(currentUser)
|
||||
result = await voiceInterface.getAvailableVoices(languageCode=languageCode)
|
||||
|
||||
if result["success"]:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": test_result.get("error", "Unknown error")
|
||||
"success": True,
|
||||
"voices": result["voices"],
|
||||
"language_filter": languageCode
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to get voices: {result.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get voices error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get available voices: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check(currentUser: User = Depends(getCurrentUser)):
|
||||
"""Health check for Google Cloud voice services."""
|
||||
try:
|
||||
voiceInterface = _getVoiceInterface(currentUser)
|
||||
test_result = await voiceInterface.healthCheck()
|
||||
|
||||
return test_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Health check failed: {e}")
|
||||
|
|
@ -301,23 +380,23 @@ async def health_check(current_user: User = Depends(getCurrentUser)):
|
|||
}
|
||||
|
||||
@router.get("/settings")
|
||||
async def get_voice_settings(current_user: User = Depends(getCurrentUser)):
|
||||
async def get_voice_settings(currentUser: User = Depends(getCurrentUser)):
|
||||
"""Get voice settings for the current user."""
|
||||
try:
|
||||
logger.info(f"Getting voice settings for user: {current_user.id}")
|
||||
logger.info(f"Getting voice settings for user: {currentUser.id}")
|
||||
|
||||
# Get database interface with user context
|
||||
interface = getInterface(current_user)
|
||||
# Get voice interface
|
||||
voiceInterface = _getVoiceInterface(currentUser)
|
||||
|
||||
# Get or create voice settings for the user
|
||||
voice_settings = interface.getOrCreateVoiceSettings(current_user.id)
|
||||
voice_settings = voiceInterface.getOrCreateVoiceSettings(currentUser.id)
|
||||
|
||||
if voice_settings:
|
||||
# Return user settings
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"user_settings": voice_settings.to_dict(),
|
||||
"user_settings": voice_settings.model_dump(),
|
||||
"default_settings": {
|
||||
"sttLanguage": "de-DE",
|
||||
"ttsLanguage": "de-DE",
|
||||
|
|
@ -354,16 +433,16 @@ async def get_voice_settings(current_user: User = Depends(getCurrentUser)):
|
|||
@router.post("/settings")
|
||||
async def save_voice_settings(
|
||||
settings: Dict[str, Any] = Body(...),
|
||||
current_user: User = Depends(getCurrentUser)
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
):
|
||||
"""Save voice settings for the current user."""
|
||||
try:
|
||||
logger.info(f"Saving voice settings for user: {current_user.id}")
|
||||
logger.info(f"Saving voice settings for user: {currentUser.id}")
|
||||
logger.info(f"Settings: {settings}")
|
||||
|
||||
# Validate required settings
|
||||
required_fields = ["sttLanguage", "ttsLanguage", "ttsVoice"]
|
||||
for field in required_fields:
|
||||
requiredFields = ["sttLanguage", "ttsLanguage", "ttsVoice"]
|
||||
for field in requiredFields:
|
||||
if field not in settings:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
|
@ -376,24 +455,24 @@ async def save_voice_settings(
|
|||
if "targetLanguage" not in settings:
|
||||
settings["targetLanguage"] = "en-US"
|
||||
|
||||
# Get database interface with user context
|
||||
interface = getInterface(current_user)
|
||||
# Get voice interface
|
||||
voiceInterface = _getVoiceInterface(currentUser)
|
||||
|
||||
# Check if settings already exist for this user
|
||||
existing_settings = interface.getVoiceSettings(current_user.id)
|
||||
existing_settings = voiceInterface.getVoiceSettings(currentUser.id)
|
||||
|
||||
if existing_settings:
|
||||
# Update existing settings
|
||||
logger.info(f"Updating existing voice settings for user {current_user.id}")
|
||||
updated_settings = interface.updateVoiceSettings(current_user.id, settings)
|
||||
logger.info(f"Voice settings updated for user {current_user.id}: {updated_settings}")
|
||||
logger.info(f"Updating existing voice settings for user {currentUser.id}")
|
||||
updated_settings = voiceInterface.updateVoiceSettings(currentUser.id, settings)
|
||||
logger.info(f"Voice settings updated for user {currentUser.id}: {updated_settings}")
|
||||
else:
|
||||
# Create new settings
|
||||
logger.info(f"Creating new voice settings for user {current_user.id}")
|
||||
logger.info(f"Creating new voice settings for user {currentUser.id}")
|
||||
# Add userId to settings
|
||||
settings["userId"] = current_user.id
|
||||
created_settings = interface.createVoiceSettings(settings)
|
||||
logger.info(f"Voice settings created for user {current_user.id}: {created_settings}")
|
||||
settings["userId"] = currentUser.id
|
||||
created_settings = voiceInterface.createVoiceSettings(settings)
|
||||
logger.info(f"Voice settings created for user {currentUser.id}: {created_settings}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
|
|
@ -409,3 +488,190 @@ async def save_voice_settings(
|
|||
status_code=500,
|
||||
detail=f"Failed to save voice settings: {str(e)}"
|
||||
)
|
||||
|
||||
# WebSocket endpoints for real-time voice streaming
|
||||
|
||||
@router.websocket("/ws/realtime-interpreter")
|
||||
async def websocket_realtime_interpreter(
|
||||
websocket: WebSocket,
|
||||
userId: str = "default",
|
||||
fromLanguage: str = "de-DE",
|
||||
toLanguage: str = "en-US"
|
||||
):
|
||||
"""WebSocket endpoint for real-time voice interpretation"""
|
||||
connectionId = f"realtime_{userId}_{fromLanguage}_{toLanguage}"
|
||||
|
||||
try:
|
||||
await manager.connect(websocket, connectionId)
|
||||
|
||||
# Send connection confirmation
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "connected",
|
||||
"connection_id": connectionId,
|
||||
"message": "Connected to real-time interpreter"
|
||||
}, websocket)
|
||||
|
||||
# Initialize voice interface
|
||||
voiceInterface = _getVoiceInterface(User(id=userId))
|
||||
|
||||
while True:
|
||||
# Receive message from client
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
|
||||
if message["type"] == "audio_chunk":
|
||||
# Process audio chunk
|
||||
try:
|
||||
# Decode base64 audio data
|
||||
audioData = base64.b64decode(message["data"])
|
||||
|
||||
# For now, just acknowledge receipt
|
||||
# In a full implementation, this would:
|
||||
# 1. Buffer audio chunks
|
||||
# 2. Process with Google Cloud Speech-to-Text streaming
|
||||
# 3. Send partial results back
|
||||
# 4. Handle translation
|
||||
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "audio_received",
|
||||
"chunk_size": len(audioData),
|
||||
"timestamp": message.get("timestamp")
|
||||
}, websocket)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio chunk: {e}")
|
||||
await manager.send_personal_message({
|
||||
"type": "error",
|
||||
"error": f"Failed to process audio: {str(e)}"
|
||||
}, websocket)
|
||||
|
||||
elif message["type"] == "ping":
|
||||
# Respond to ping
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "pong",
|
||||
"timestamp": message.get("timestamp")
|
||||
}, websocket)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown message type: {message['type']}")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, connectionId)
|
||||
logger.info(f"Client disconnected: {connectionId}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
manager.disconnect(websocket, connectionId)
|
||||
|
||||
@router.websocket("/ws/speech-to-text")
|
||||
async def websocket_speech_to_text(
|
||||
websocket: WebSocket,
|
||||
userId: str = "default",
|
||||
language: str = "de-DE"
|
||||
):
|
||||
"""WebSocket endpoint for real-time speech-to-text"""
|
||||
connectionId = f"stt_{userId}_{language}"
|
||||
|
||||
try:
|
||||
await manager.connect(websocket, connectionId)
|
||||
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "connected",
|
||||
"connection_id": connectionId,
|
||||
"message": "Connected to speech-to-text"
|
||||
}, websocket)
|
||||
|
||||
# Initialize voice interface
|
||||
voiceInterface = _getVoiceInterface(User(id=userId))
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
|
||||
if message["type"] == "audio_chunk":
|
||||
try:
|
||||
audioData = base64.b64decode(message["data"])
|
||||
|
||||
# Process audio chunk
|
||||
# This would integrate with Google Cloud Speech-to-Text streaming API
|
||||
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "transcription_result",
|
||||
"text": "Audio chunk received", # Placeholder
|
||||
"confidence": 0.95,
|
||||
"is_final": False
|
||||
}, websocket)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio: {e}")
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "error",
|
||||
"error": f"Failed to process audio: {str(e)}"
|
||||
}, websocket)
|
||||
|
||||
elif message["type"] == "ping":
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "pong",
|
||||
"timestamp": message.get("timestamp")
|
||||
}, websocket)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, connectionId)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
manager.disconnect(websocket, connectionId)
|
||||
|
||||
@router.websocket("/ws/text-to-speech")
|
||||
async def websocket_text_to_speech(
|
||||
websocket: WebSocket,
|
||||
userId: str = "default",
|
||||
language: str = "de-DE",
|
||||
voice: str = "de-DE-Wavenet-A"
|
||||
):
|
||||
"""WebSocket endpoint for real-time text-to-speech"""
|
||||
connectionId = f"tts_{userId}_{language}_{voice}"
|
||||
|
||||
try:
|
||||
await manager.connect(websocket, connectionId)
|
||||
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "connected",
|
||||
"connection_id": connectionId,
|
||||
"message": "Connected to text-to-speech"
|
||||
}, websocket)
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
|
||||
if message["type"] == "text_to_speak":
|
||||
try:
|
||||
text = message["text"]
|
||||
|
||||
# Process text-to-speech
|
||||
# This would integrate with Google Cloud Text-to-Speech API
|
||||
|
||||
# For now, send a placeholder response
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "audio_data",
|
||||
"audio": "base64_encoded_audio_here", # Placeholder
|
||||
"format": "mp3"
|
||||
}, websocket)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing text-to-speech: {e}")
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "error",
|
||||
"error": f"Failed to process text: {str(e)}"
|
||||
}, websocket)
|
||||
|
||||
elif message["type"] == "ping":
|
||||
await manager.sendPersonalMessage({
|
||||
"type": "pong",
|
||||
"timestamp": message.get("timestamp")
|
||||
}, websocket)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, connectionId)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
manager.disconnect(websocket, connectionId)
|
||||
|
|
|
|||
|
|
@ -1,231 +0,0 @@
|
|||
"""
|
||||
Voice Streaming WebSocket Routes
|
||||
Provides real-time audio streaming for voice services
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
import asyncio
|
||||
from typing import Dict, List
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.connectors.connectorGoogleSpeech import ConnectorGoogleSpeech
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/voice/ws", tags=["Voice Streaming"])
|
||||
|
||||
# Store active connections
|
||||
active_connections: Dict[str, WebSocket] = {}
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = []
|
||||
|
||||
async def connect(self, websocket: WebSocket, connection_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
active_connections[connection_id] = websocket
|
||||
logger.info(f"WebSocket connected: {connection_id}")
|
||||
|
||||
def disconnect(self, websocket: WebSocket, connection_id: str):
|
||||
if websocket in self.active_connections:
|
||||
self.active_connections.remove(websocket)
|
||||
if connection_id in active_connections:
|
||||
del active_connections[connection_id]
|
||||
logger.info(f"WebSocket disconnected: {connection_id}")
|
||||
|
||||
async def send_personal_message(self, message: dict, websocket: WebSocket):
|
||||
try:
|
||||
await websocket.send_text(json.dumps(message))
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending message: {e}")
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
@router.websocket("/realtime-interpreter")
|
||||
async def websocket_realtime_interpreter(
|
||||
websocket: WebSocket,
|
||||
user_id: str = "default",
|
||||
from_language: str = "de-DE",
|
||||
to_language: str = "en-US"
|
||||
):
|
||||
"""WebSocket endpoint for real-time voice interpretation"""
|
||||
connection_id = f"realtime_{user_id}_{from_language}_{to_language}"
|
||||
|
||||
try:
|
||||
await manager.connect(websocket, connection_id)
|
||||
|
||||
# Send connection confirmation
|
||||
await manager.send_personal_message({
|
||||
"type": "connected",
|
||||
"connection_id": connection_id,
|
||||
"message": "Connected to real-time interpreter"
|
||||
}, websocket)
|
||||
|
||||
# Initialize Google Speech connector
|
||||
google_speech = ConnectorGoogleSpeech()
|
||||
|
||||
while True:
|
||||
# Receive message from client
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
|
||||
if message["type"] == "audio_chunk":
|
||||
# Process audio chunk
|
||||
try:
|
||||
# Decode base64 audio data
|
||||
audio_data = base64.b64decode(message["data"])
|
||||
|
||||
# For now, just acknowledge receipt
|
||||
# In a full implementation, this would:
|
||||
# 1. Buffer audio chunks
|
||||
# 2. Process with Google Cloud Speech-to-Text streaming
|
||||
# 3. Send partial results back
|
||||
# 4. Handle translation
|
||||
|
||||
await manager.send_personal_message({
|
||||
"type": "audio_received",
|
||||
"chunk_size": len(audio_data),
|
||||
"timestamp": message.get("timestamp")
|
||||
}, websocket)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio chunk: {e}")
|
||||
await manager.send_personal_message({
|
||||
"type": "error",
|
||||
"error": f"Failed to process audio: {str(e)}"
|
||||
}, websocket)
|
||||
|
||||
elif message["type"] == "ping":
|
||||
# Respond to ping
|
||||
await manager.send_personal_message({
|
||||
"type": "pong",
|
||||
"timestamp": message.get("timestamp")
|
||||
}, websocket)
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown message type: {message['type']}")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, connection_id)
|
||||
logger.info(f"Client disconnected: {connection_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
manager.disconnect(websocket, connection_id)
|
||||
|
||||
@router.websocket("/speech-to-text")
|
||||
async def websocket_speech_to_text(
|
||||
websocket: WebSocket,
|
||||
user_id: str = "default",
|
||||
language: str = "de-DE"
|
||||
):
|
||||
"""WebSocket endpoint for real-time speech-to-text"""
|
||||
connection_id = f"stt_{user_id}_{language}"
|
||||
|
||||
try:
|
||||
await manager.connect(websocket, connection_id)
|
||||
|
||||
await manager.send_personal_message({
|
||||
"type": "connected",
|
||||
"connection_id": connection_id,
|
||||
"message": "Connected to speech-to-text"
|
||||
}, websocket)
|
||||
|
||||
# Initialize Google Speech connector
|
||||
google_speech = ConnectorGoogleSpeech()
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
|
||||
if message["type"] == "audio_chunk":
|
||||
try:
|
||||
audio_data = base64.b64decode(message["data"])
|
||||
|
||||
# Process audio chunk
|
||||
# This would integrate with Google Cloud Speech-to-Text streaming API
|
||||
|
||||
await manager.send_personal_message({
|
||||
"type": "transcription_result",
|
||||
"text": "Audio chunk received", # Placeholder
|
||||
"confidence": 0.95,
|
||||
"is_final": False
|
||||
}, websocket)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio: {e}")
|
||||
await manager.send_personal_message({
|
||||
"type": "error",
|
||||
"error": f"Failed to process audio: {str(e)}"
|
||||
}, websocket)
|
||||
|
||||
elif message["type"] == "ping":
|
||||
await manager.send_personal_message({
|
||||
"type": "pong",
|
||||
"timestamp": message.get("timestamp")
|
||||
}, websocket)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, connection_id)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
manager.disconnect(websocket, connection_id)
|
||||
|
||||
@router.websocket("/text-to-speech")
|
||||
async def websocket_text_to_speech(
|
||||
websocket: WebSocket,
|
||||
user_id: str = "default",
|
||||
language: str = "de-DE",
|
||||
voice: str = "de-DE-Wavenet-A"
|
||||
):
|
||||
"""WebSocket endpoint for real-time text-to-speech"""
|
||||
connection_id = f"tts_{user_id}_{language}_{voice}"
|
||||
|
||||
try:
|
||||
await manager.connect(websocket, connection_id)
|
||||
|
||||
await manager.send_personal_message({
|
||||
"type": "connected",
|
||||
"connection_id": connection_id,
|
||||
"message": "Connected to text-to-speech"
|
||||
}, websocket)
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
|
||||
if message["type"] == "text_to_speak":
|
||||
try:
|
||||
text = message["text"]
|
||||
|
||||
# Process text-to-speech
|
||||
# This would integrate with Google Cloud Text-to-Speech API
|
||||
|
||||
# For now, send a placeholder response
|
||||
await manager.send_personal_message({
|
||||
"type": "audio_data",
|
||||
"audio": "base64_encoded_audio_here", # Placeholder
|
||||
"format": "mp3"
|
||||
}, websocket)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing text-to-speech: {e}")
|
||||
await manager.send_personal_message({
|
||||
"type": "error",
|
||||
"error": f"Failed to process text: {str(e)}"
|
||||
}, websocket)
|
||||
|
||||
elif message["type"] == "ping":
|
||||
await manager.send_personal_message({
|
||||
"type": "pong",
|
||||
"timestamp": message.get("timestamp")
|
||||
}, websocket)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, connection_id)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error: {e}")
|
||||
manager.disconnect(websocket, connection_id)
|
||||
|
|
@ -3,33 +3,30 @@ Workflow routes for the backend API.
|
|||
Implements the endpoints for workflow management according to the state machine.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Query, Response, status, Request
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Import auth modules
|
||||
from modules.security.auth import limiter, getCurrentUser
|
||||
|
||||
# Import interfaces
|
||||
import modules.interfaces.interfaceChatObjects as interfaceChatObjects
|
||||
from modules.interfaces.interfaceChatObjects import getInterface
|
||||
import modules.interfaces.interfaceDbChatObjects as interfaceDbChatObjects
|
||||
from modules.interfaces.interfaceDbChatObjects import getInterface
|
||||
|
||||
# Import models
|
||||
from modules.interfaces.interfaceChatModel import (
|
||||
from modules.datamodels.datamodelChat import (
|
||||
ChatWorkflow,
|
||||
ChatMessage,
|
||||
ChatLog,
|
||||
ChatStat,
|
||||
ChatDocument,
|
||||
UserInputRequest
|
||||
ChatDocument
|
||||
)
|
||||
from modules.shared.attributeUtils import getModelAttributeDefinitions, AttributeResponse
|
||||
from modules.interfaces.interfaceAppModel import User
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
||||
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -45,21 +42,54 @@ router = APIRouter(
|
|||
)
|
||||
|
||||
def getServiceChat(currentUser: User):
|
||||
return interfaceChatObjects.getInterface(currentUser)
|
||||
return interfaceDbChatObjects.getInterface(currentUser)
|
||||
|
||||
# Consolidated endpoint for getting all workflows
|
||||
@router.get("/", response_model=List[ChatWorkflow])
|
||||
@router.get("/", response_model=PaginatedResponse[ChatWorkflow])
|
||||
@limiter.limit("120/minute")
|
||||
async def get_workflows(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[ChatWorkflow]:
|
||||
"""Get all workflows for the current user."""
|
||||
) -> PaginatedResponse[ChatWorkflow]:
|
||||
"""
|
||||
Get workflows with optional pagination, sorting, and filtering.
|
||||
|
||||
Query Parameters:
|
||||
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
||||
|
||||
Examples:
|
||||
- GET /api/workflows/ (no pagination - returns all workflows)
|
||||
- GET /api/workflows/?pagination={"page":1,"pageSize":10,"sort":[]}
|
||||
"""
|
||||
try:
|
||||
# Parse pagination parameter
|
||||
paginationParams = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid pagination parameter: {str(e)}"
|
||||
)
|
||||
|
||||
appInterface = getInterface(currentUser)
|
||||
workflows_data = appInterface.getWorkflows()
|
||||
result = appInterface.getWorkflows(pagination=paginationParams)
|
||||
|
||||
# Convert raw dictionaries to ChatWorkflow objects by loading each workflow properly
|
||||
# If pagination was requested, result is PaginatedResult with items as dicts
|
||||
# If no pagination, result is List[Dict]
|
||||
if paginationParams:
|
||||
workflows_data = result.items
|
||||
totalItems = result.totalItems
|
||||
totalPages = result.totalPages
|
||||
else:
|
||||
workflows_data = result
|
||||
totalItems = len(result)
|
||||
totalPages = 1
|
||||
|
||||
workflows = []
|
||||
for workflow_data in workflows_data:
|
||||
try:
|
||||
|
|
@ -72,7 +102,25 @@ async def get_workflows(
|
|||
# Skip invalid workflows instead of failing the entire request
|
||||
continue
|
||||
|
||||
return workflows
|
||||
if paginationParams:
|
||||
return PaginatedResponse(
|
||||
items=workflows,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PaginatedResponse(
|
||||
items=workflows,
|
||||
pagination=None
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflows: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
|
@ -169,10 +217,10 @@ async def get_workflow_status(
|
|||
"""Get the current status of a workflow."""
|
||||
try:
|
||||
# Get service center
|
||||
interfaceChat = getServiceChat(currentUser)
|
||||
interfaceDbChat = getServiceChat(currentUser)
|
||||
|
||||
# Retrieve workflow
|
||||
workflow = interfaceChat.getWorkflow(workflowId)
|
||||
workflow = interfaceDbChat.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -190,39 +238,84 @@ async def get_workflow_status(
|
|||
)
|
||||
|
||||
# API Endpoint for workflow logs with selective data transfer
|
||||
@router.get("/{workflowId}/logs", response_model=List[ChatLog])
|
||||
@router.get("/{workflowId}/logs", response_model=PaginatedResponse[ChatLog])
|
||||
@limiter.limit("120/minute")
|
||||
async def get_workflow_logs(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
logId: Optional[str] = Query(None, description="Optional log ID to get only newer logs"),
|
||||
logId: Optional[str] = Query(None, description="Optional log ID to get only newer logs (legacy selective data transfer)"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[ChatLog]:
|
||||
"""Get logs for a workflow with support for selective data transfer."""
|
||||
) -> PaginatedResponse[ChatLog]:
|
||||
"""
|
||||
Get logs for a workflow with optional pagination, sorting, and filtering.
|
||||
Also supports legacy selective data transfer via logId parameter.
|
||||
|
||||
Query Parameters:
|
||||
- logId: Optional log ID for selective data transfer (returns only logs after this ID)
|
||||
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
||||
"""
|
||||
try:
|
||||
# Parse pagination parameter
|
||||
paginationParams = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid pagination parameter: {str(e)}"
|
||||
)
|
||||
|
||||
# Get service center
|
||||
interfaceChat = getServiceChat(currentUser)
|
||||
interfaceDbChat = getServiceChat(currentUser)
|
||||
|
||||
# Verify workflow exists
|
||||
workflow = interfaceChat.getWorkflow(workflowId)
|
||||
workflow = interfaceDbChat.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Workflow with ID {workflowId} not found"
|
||||
)
|
||||
|
||||
# Get all logs
|
||||
allLogs = interfaceChat.getLogs(workflowId)
|
||||
# Get logs with optional pagination
|
||||
result = interfaceDbChat.getLogs(workflowId, pagination=paginationParams)
|
||||
|
||||
# Apply selective data transfer if logId is provided
|
||||
# Handle legacy selective data transfer if logId is provided (takes precedence over pagination)
|
||||
if logId:
|
||||
# If pagination was requested, result is PaginatedResult, otherwise List[ChatLog]
|
||||
allLogs = result.items if paginationParams else result
|
||||
|
||||
# Find the index of the log with the given ID
|
||||
logIndex = next((i for i, log in enumerate(allLogs) if log.id == logId), -1)
|
||||
if logIndex >= 0:
|
||||
# Return only logs after the specified log
|
||||
return allLogs[logIndex + 1:]
|
||||
filteredLogs = allLogs[logIndex + 1:]
|
||||
return PaginatedResponse(
|
||||
items=filteredLogs,
|
||||
pagination=None
|
||||
)
|
||||
|
||||
return allLogs
|
||||
# If pagination was requested, result is PaginatedResult
|
||||
# If no pagination, result is List[ChatLog]
|
||||
if paginationParams:
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=result.totalItems,
|
||||
totalPages=result.totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PaginatedResponse(
|
||||
items=result,
|
||||
pagination=None
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -233,40 +326,84 @@ async def get_workflow_logs(
|
|||
)
|
||||
|
||||
# API Endpoint for workflow messages with selective data transfer
|
||||
@router.get("/{workflowId}/messages", response_model=List[ChatMessage])
|
||||
@router.get("/{workflowId}/messages", response_model=PaginatedResponse[ChatMessage])
|
||||
@limiter.limit("120/minute")
|
||||
async def get_workflow_messages(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
messageId: Optional[str] = Query(None, description="Optional message ID to get only newer messages"),
|
||||
messageId: Optional[str] = Query(None, description="Optional message ID to get only newer messages (legacy selective data transfer)"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[ChatMessage]:
|
||||
"""Get messages for a workflow with support for selective data transfer."""
|
||||
) -> PaginatedResponse[ChatMessage]:
|
||||
"""
|
||||
Get messages for a workflow with optional pagination, sorting, and filtering.
|
||||
Also supports legacy selective data transfer via messageId parameter.
|
||||
|
||||
Query Parameters:
|
||||
- messageId: Optional message ID for selective data transfer (returns only messages after this ID)
|
||||
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
||||
"""
|
||||
try:
|
||||
# Parse pagination parameter
|
||||
paginationParams = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid pagination parameter: {str(e)}"
|
||||
)
|
||||
|
||||
# Get service center
|
||||
interfaceChat = getServiceChat(currentUser)
|
||||
interfaceDbChat = getServiceChat(currentUser)
|
||||
|
||||
# Verify workflow exists
|
||||
workflow = interfaceChat.getWorkflow(workflowId)
|
||||
workflow = interfaceDbChat.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Workflow with ID {workflowId} not found"
|
||||
)
|
||||
|
||||
# Get all messages
|
||||
allMessages = interfaceChat.getMessages(workflowId)
|
||||
# Get messages with optional pagination
|
||||
result = interfaceDbChat.getMessages(workflowId, pagination=paginationParams)
|
||||
|
||||
# Apply selective data transfer if messageId is provided
|
||||
# Handle legacy selective data transfer if messageId is provided (takes precedence over pagination)
|
||||
if messageId:
|
||||
# If pagination was requested, result is PaginatedResult, otherwise List[ChatMessage]
|
||||
allMessages = result.items if paginationParams else result
|
||||
|
||||
# Find the index of the message with the given ID
|
||||
messageIndex = next((i for i, msg in enumerate(allMessages) if msg.id == messageId), -1)
|
||||
if messageIndex >= 0:
|
||||
# Return only messages after the specified message
|
||||
filteredMessages = allMessages[messageIndex + 1:]
|
||||
return filteredMessages
|
||||
return PaginatedResponse(
|
||||
items=filteredMessages,
|
||||
pagination=None
|
||||
)
|
||||
|
||||
return allMessages
|
||||
# If pagination was requested, result is PaginatedResult
|
||||
# If no pagination, result is List[ChatMessage]
|
||||
if paginationParams:
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=result.totalItems,
|
||||
totalPages=result.totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
else:
|
||||
return PaginatedResponse(
|
||||
items=result,
|
||||
pagination=None
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -276,59 +413,6 @@ async def get_workflow_messages(
|
|||
detail=f"Error getting workflow messages: {str(e)}"
|
||||
)
|
||||
|
||||
# State 1: Workflow Initialization endpoint
|
||||
@router.post("/start", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def start_workflow(
|
||||
request: Request,
|
||||
workflowId: Optional[str] = Query(None, description="Optional ID of the workflow to continue"),
|
||||
userInput: UserInputRequest = Body(...),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> ChatWorkflow:
|
||||
"""
|
||||
Starts a new workflow or continues an existing one.
|
||||
Corresponds to State 1 in the state machine documentation.
|
||||
"""
|
||||
try:
|
||||
# Get service center
|
||||
interfaceChat = getServiceChat(currentUser)
|
||||
|
||||
# Start or continue workflow using ChatObjects
|
||||
workflow = await interfaceChat.workflowStart(currentUser, userInput, workflowId)
|
||||
|
||||
return workflow
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in start_workflow: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
# State 8: Workflow Stopped endpoint
|
||||
@router.post("/{workflowId}/stop", response_model=ChatWorkflow)
|
||||
@limiter.limit("120/minute")
|
||||
async def stop_workflow(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow to stop"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> ChatWorkflow:
|
||||
"""Stops a running workflow."""
|
||||
try:
|
||||
# Get service center
|
||||
interfaceChat = getServiceChat(currentUser)
|
||||
|
||||
# Stop workflow using ChatObjects
|
||||
workflow = await interfaceChat.workflowStop(workflowId)
|
||||
|
||||
return workflow
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stop_workflow: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
# State 11: Workflow Reset/Deletion endpoint
|
||||
@router.delete("/{workflowId}", response_model=Dict[str, Any])
|
||||
|
|
@ -341,10 +425,10 @@ async def delete_workflow(
|
|||
"""Deletes a workflow and its associated data."""
|
||||
try:
|
||||
# Get service center
|
||||
interfaceChat = getServiceChat(currentUser)
|
||||
interfaceDbChat = getServiceChat(currentUser)
|
||||
|
||||
# Get raw workflow data from database to check permissions
|
||||
workflows = interfaceChat.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
workflows = interfaceDbChat.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
|
||||
if not workflows:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -354,14 +438,14 @@ async def delete_workflow(
|
|||
workflow_data = workflows[0]
|
||||
|
||||
# Check if user has permission to delete using the interface's permission system
|
||||
if not interfaceChat._canModify("workflows", workflowId):
|
||||
if not interfaceDbChat._canModify("workflows", workflowId):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You don't have permission to delete this workflow"
|
||||
)
|
||||
|
||||
# Delete workflow
|
||||
success = interfaceChat.deleteWorkflow(workflowId)
|
||||
success = interfaceDbChat.deleteWorkflow(workflowId)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
|
|
@ -383,45 +467,6 @@ async def delete_workflow(
|
|||
)
|
||||
|
||||
|
||||
# Unified Chat Data Endpoint for Polling
|
||||
@router.get("/{workflowId}/chatData")
|
||||
@limiter.limit("120/minute")
|
||||
async def get_workflow_chat_data(
|
||||
request: Request,
|
||||
workflowId: str = Path(..., description="ID of the workflow"),
|
||||
afterTimestamp: Optional[float] = Query(None, description="Unix timestamp to get data after"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get unified chat data (messages, logs, stats) for a workflow with timestamp-based selective data transfer.
|
||||
Returns all data types in chronological order based on _createdAt timestamp.
|
||||
"""
|
||||
try:
|
||||
# Get service center
|
||||
interfaceChat = getServiceChat(currentUser)
|
||||
|
||||
# Verify workflow exists
|
||||
workflow = interfaceChat.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Workflow with ID {workflowId} not found"
|
||||
)
|
||||
|
||||
# Get unified chat data using the new method
|
||||
chatData = interfaceChat.getUnifiedChatData(workflowId, afterTimestamp)
|
||||
|
||||
return chatData
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting unified chat data: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error getting unified chat data: {str(e)}"
|
||||
)
|
||||
|
||||
# Document Management Endpoints
|
||||
|
||||
@router.delete("/{workflowId}/messages/{messageId}", response_model=Dict[str, Any])
|
||||
|
|
@ -435,10 +480,10 @@ async def delete_workflow_message(
|
|||
"""Delete a message from a workflow."""
|
||||
try:
|
||||
# Get service center
|
||||
interfaceChat = getServiceChat(currentUser)
|
||||
interfaceDbChat = getServiceChat(currentUser)
|
||||
|
||||
# Verify workflow exists
|
||||
workflow = interfaceChat.getWorkflow(workflowId)
|
||||
workflow = interfaceDbChat.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -446,7 +491,7 @@ async def delete_workflow_message(
|
|||
)
|
||||
|
||||
# Delete the message
|
||||
success = interfaceChat.deleteMessage(workflowId, messageId)
|
||||
success = interfaceDbChat.deleteMessage(workflowId, messageId)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
|
|
@ -458,7 +503,7 @@ async def delete_workflow_message(
|
|||
messageIds = workflow.get("messageIds", [])
|
||||
if messageId in messageIds:
|
||||
messageIds.remove(messageId)
|
||||
interfaceChat.updateWorkflow(workflowId, {"messageIds": messageIds})
|
||||
interfaceDbChat.updateWorkflow(workflowId, {"messageIds": messageIds})
|
||||
|
||||
return {
|
||||
"workflowId": workflowId,
|
||||
|
|
@ -486,10 +531,10 @@ async def delete_file_from_message(
|
|||
"""Delete a file reference from a message in a workflow."""
|
||||
try:
|
||||
# Get service center
|
||||
interfaceChat = getServiceChat(currentUser)
|
||||
interfaceDbChat = getServiceChat(currentUser)
|
||||
|
||||
# Verify workflow exists
|
||||
workflow = interfaceChat.getWorkflow(workflowId)
|
||||
workflow = interfaceDbChat.getWorkflow(workflowId)
|
||||
if not workflow:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
|
@ -497,7 +542,7 @@ async def delete_file_from_message(
|
|||
)
|
||||
|
||||
# Delete file reference from message
|
||||
success = interfaceChat.deleteFileFromMessage(workflowId, messageId, fileId)
|
||||
success = interfaceDbChat.deleteFileFromMessage(workflowId, messageId, fileId)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -3,29 +3,48 @@ Authentication module for backend API.
|
|||
Handles JWT-based authentication, token generation, and user context.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from fastapi import Depends, HTTPException, status, Request
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi import Depends, HTTPException, status, Request, Response
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from jose import JWTError, jwt
|
||||
import logging
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.shared.timezoneUtils import get_utc_now, get_utc_timestamp
|
||||
from modules.interfaces.interfaceAppObjects import getRootInterface
|
||||
from modules.interfaces.interfaceAppModel import User, AuthAuthority, Token
|
||||
from modules.interfaces.interfaceDbAppObjects import getRootInterface
|
||||
from modules.datamodels.datamodelUam import User, AuthAuthority
|
||||
from modules.datamodels.datamodelSecurity import Token
|
||||
|
||||
# Get Config Data
|
||||
SECRET_KEY = APP_CONFIG.get("APP_JWT_SECRET_SECRET")
|
||||
SECRET_KEY = APP_CONFIG.get("APP_JWT_KEY_SECRET")
|
||||
ALGORITHM = APP_CONFIG.get("Auth_ALGORITHM")
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(APP_CONFIG.get("APP_TOKEN_EXPIRY"))
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = int(APP_CONFIG.get("APP_REFRESH_TOKEN_EXPIRY", "7"))
|
||||
|
||||
# OAuth2 Setup
|
||||
oauth2Scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
# Cookie-based Authentication Setup
|
||||
class CookieAuth(HTTPBearer):
|
||||
"""Cookie-based authentication that checks httpOnly cookies first, then Authorization header"""
|
||||
def __init__(self, auto_error: bool = True):
|
||||
super().__init__(auto_error=auto_error)
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
# 1. Check httpOnly cookie first (preferred method)
|
||||
token = request.cookies.get('auth_token')
|
||||
if token:
|
||||
return token
|
||||
|
||||
# 2. Fallback to Authorization header for API calls
|
||||
authorization = request.headers.get("Authorization")
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
return authorization.split(" ")[1]
|
||||
|
||||
if self.auto_error:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
return None
|
||||
|
||||
# Initialize cookie-based auth
|
||||
cookieAuth = CookieAuth(auto_error=False)
|
||||
|
||||
# Rate Limiter
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
|
@ -33,33 +52,9 @@ limiter = Limiter(key_func=get_remote_address)
|
|||
# Logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def createAccessToken(data: dict, expiresDelta: Optional[timedelta] = None) -> Tuple[str, datetime]:
|
||||
"""
|
||||
Creates a JWT Access Token.
|
||||
|
||||
Args:
|
||||
data: Data to encode (usually user ID or username)
|
||||
expiresDelta: Validity duration of the token (optional)
|
||||
|
||||
Returns:
|
||||
Tuple of (JWT Token as string, expiration datetime)
|
||||
"""
|
||||
toEncode = data.copy()
|
||||
# Ensure a token id (jti) exists for revocation tracking (only required for local, harmless otherwise)
|
||||
if "jti" not in toEncode or not toEncode.get("jti"):
|
||||
toEncode["jti"] = str(uuid.uuid4())
|
||||
|
||||
if expiresDelta:
|
||||
expire = get_utc_now() + expiresDelta
|
||||
else:
|
||||
expire = get_utc_now() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
toEncode.update({"exp": expire})
|
||||
encodedJwt = jwt.encode(toEncode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
return encodedJwt, expire
|
||||
# Note: JWT creation and cookie helpers moved to modules.security.jwtService
|
||||
|
||||
def _getUserBase(token: str = Depends(oauth2Scheme)) -> User:
|
||||
def _getUserBase(token: str = Depends(cookieAuth)) -> User:
|
||||
"""
|
||||
Extracts and validates the current user from the JWT token.
|
||||
|
||||
|
|
@ -78,6 +73,19 @@ def _getUserBase(token: str = Depends(oauth2Scheme)) -> User:
|
|||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Guard: token may be None or malformed when cookie/header is missing or bad
|
||||
if not token or not isinstance(token, str):
|
||||
logger.warning("Missing JWT Token (no cookie/header)")
|
||||
raise credentialsException
|
||||
# Basic JWT format check (header.payload.signature)
|
||||
try:
|
||||
if token.count(".") != 2:
|
||||
logger.warning("Malformed JWT token format")
|
||||
raise credentialsException
|
||||
except Exception:
|
||||
# If anything odd happens while checking format, treat as invalid creds
|
||||
raise credentialsException
|
||||
|
||||
try:
|
||||
# Decode token
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
|
|
@ -138,7 +146,14 @@ def _getUserBase(token: str = Depends(oauth2Scheme)) -> User:
|
|||
db_tokens = appInterface.db.getRecordset(
|
||||
Token, recordFilter={"id": tokenId}
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# Check if this is a table not found error (token table was deleted)
|
||||
if "does not exist" in str(e).lower() or "relation" in str(e).lower():
|
||||
logger.error("Token table does not exist - database may have been reset")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Authentication service temporarily unavailable. Please contact administrator."
|
||||
)
|
||||
db_tokens = []
|
||||
|
||||
if db_tokens:
|
||||
|
|
|
|||
98
modules/security/csrf.py
Normal file
98
modules/security/csrf.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
"""
|
||||
CSRF Protection Middleware for PowerOn Gateway
|
||||
|
||||
This module provides CSRF protection for state-changing operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import Request, HTTPException, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from typing import Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
CSRF protection middleware that validates CSRF tokens for state-changing operations.
|
||||
"""
|
||||
|
||||
def __init__(self, app, exempt_paths: Set[str] = None):
|
||||
super().__init__(app)
|
||||
# Paths that are exempt from CSRF protection
|
||||
self.exempt_paths = exempt_paths or {
|
||||
"/api/local/login",
|
||||
"/api/local/register",
|
||||
"/api/msft/login",
|
||||
"/api/google/login",
|
||||
"/api/msft/callback",
|
||||
"/api/google/callback"
|
||||
}
|
||||
|
||||
# State-changing HTTP methods that require CSRF protection
|
||||
self.protected_methods = {"POST", "PUT", "DELETE", "PATCH"}
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
Check CSRF token for state-changing operations.
|
||||
"""
|
||||
# Skip CSRF check for exempt paths
|
||||
if request.url.path in self.exempt_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Skip CSRF check for non-state-changing methods
|
||||
if request.method not in self.protected_methods:
|
||||
return await call_next(request)
|
||||
|
||||
# Skip CSRF check for OPTIONS requests (CORS preflight)
|
||||
if request.method == "OPTIONS":
|
||||
return await call_next(request)
|
||||
|
||||
# Get CSRF token from header
|
||||
csrf_token = request.headers.get("X-CSRF-Token")
|
||||
if not csrf_token:
|
||||
logger.warning(f"CSRF token missing for {request.method} {request.url.path}")
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content={"detail": "CSRF token missing"}
|
||||
)
|
||||
|
||||
# Validate CSRF token format (basic validation)
|
||||
if not self._is_valid_csrf_token(csrf_token):
|
||||
logger.warning(f"Invalid CSRF token format for {request.method} {request.url.path}")
|
||||
from fastapi.responses import JSONResponse
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content={"detail": "Invalid CSRF token format"}
|
||||
)
|
||||
|
||||
# Additional CSRF validation could be added here:
|
||||
# - Check token against session
|
||||
# - Validate token expiration
|
||||
# - Verify token origin
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
def _is_valid_csrf_token(self, token: str) -> bool:
|
||||
"""
|
||||
Basic validation of CSRF token format.
|
||||
|
||||
Args:
|
||||
token: The CSRF token to validate
|
||||
|
||||
Returns:
|
||||
bool: True if token format is valid
|
||||
"""
|
||||
if not token or not isinstance(token, str):
|
||||
return False
|
||||
|
||||
# Basic format validation (hex string, reasonable length)
|
||||
if len(token) < 16 or len(token) > 64:
|
||||
return False
|
||||
|
||||
# Check if token contains only valid hex characters
|
||||
try:
|
||||
int(token, 16)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
115
modules/security/jwtService.py
Normal file
115
modules/security/jwtService.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""
|
||||
JWT Service
|
||||
Centralizes local JWT creation and cookie helpers.
|
||||
"""
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Optional, Tuple
|
||||
from fastapi import Response
|
||||
from jose import jwt
|
||||
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.shared.timezoneUtils import getUtcNow
|
||||
|
||||
# Config
|
||||
SECRET_KEY = APP_CONFIG.get("APP_JWT_KEY_SECRET")
|
||||
ALGORITHM = APP_CONFIG.get("Auth_ALGORITHM")
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(APP_CONFIG.get("APP_TOKEN_EXPIRY"))
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = int(APP_CONFIG.get("APP_REFRESH_TOKEN_EXPIRY", "7"))
|
||||
|
||||
# Cookie security settings - use secure cookies based on whether API uses HTTPS
|
||||
# Cookies must have secure=True on HTTPS sites, secure=False on HTTP sites
|
||||
APP_API_URL = APP_CONFIG.get("APP_API_URL", "http://localhost:8000")
|
||||
USE_SECURE_COOKIES = APP_API_URL.startswith("https://") if APP_API_URL else False
|
||||
|
||||
|
||||
def createAccessToken(data: dict, expiresDelta: Optional[timedelta] = None) -> Tuple[str, "datetime"]:
|
||||
"""Create a JWT access token and return (token, expiresAt)."""
|
||||
toEncode = data.copy()
|
||||
if "jti" not in toEncode or not toEncode.get("jti"):
|
||||
import uuid
|
||||
toEncode["jti"] = str(uuid.uuid4())
|
||||
|
||||
expire = getUtcNow() + (expiresDelta if expiresDelta else timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
|
||||
toEncode.update({"exp": expire})
|
||||
encodedJwt = jwt.encode(toEncode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encodedJwt, expire
|
||||
|
||||
|
||||
def createRefreshToken(data: dict) -> Tuple[str, "datetime"]:
|
||||
"""Create a JWT refresh token and return (token, expiresAt)."""
|
||||
toEncode = data.copy()
|
||||
if "jti" not in toEncode or not toEncode.get("jti"):
|
||||
import uuid
|
||||
toEncode["jti"] = str(uuid.uuid4())
|
||||
toEncode["type"] = "refresh"
|
||||
|
||||
expire = getUtcNow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
toEncode.update({"exp": expire})
|
||||
encodedJwt = jwt.encode(toEncode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encodedJwt, expire
|
||||
|
||||
|
||||
def setAccessTokenCookie(response: Response, token: str, expiresDelta: Optional[timedelta] = None) -> None:
|
||||
"""Set access token as httpOnly cookie."""
|
||||
maxAge = int(expiresDelta.total_seconds()) if expiresDelta else ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
response.set_cookie(
|
||||
key="auth_token",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=USE_SECURE_COOKIES, # Only secure in production (HTTPS)
|
||||
samesite="strict",
|
||||
path="/",
|
||||
max_age=maxAge
|
||||
)
|
||||
|
||||
|
||||
def setRefreshTokenCookie(response: Response, token: str) -> None:
|
||||
"""Set refresh token as httpOnly cookie."""
|
||||
response.set_cookie(
|
||||
key="refresh_token",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=USE_SECURE_COOKIES, # Only secure in production (HTTPS)
|
||||
samesite="strict",
|
||||
path="/",
|
||||
max_age=REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
|
||||
)
|
||||
|
||||
|
||||
def clearAccessTokenCookie(response: Response) -> None:
|
||||
"""
|
||||
Clear access token cookie by setting it to expire immediately.
|
||||
Uses both raw header manipulation and FastAPI's delete_cookie for maximum browser compatibility.
|
||||
"""
|
||||
# Build secure flag based on environment
|
||||
secure_flag = "; Secure" if USE_SECURE_COOKIES else ""
|
||||
|
||||
# Primary method: Raw Set-Cookie header for guaranteed deletion
|
||||
response.headers.append(
|
||||
"Set-Cookie",
|
||||
f"auth_token=deleted; Path=/; Max-Age=0; Expires=Thu, 01 Jan 1970 00:00:00 GMT; HttpOnly{secure_flag}; SameSite=Strict"
|
||||
)
|
||||
|
||||
# Fallback: Also use FastAPI's built-in method
|
||||
response.delete_cookie(key="auth_token", path="/")
|
||||
|
||||
|
||||
def clearRefreshTokenCookie(response: Response) -> None:
|
||||
"""
|
||||
Clear refresh token cookie by setting it to expire immediately.
|
||||
Uses both raw header manipulation and FastAPI's delete_cookie for maximum browser compatibility.
|
||||
"""
|
||||
# Build secure flag based on environment
|
||||
secure_flag = "; Secure" if USE_SECURE_COOKIES else ""
|
||||
|
||||
# Primary method: Raw Set-Cookie header for guaranteed deletion
|
||||
response.headers.append(
|
||||
"Set-Cookie",
|
||||
f"refresh_token=deleted; Path=/; Max-Age=0; Expires=Thu, 01 Jan 1970 00:00:00 GMT; HttpOnly{secure_flag}; SameSite=Strict"
|
||||
)
|
||||
|
||||
# Fallback: Also use FastAPI's built-in method
|
||||
response.delete_cookie(key="refresh_token", path="/")
|
||||
|
||||
|
||||
|
|
@ -5,12 +5,12 @@ Handles all token operations including automatic refresh for backend services.
|
|||
|
||||
import logging
|
||||
import httpx
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, Callable
|
||||
|
||||
from modules.interfaces.interfaceAppModel import Token, AuthAuthority
|
||||
from modules.datamodels.datamodelSecurity import Token
|
||||
from modules.datamodels.datamodelUam import AuthAuthority
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp, create_expiration_timestamp
|
||||
from modules.shared.timezoneUtils import getUtcTimestamp, createExpirationTimestamp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -27,54 +27,54 @@ class TokenManager:
|
|||
self.google_client_id = APP_CONFIG.get("Service_GOOGLE_CLIENT_ID")
|
||||
self.google_client_secret = APP_CONFIG.get("Service_GOOGLE_CLIENT_SECRET")
|
||||
|
||||
def refresh_microsoft_token(self, refresh_token: str, user_id: str, old_token: Token) -> Optional[Token]:
|
||||
def refreshMicrosoftToken(self, refreshToken: str, userId: str, oldToken: Token) -> Optional[Token]:
|
||||
"""Refresh Microsoft OAuth token using refresh token"""
|
||||
try:
|
||||
logger.debug(f"refresh_microsoft_token: Starting Microsoft token refresh for user {user_id}")
|
||||
logger.debug(f"refresh_microsoft_token: Configuration check - client_id: {bool(self.msft_client_id)}, client_secret: {bool(self.msft_client_secret)}")
|
||||
logger.debug(f"refreshMicrosoftToken: Starting Microsoft token refresh for user {userId}")
|
||||
logger.debug(f"refreshMicrosoftToken: Configuration check - client_id: {bool(self.msft_client_id)}, client_secret: {bool(self.msft_client_secret)}")
|
||||
|
||||
if not self.msft_client_id or not self.msft_client_secret:
|
||||
logger.error("Microsoft OAuth configuration not found")
|
||||
return None
|
||||
|
||||
# Microsoft token refresh endpoint
|
||||
token_url = f"https://login.microsoftonline.com/{self.msft_tenant_id}/oauth2/v2.0/token"
|
||||
logger.debug(f"refresh_microsoft_token: Using token URL: {token_url}")
|
||||
tokenUrl = f"https://login.microsoftonline.com/{self.msft_tenant_id}/oauth2/v2.0/token"
|
||||
logger.debug(f"refreshMicrosoftToken: Using token URL: {tokenUrl}")
|
||||
|
||||
# Prepare refresh request
|
||||
data = {
|
||||
"client_id": self.msft_client_id,
|
||||
"client_secret": self.msft_client_secret,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"refresh_token": refreshToken,
|
||||
"scope": "Mail.ReadWrite Mail.Send Mail.ReadWrite.Shared User.Read"
|
||||
}
|
||||
logger.debug(f"refresh_microsoft_token: Refresh request data prepared (refresh_token length: {len(refresh_token) if refresh_token else 0})")
|
||||
logger.debug(f"refreshMicrosoftToken: Refresh request data prepared (refreshToken length: {len(refreshToken) if refreshToken else 0})")
|
||||
|
||||
# Make refresh request
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
logger.debug(f"refresh_microsoft_token: Making HTTP request to Microsoft OAuth endpoint")
|
||||
response = client.post(token_url, data=data)
|
||||
logger.debug(f"refresh_microsoft_token: HTTP response status: {response.status_code}")
|
||||
logger.debug(f"refreshMicrosoftToken: Making HTTP request to Microsoft OAuth endpoint")
|
||||
response = client.post(tokenUrl, data=data)
|
||||
logger.debug(f"refreshMicrosoftToken: HTTP response status: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
token_data = response.json()
|
||||
logger.debug(f"refresh_microsoft_token: Token refresh successful, creating new token")
|
||||
tokenData = response.json()
|
||||
logger.debug(f"refreshMicrosoftToken: Token refresh successful, creating new token")
|
||||
|
||||
# Create new token
|
||||
new_token = Token(
|
||||
userId=user_id,
|
||||
newToken = Token(
|
||||
userId=userId,
|
||||
authority=AuthAuthority.MSFT,
|
||||
connectionId=old_token.connectionId, # Preserve connection ID
|
||||
tokenAccess=token_data["access_token"],
|
||||
tokenRefresh=token_data.get("refresh_token", refresh_token), # Keep old refresh token if new one not provided
|
||||
tokenType=token_data.get("token_type", "bearer"),
|
||||
expiresAt=create_expiration_timestamp(token_data.get("expires_in", 3600)),
|
||||
createdAt=get_utc_timestamp()
|
||||
connectionId=oldToken.connectionId, # Preserve connection ID
|
||||
tokenAccess=tokenData["access_token"],
|
||||
tokenRefresh=tokenData.get("refresh_token", refreshToken), # Keep old refresh token if new one not provided
|
||||
tokenType=tokenData.get("token_type", "bearer"),
|
||||
expiresAt=createExpirationTimestamp(tokenData.get("expires_in", 3600)),
|
||||
createdAt=getUtcTimestamp()
|
||||
)
|
||||
|
||||
logger.debug(f"refresh_microsoft_token: New token created with ID: {new_token.id}")
|
||||
return new_token
|
||||
logger.debug(f"refreshMicrosoftToken: New token created with ID: {newToken.id}")
|
||||
return newToken
|
||||
else:
|
||||
logger.error(f"Failed to refresh Microsoft token: {response.status_code} - {response.text}")
|
||||
return None
|
||||
|
|
@ -83,70 +83,70 @@ class TokenManager:
|
|||
logger.error(f"Error refreshing Microsoft token: {str(e)}")
|
||||
return None
|
||||
|
||||
def refresh_google_token(self, refresh_token: str, user_id: str, old_token: Token) -> Optional[Token]:
|
||||
def refreshGoogleToken(self, refreshToken: str, userId: str, oldToken: Token) -> Optional[Token]:
|
||||
"""Refresh Google OAuth token using refresh token"""
|
||||
try:
|
||||
logger.debug(f"refresh_google_token: Starting Google token refresh for user {user_id}")
|
||||
logger.debug(f"refresh_google_token: Configuration check - client_id: {bool(self.google_client_id)}, client_secret: {bool(self.google_client_secret)}")
|
||||
logger.debug(f"refreshGoogleToken: Starting Google token refresh for user {userId}")
|
||||
logger.debug(f"refreshGoogleToken: Configuration check - client_id: {bool(self.google_client_id)}, client_secret: {bool(self.google_client_secret)}")
|
||||
|
||||
if not self.google_client_id or not self.google_client_secret:
|
||||
logger.error("Google OAuth configuration not found")
|
||||
return None
|
||||
|
||||
# Google token refresh endpoint
|
||||
token_url = "https://oauth2.googleapis.com/token"
|
||||
logger.debug(f"refresh_google_token: Using token URL: {token_url}")
|
||||
tokenUrl = "https://oauth2.googleapis.com/token"
|
||||
logger.debug(f"refreshGoogleToken: Using token URL: {tokenUrl}")
|
||||
|
||||
# Prepare refresh request
|
||||
data = {
|
||||
"client_id": self.google_client_id,
|
||||
"client_secret": self.google_client_secret,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token
|
||||
"refresh_token": refreshToken
|
||||
}
|
||||
logger.debug(f"refresh_google_token: Refresh request data prepared (refresh_token length: {len(refresh_token) if refresh_token else 0})")
|
||||
logger.debug(f"refreshGoogleToken: Refresh request data prepared (refreshToken length: {len(refreshToken) if refreshToken else 0})")
|
||||
|
||||
# Make refresh request
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
logger.debug(f"refresh_google_token: Making HTTP request to Google OAuth endpoint")
|
||||
response = client.post(token_url, data=data)
|
||||
logger.debug(f"refresh_google_token: HTTP response status: {response.status_code}")
|
||||
logger.debug(f"refreshGoogleToken: Making HTTP request to Google OAuth endpoint")
|
||||
response = client.post(tokenUrl, data=data)
|
||||
logger.debug(f"refreshGoogleToken: HTTP response status: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
token_data = response.json()
|
||||
logger.debug(f"refresh_google_token: Token refresh successful, creating new token")
|
||||
tokenData = response.json()
|
||||
logger.debug(f"refreshGoogleToken: Token refresh successful, creating new token")
|
||||
|
||||
# Validate the response contains required fields
|
||||
if "access_token" not in token_data:
|
||||
if "access_token" not in tokenData:
|
||||
logger.error("Google token refresh response missing access_token")
|
||||
return None
|
||||
|
||||
# Create new token
|
||||
new_token = Token(
|
||||
userId=user_id,
|
||||
newToken = Token(
|
||||
userId=userId,
|
||||
authority=AuthAuthority.GOOGLE,
|
||||
connectionId=old_token.connectionId, # Preserve connection ID
|
||||
tokenAccess=token_data["access_token"],
|
||||
tokenRefresh=token_data.get("refresh_token", refresh_token), # Use new refresh token if provided
|
||||
tokenType=token_data.get("token_type", "bearer"),
|
||||
expiresAt=create_expiration_timestamp(token_data.get("expires_in", 3600)),
|
||||
createdAt=get_utc_timestamp()
|
||||
connectionId=oldToken.connectionId, # Preserve connection ID
|
||||
tokenAccess=tokenData["access_token"],
|
||||
tokenRefresh=tokenData.get("refresh_token", refreshToken), # Use new refresh token if provided
|
||||
tokenType=tokenData.get("token_type", "bearer"),
|
||||
expiresAt=createExpirationTimestamp(tokenData.get("expires_in", 3600)),
|
||||
createdAt=getUtcTimestamp()
|
||||
)
|
||||
|
||||
logger.debug(f"refresh_google_token: New token created with ID: {new_token.id}")
|
||||
return new_token
|
||||
logger.debug(f"refreshGoogleToken: New token created with ID: {newToken.id}")
|
||||
return newToken
|
||||
else:
|
||||
error_details = response.text
|
||||
logger.error(f"Failed to refresh Google token: {response.status_code} - {error_details}")
|
||||
errorDetails = response.text
|
||||
logger.error(f"Failed to refresh Google token: {response.status_code} - {errorDetails}")
|
||||
|
||||
# Handle specific error cases
|
||||
if response.status_code == 400:
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_code = error_data.get("error")
|
||||
if error_code == "invalid_grant":
|
||||
errorData = response.json()
|
||||
errorCode = errorData.get("error")
|
||||
if errorCode == "invalid_grant":
|
||||
logger.warning("Google refresh token is invalid or expired - user needs to re-authenticate")
|
||||
elif error_code == "invalid_client":
|
||||
elif errorCode == "invalid_client":
|
||||
logger.error("Google OAuth client configuration is invalid")
|
||||
except:
|
||||
pass
|
||||
|
|
@ -157,28 +157,110 @@ class TokenManager:
|
|||
logger.error(f"Error refreshing Google token: {str(e)}")
|
||||
return None
|
||||
|
||||
def refresh_token(self, old_token: Token) -> Optional[Token]:
|
||||
def refreshToken(self, oldToken: Token) -> Optional[Token]:
|
||||
"""Refresh an expired token using the appropriate OAuth service"""
|
||||
try:
|
||||
logger.debug(f"refresh_token: Starting refresh for token {old_token.id}, authority: {old_token.authority}")
|
||||
logger.debug(f"refresh_token: Token details: userId={old_token.userId}, connectionId={old_token.connectionId}, hasRefreshToken={bool(old_token.tokenRefresh)}")
|
||||
logger.debug(f"refreshToken: Starting refresh for token {oldToken.id}, authority: {oldToken.authority}")
|
||||
logger.debug(f"refreshToken: Token details: userId={oldToken.userId}, connectionId={oldToken.connectionId}, hasRefreshToken={bool(oldToken.tokenRefresh)}")
|
||||
|
||||
if not old_token.tokenRefresh:
|
||||
logger.warning(f"No refresh token available for {old_token.authority}")
|
||||
# Cooldown: avoid refreshing too frequently if a workflow triggers refresh repeatedly
|
||||
# Only allow a new refresh if at least 10 minutes passed since the token was created/refreshed
|
||||
try:
|
||||
nowTs = getUtcTimestamp()
|
||||
createdTs = float(oldToken.createdAt) if oldToken.createdAt is not None else 0.0
|
||||
secondsSinceLastRefresh = nowTs - createdTs
|
||||
if secondsSinceLastRefresh < 10 * 60:
|
||||
logger.info(
|
||||
f"refreshToken: Skipping refresh for connection {oldToken.connectionId} due to cooldown. "
|
||||
f"Last refresh {int(secondsSinceLastRefresh)}s ago (< 600s)."
|
||||
)
|
||||
# Return the existing token to avoid caller errors while preventing provider rate limits
|
||||
return oldToken
|
||||
except Exception:
|
||||
# If any issue reading timestamps, proceed with normal refresh to be safe
|
||||
pass
|
||||
|
||||
if not oldToken.tokenRefresh:
|
||||
logger.warning(f"No refresh token available for {oldToken.authority}")
|
||||
return None
|
||||
|
||||
# Route to appropriate refresh method
|
||||
if old_token.authority == AuthAuthority.MSFT:
|
||||
logger.debug(f"refresh_token: Refreshing Microsoft token")
|
||||
return self.refresh_microsoft_token(old_token.tokenRefresh, old_token.userId, old_token)
|
||||
elif old_token.authority == AuthAuthority.GOOGLE:
|
||||
logger.debug(f"refresh_token: Refreshing Google token")
|
||||
return self.refresh_google_token(old_token.tokenRefresh, old_token.userId, old_token)
|
||||
if oldToken.authority == AuthAuthority.MSFT:
|
||||
logger.debug(f"refreshToken: Refreshing Microsoft token")
|
||||
return self.refreshMicrosoftToken(oldToken.tokenRefresh, oldToken.userId, oldToken)
|
||||
elif oldToken.authority == AuthAuthority.GOOGLE:
|
||||
logger.debug(f"refreshToken: Refreshing Google token")
|
||||
return self.refreshGoogleToken(oldToken.tokenRefresh, oldToken.userId, oldToken)
|
||||
else:
|
||||
logger.warning(f"Unknown authority for token refresh: {old_token.authority}")
|
||||
logger.warning(f"Unknown authority for token refresh: {oldToken.authority}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing token: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def ensureFreshToken(self, token: Token, *, secondsBeforeExpiry: int = 30 * 60, saveCallback: Optional[Callable[[Token], None]] = None) -> Optional[Token]:
|
||||
"""Ensure a token is fresh; refresh if expiring within threshold.
|
||||
|
||||
Args:
|
||||
token: Existing token to validate/refresh.
|
||||
secondsBeforeExpiry: Threshold window to proactively refresh.
|
||||
saveCallback: Optional function to persist a refreshed token.
|
||||
|
||||
Returns:
|
||||
A fresh token (refreshed or original) or None if refresh failed.
|
||||
"""
|
||||
try:
|
||||
if token is None:
|
||||
return None
|
||||
|
||||
nowTs = getUtcTimestamp()
|
||||
expiresAt = token.expiresAt or 0
|
||||
|
||||
# If token expires within the threshold, try to refresh
|
||||
if expiresAt and expiresAt < (nowTs + secondsBeforeExpiry):
|
||||
logger.info(
|
||||
f"ensureFreshToken: Token for connection {token.connectionId} expiring soon "
|
||||
f"(in {max(0, expiresAt - nowTs)}s). Attempting proactive refresh."
|
||||
)
|
||||
refreshed = self.refreshToken(token)
|
||||
if refreshed:
|
||||
if saveCallback is not None:
|
||||
try:
|
||||
saveCallback(refreshed)
|
||||
except Exception as e:
|
||||
logger.warning(f"ensureFreshToken: Failed to persist refreshed token: {e}")
|
||||
return refreshed
|
||||
else:
|
||||
logger.warning("ensureFreshToken: Token refresh failed")
|
||||
return None
|
||||
|
||||
# Token is sufficiently fresh
|
||||
return token
|
||||
except Exception as e:
|
||||
logger.error(f"ensureFreshToken: Error ensuring fresh token: {e}")
|
||||
return None
|
||||
|
||||
# Convenience wrapper to fetch and ensure fresh token for a connection via interface layer
|
||||
def getFreshToken(self, connectionId: str, secondsBeforeExpiry: int = 30 * 60) -> Optional[Token]:
|
||||
"""Return a fresh token for a connection, refreshing when expiring soon.
|
||||
|
||||
Reads the latest stored token via interface layer, then
|
||||
uses ensure_fresh_token to refresh if needed and persists the refreshed
|
||||
token via interface layer.
|
||||
"""
|
||||
try:
|
||||
from modules.interfaces.interfaceDbAppObjects import getRootInterface
|
||||
interfaceDbApp = getRootInterface()
|
||||
|
||||
token = interfaceDbApp.getConnectionToken(connectionId)
|
||||
if not token:
|
||||
return None
|
||||
return self.ensureFreshToken(
|
||||
token,
|
||||
secondsBeforeExpiry=secondsBeforeExpiry,
|
||||
saveCallback=lambda t: interfaceDbApp.saveConnectionToken(t)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"getFreshToken: Error fetching or refreshing token for connection {connectionId}: {e}")
|
||||
return None
|
||||
185
modules/security/tokenRefreshMiddleware.py
Normal file
185
modules/security/tokenRefreshMiddleware.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
"""
|
||||
Token Refresh Middleware for PowerOn Gateway
|
||||
|
||||
This middleware automatically refreshes expired OAuth tokens
|
||||
when API endpoints are accessed, providing seamless user experience.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from typing import Callable
|
||||
import asyncio
|
||||
from modules.security.tokenRefreshService import token_refresh_service
|
||||
from modules.shared.timezoneUtils import getUtcTimestamp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TokenRefreshMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that automatically refreshes expired OAuth tokens
|
||||
when API endpoints are accessed.
|
||||
"""
|
||||
|
||||
def __init__(self, app, enabled: bool = True):
|
||||
super().__init__(app)
|
||||
self.enabled = enabled
|
||||
self.refresh_endpoints = {
|
||||
'/api/connections',
|
||||
'/api/files',
|
||||
'/api/chat',
|
||||
'/api/msft',
|
||||
'/api/google'
|
||||
}
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""
|
||||
Process request and refresh tokens if needed
|
||||
"""
|
||||
if not self.enabled:
|
||||
return await call_next(request)
|
||||
|
||||
# Check if this is an endpoint that might need token refresh
|
||||
if not self._should_check_tokens(request):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract user ID from request (if available)
|
||||
user_id = self._extract_user_id(request)
|
||||
if not user_id:
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
# Perform silent token refresh in background
|
||||
# Don't wait for completion to avoid slowing down the request
|
||||
asyncio.create_task(self._silent_refresh_tokens(user_id))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error scheduling token refresh: {str(e)}")
|
||||
# Continue with request even if refresh scheduling fails
|
||||
|
||||
# Process the original request
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
def _should_check_tokens(self, request: Request) -> bool:
|
||||
"""
|
||||
Check if this request should trigger token refresh
|
||||
"""
|
||||
path = request.url.path
|
||||
|
||||
# Only check specific API endpoints
|
||||
for endpoint in self.refresh_endpoints:
|
||||
if path.startswith(endpoint):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _extract_user_id(self, request: Request) -> str:
|
||||
"""
|
||||
Extract user ID from request context
|
||||
"""
|
||||
try:
|
||||
# Try to get user from request state (set by auth middleware)
|
||||
if hasattr(request.state, 'user_id'):
|
||||
return request.state.user_id
|
||||
|
||||
# Try to get from JWT token in cookies or headers
|
||||
# This is a fallback if user state is not available
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not extract user ID: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _silent_refresh_tokens(self, user_id: str) -> None:
|
||||
"""
|
||||
Perform silent token refresh for the user
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Starting silent token refresh for user {user_id}")
|
||||
|
||||
# Refresh expired tokens
|
||||
result = await token_refresh_service.refresh_expired_tokens(user_id)
|
||||
|
||||
if result.get("refreshed", 0) > 0:
|
||||
logger.info(f"Silently refreshed {result['refreshed']} tokens for user {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in silent token refresh for user {user_id}: {str(e)}")
|
||||
|
||||
class ProactiveTokenRefreshMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware that proactively refreshes tokens before they expire
|
||||
"""
|
||||
|
||||
def __init__(self, app, enabled: bool = True, check_interval_minutes: int = 5):
|
||||
super().__init__(app)
|
||||
self.enabled = enabled
|
||||
self.check_interval_minutes = check_interval_minutes
|
||||
self.last_check = {}
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""
|
||||
Process request and check for proactive refresh needs
|
||||
"""
|
||||
if not self.enabled:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract user ID from request
|
||||
user_id = self._extract_user_id(request)
|
||||
if not user_id:
|
||||
return await call_next(request)
|
||||
|
||||
# Check if we need to do proactive refresh
|
||||
if self._should_check_proactive_refresh(user_id):
|
||||
try:
|
||||
# Perform proactive refresh in background
|
||||
asyncio.create_task(self._proactive_refresh_tokens(user_id))
|
||||
self.last_check[user_id] = getUtcTimestamp()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error scheduling proactive refresh: {str(e)}")
|
||||
|
||||
# Process the original request
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
def _extract_user_id(self, request: Request) -> str:
|
||||
"""
|
||||
Extract user ID from request context
|
||||
"""
|
||||
try:
|
||||
if hasattr(request.state, 'user_id'):
|
||||
return request.state.user_id
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _should_check_proactive_refresh(self, user_id: str) -> bool:
|
||||
"""
|
||||
Check if we should perform proactive refresh for this user
|
||||
"""
|
||||
try:
|
||||
current_time = getUtcTimestamp()
|
||||
last_check = self.last_check.get(user_id, 0)
|
||||
|
||||
# Check every 5 minutes
|
||||
return (current_time - last_check) > (self.check_interval_minutes * 60)
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def _proactive_refresh_tokens(self, user_id: str) -> None:
|
||||
"""
|
||||
Perform proactive token refresh for the user
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Starting proactive token refresh for user {user_id}")
|
||||
|
||||
result = await token_refresh_service.proactive_refresh(user_id)
|
||||
|
||||
if result.get("refreshed", 0) > 0:
|
||||
logger.info(f"Proactively refreshed {result['refreshed']} tokens for user {user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in proactive token refresh for user {user_id}: {str(e)}")
|
||||
289
modules/security/tokenRefreshService.py
Normal file
289
modules/security/tokenRefreshService.py
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
"""
|
||||
Token Refresh Service for PowerOn Gateway
|
||||
|
||||
This service handles automatic token refresh for OAuth connections
|
||||
when they are accessed via API calls. It runs silently in the background
|
||||
to ensure users don't experience token expiration issues.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from modules.datamodels.datamodelUam import UserConnection, AuthAuthority
|
||||
from modules.shared.timezoneUtils import getUtcTimestamp
|
||||
from modules.shared.auditLogger import audit_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TokenRefreshService:
|
||||
"""Service for automatic token refresh operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.rate_limit_map = {} # Track refresh attempts per connection
|
||||
self.max_attempts_per_hour = 3
|
||||
self.refresh_window_minutes = 60
|
||||
|
||||
def _is_rate_limited(self, connection_id: str) -> bool:
|
||||
"""Check if connection is rate limited for refresh attempts"""
|
||||
now = getUtcTimestamp()
|
||||
if connection_id not in self.rate_limit_map:
|
||||
return False
|
||||
|
||||
# Remove attempts older than 1 hour
|
||||
recent_attempts = [
|
||||
attempt_time for attempt_time in self.rate_limit_map[connection_id]
|
||||
if now - attempt_time < (self.refresh_window_minutes * 60)
|
||||
]
|
||||
self.rate_limit_map[connection_id] = recent_attempts
|
||||
|
||||
return len(recent_attempts) >= self.max_attempts_per_hour
|
||||
|
||||
def _record_refresh_attempt(self, connection_id: str) -> None:
|
||||
"""Record a refresh attempt for rate limiting"""
|
||||
now = getUtcTimestamp()
|
||||
if connection_id not in self.rate_limit_map:
|
||||
self.rate_limit_map[connection_id] = []
|
||||
self.rate_limit_map[connection_id].append(now)
|
||||
|
||||
async def _refresh_google_token(self, interface, connection: UserConnection) -> bool:
|
||||
"""Refresh Google OAuth token"""
|
||||
try:
|
||||
logger.debug(f"Refreshing Google token for connection {connection.id}")
|
||||
|
||||
# Get current token (no refresh in interface layer)
|
||||
current_token = interface.getConnectionToken(connection.id)
|
||||
if not current_token:
|
||||
logger.warning(f"No Google token found for connection {connection.id}")
|
||||
return False
|
||||
|
||||
# Import Google token refresh logic
|
||||
from modules.security.tokenManager import TokenManager
|
||||
token_manager = TokenManager()
|
||||
|
||||
# Attempt to refresh the token
|
||||
refreshedToken = token_manager.refreshToken(current_token)
|
||||
if refreshedToken:
|
||||
# Save the refreshed token
|
||||
interface.saveConnectionToken(refreshedToken)
|
||||
|
||||
# Update connection status
|
||||
interface.db.recordModify(UserConnection, connection.id, {
|
||||
"lastChecked": getUtcTimestamp(),
|
||||
"expiresAt": refreshed_token.expiresAt
|
||||
})
|
||||
|
||||
logger.info(f"Successfully refreshed Google token for connection {connection.id}")
|
||||
|
||||
# Log audit event
|
||||
try:
|
||||
audit_logger.logSecurityEvent(
|
||||
userId=str(connection.userId),
|
||||
mandateId="system",
|
||||
action="token_refresh",
|
||||
details=f"Google token refreshed for connection {connection.id}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Failed to refresh Google token for connection {connection.id}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing Google token for connection {connection.id}: {str(e)}")
|
||||
return False
|
||||
|
||||
async def _refresh_microsoft_token(self, interface, connection: UserConnection) -> bool:
|
||||
"""Refresh Microsoft OAuth token"""
|
||||
try:
|
||||
logger.debug(f"Refreshing Microsoft token for connection {connection.id}")
|
||||
|
||||
# Get current token (no refresh in interface layer)
|
||||
current_token = interface.getConnectionToken(connection.id)
|
||||
if not current_token:
|
||||
logger.warning(f"No Microsoft token found for connection {connection.id}")
|
||||
return False
|
||||
|
||||
# Import Microsoft token refresh logic
|
||||
from modules.security.tokenManager import TokenManager
|
||||
token_manager = TokenManager()
|
||||
|
||||
# Attempt to refresh the token
|
||||
refreshedToken = token_manager.refreshToken(current_token)
|
||||
if refreshedToken:
|
||||
# Save the refreshed token
|
||||
interface.saveConnectionToken(refreshedToken)
|
||||
|
||||
# Update connection status
|
||||
interface.db.recordModify(UserConnection, connection.id, {
|
||||
"lastChecked": getUtcTimestamp(),
|
||||
"expiresAt": refreshed_token.expiresAt
|
||||
})
|
||||
|
||||
logger.info(f"Successfully refreshed Microsoft token for connection {connection.id}")
|
||||
|
||||
# Log audit event
|
||||
try:
|
||||
audit_logger.logSecurityEvent(
|
||||
userId=str(connection.userId),
|
||||
mandateId="system",
|
||||
action="token_refresh",
|
||||
details=f"Microsoft token refreshed for connection {connection.id}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Failed to refresh Microsoft token for connection {connection.id}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing Microsoft token for connection {connection.id}: {str(e)}")
|
||||
return False
|
||||
|
||||
async def refresh_expired_tokens(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Refresh expired OAuth tokens for a user
|
||||
|
||||
Args:
|
||||
user_id: User ID to refresh tokens for
|
||||
|
||||
Returns:
|
||||
Dict with refresh results
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Starting silent token refresh for user {user_id}")
|
||||
|
||||
# Get user interface
|
||||
from modules.interfaces.interfaceDbAppObjects import getRootInterface
|
||||
root_interface = getRootInterface()
|
||||
|
||||
# Get user connections
|
||||
connections = root_interface.getUserConnections(user_id)
|
||||
if not connections:
|
||||
logger.debug(f"No connections found for user {user_id}")
|
||||
return {"refreshed": 0, "failed": 0, "rate_limited": 0}
|
||||
|
||||
refreshed_count = 0
|
||||
failed_count = 0
|
||||
rate_limited_count = 0
|
||||
|
||||
# Process each connection
|
||||
for connection in connections:
|
||||
# Only refresh expired OAuth connections
|
||||
if (connection.tokenStatus == 'expired' and
|
||||
connection.authority in [AuthAuthority.GOOGLE, AuthAuthority.MSFT]):
|
||||
|
||||
# Check rate limiting
|
||||
if self._is_rate_limited(connection.id):
|
||||
logger.warning(f"Rate limited for connection {connection.id}")
|
||||
rate_limited_count += 1
|
||||
continue
|
||||
|
||||
# Record attempt
|
||||
self._record_refresh_attempt(connection.id)
|
||||
|
||||
# Refresh based on authority
|
||||
success = False
|
||||
if connection.authority == AuthAuthority.GOOGLE:
|
||||
success = await self._refresh_google_token(root_interface, connection)
|
||||
elif connection.authority == AuthAuthority.MSFT:
|
||||
success = await self._refresh_microsoft_token(root_interface, connection)
|
||||
|
||||
if success:
|
||||
refreshed_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
result = {
|
||||
"refreshed": refreshed_count,
|
||||
"failed": failed_count,
|
||||
"rate_limited": rate_limited_count
|
||||
}
|
||||
|
||||
logger.info(f"Silent token refresh completed for user {user_id}: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during silent token refresh for user {user_id}: {str(e)}")
|
||||
return {"refreshed": 0, "failed": 0, "rate_limited": 0, "error": str(e)}
|
||||
|
||||
async def proactive_refresh(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Proactively refresh tokens that expire within 5 minutes
|
||||
|
||||
Args:
|
||||
user_id: User ID to check tokens for
|
||||
|
||||
Returns:
|
||||
Dict with refresh results
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"Starting proactive token refresh for user {user_id}")
|
||||
|
||||
# Get user interface
|
||||
from modules.interfaces.interfaceDbAppObjects import getRootInterface
|
||||
root_interface = getRootInterface()
|
||||
|
||||
# Get user connections
|
||||
connections = root_interface.getUserConnections(user_id)
|
||||
if not connections:
|
||||
return {"refreshed": 0, "failed": 0, "rate_limited": 0}
|
||||
|
||||
refreshed_count = 0
|
||||
failed_count = 0
|
||||
rate_limited_count = 0
|
||||
current_time = getUtcTimestamp()
|
||||
five_minutes = 5 * 60 # 5 minutes in seconds
|
||||
|
||||
# Process each connection
|
||||
for connection in connections:
|
||||
# Only refresh active tokens that expire soon
|
||||
if (connection.tokenStatus == 'active' and
|
||||
connection.tokenExpiresAt and
|
||||
connection.authority in [AuthAuthority.GOOGLE, AuthAuthority.MSFT]):
|
||||
|
||||
# Check if token expires within 5 minutes
|
||||
time_until_expiry = connection.tokenExpiresAt - current_time
|
||||
if 0 < time_until_expiry <= five_minutes:
|
||||
|
||||
# Check rate limiting
|
||||
if self._is_rate_limited(connection.id):
|
||||
logger.warning(f"Rate limited for proactive refresh of connection {connection.id}")
|
||||
rate_limited_count += 1
|
||||
continue
|
||||
|
||||
# Record attempt
|
||||
self._record_refresh_attempt(connection.id)
|
||||
|
||||
# Refresh based on authority
|
||||
success = False
|
||||
if connection.authority == AuthAuthority.GOOGLE:
|
||||
success = await self._refresh_google_token(root_interface, connection)
|
||||
elif connection.authority == AuthAuthority.MSFT:
|
||||
success = await self._refresh_microsoft_token(root_interface, connection)
|
||||
|
||||
if success:
|
||||
refreshed_count += 1
|
||||
logger.info(f"Proactively refreshed {connection.authority} token for connection {connection.id}")
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
result = {
|
||||
"refreshed": refreshed_count,
|
||||
"failed": failed_count,
|
||||
"rate_limited": rate_limited_count
|
||||
}
|
||||
|
||||
if refreshed_count > 0:
|
||||
logger.info(f"Proactive token refresh completed for user {user_id}: {result}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during proactive token refresh for user {user_id}: {str(e)}")
|
||||
return {"refreshed": 0, "failed": 0, "rate_limited": 0, "error": str(e)}
|
||||
|
||||
# Global service instance
|
||||
token_refresh_service = TokenRefreshService()
|
||||
91
modules/services/__init__.py
Normal file
91
modules/services/__init__.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
from typing import Any
|
||||
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelChat import ChatWorkflow
|
||||
|
||||
class PublicService:
|
||||
"""Lightweight proxy exposing only public callable attributes of a target.
|
||||
|
||||
- Hides names starting with '_'
|
||||
- Optionally restricts to callables only
|
||||
- Optional name_filter predicate for allow-list patterns
|
||||
"""
|
||||
|
||||
def __init__(self, target: Any, functionsOnly: bool = True, nameFilter=None):
|
||||
self._target = target
|
||||
self._functionsOnly = functionsOnly
|
||||
self._nameFilter = nameFilter
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if name.startswith('_'):
|
||||
raise AttributeError(f"'{type(self._target).__name__}' attribute '{name}' is private")
|
||||
if self._nameFilter and not self._nameFilter(name):
|
||||
raise AttributeError(f"'{name}' not exposed by policy")
|
||||
attr = getattr(self._target, name)
|
||||
if self._functionsOnly and not callable(attr):
|
||||
raise AttributeError(f"'{name}' is not a function")
|
||||
return attr
|
||||
|
||||
def __dir__(self):
|
||||
names = [
|
||||
n for n in dir(self._target)
|
||||
if not n.startswith('_')
|
||||
and (not self._functionsOnly or callable(getattr(self._target, n, None)))
|
||||
and (self._nameFilter(n) if self._nameFilter else True)
|
||||
]
|
||||
return sorted(names)
|
||||
|
||||
|
||||
class Services:
|
||||
|
||||
def __init__(self, user: User, workflow: ChatWorkflow = None):
|
||||
self.user: User = user
|
||||
self.workflow: ChatWorkflow = workflow
|
||||
self.currentUserPrompt: str = "" # Cleaned/normalized user intent for the current round
|
||||
self.rawUserPrompt: str = "" # Original raw user message for the current round
|
||||
|
||||
# Initialize interfaces
|
||||
|
||||
from modules.interfaces.interfaceDbChatObjects import getInterface as getChatInterface
|
||||
self.interfaceDbChat = getChatInterface(user)
|
||||
|
||||
from modules.interfaces.interfaceDbAppObjects import getInterface as getAppInterface
|
||||
self.interfaceDbApp = getAppInterface(user)
|
||||
|
||||
from modules.interfaces.interfaceDbComponentObjects import getInterface as getComponentInterface
|
||||
self.interfaceDbComponent = getComponentInterface(user)
|
||||
|
||||
# Initialize service packages
|
||||
|
||||
from .serviceExtraction.mainServiceExtraction import ExtractionService
|
||||
self.extraction = PublicService(ExtractionService(self))
|
||||
|
||||
from .serviceGeneration.mainServiceGeneration import GenerationService
|
||||
self.generation = PublicService(GenerationService(self))
|
||||
|
||||
from .serviceNeutralization.mainServiceNeutralization import NeutralizationService
|
||||
self.neutralization = PublicService(NeutralizationService(self))
|
||||
|
||||
from .serviceSharepoint.mainServiceSharepoint import SharepointService
|
||||
self.sharepoint = PublicService(SharepointService(self))
|
||||
|
||||
from .serviceAi.mainServiceAi import AiService
|
||||
self.ai = PublicService(AiService(self), functionsOnly=False)
|
||||
|
||||
from .serviceTicket.mainServiceTicket import TicketService
|
||||
self.ticket = PublicService(TicketService(self))
|
||||
|
||||
from .serviceChat.mainServiceChat import ChatService
|
||||
self.chat = PublicService(ChatService(self))
|
||||
|
||||
from .serviceUtils.mainServiceUtils import UtilsService
|
||||
self.utils = PublicService(UtilsService(self))
|
||||
|
||||
from .serviceWeb.mainServiceWeb import WebService
|
||||
self.web = PublicService(WebService(self))
|
||||
|
||||
|
||||
def getInterface(user: User, workflow: ChatWorkflow) -> Services:
|
||||
return Services(user, workflow)
|
||||
|
||||
|
||||
843
modules/services/serviceAi/mainServiceAi.py
Normal file
843
modules/services/serviceAi/mainServiceAi.py
Normal file
|
|
@ -0,0 +1,843 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, Tuple, Union
|
||||
from modules.datamodels.datamodelChat import PromptPlaceholder, ChatDocument
|
||||
from modules.services.serviceExtraction.mainServiceExtraction import ExtractionService
|
||||
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, PriorityEnum, ProcessingModeEnum
|
||||
from modules.interfaces.interfaceAiObjects import AiObjects
|
||||
from modules.shared.jsonUtils import (
|
||||
extractJsonString,
|
||||
repairBrokenJson,
|
||||
extractSectionsFromDocument,
|
||||
buildContinuationContext
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Rebuild the model to resolve forward references
|
||||
AiCallRequest.model_rebuild()
|
||||
|
||||
class AiService:
|
||||
"""AI service with core operations integrated."""
|
||||
|
||||
def __init__(self, serviceCenter=None) -> None:
|
||||
"""Initialize AI service with service center access.
|
||||
|
||||
Args:
|
||||
serviceCenter: Service center instance for accessing other services
|
||||
"""
|
||||
self.services = serviceCenter
|
||||
# Only depend on interfaces
|
||||
self.aiObjects = None # Will be initialized in create() or _ensureAiObjectsInitialized()
|
||||
# Submodules initialized as None - will be set in _initializeSubmodules() after aiObjects is ready
|
||||
self.extractionService = None
|
||||
|
||||
def _initializeSubmodules(self):
|
||||
"""Initialize all submodules after aiObjects is ready."""
|
||||
if self.aiObjects is None:
|
||||
raise RuntimeError("aiObjects must be initialized before initializing submodules")
|
||||
|
||||
if self.extractionService is None:
|
||||
logger.info("Initializing ExtractionService...")
|
||||
self.extractionService = ExtractionService(self.services)
|
||||
|
||||
async def _ensureAiObjectsInitialized(self):
|
||||
"""Ensure aiObjects is initialized and submodules are ready."""
|
||||
if self.aiObjects is None:
|
||||
logger.info("Lazy initializing AiObjects...")
|
||||
self.aiObjects = await AiObjects.create()
|
||||
logger.info("AiObjects initialization completed")
|
||||
# Initialize submodules after aiObjects is ready
|
||||
self._initializeSubmodules()
|
||||
|
||||
@classmethod
|
||||
async def create(cls, serviceCenter=None) -> "AiService":
|
||||
"""Create AiService instance with all connectors and submodules initialized."""
|
||||
logger.info("AiService.create() called")
|
||||
instance = cls(serviceCenter)
|
||||
logger.info("AiService created, about to call AiObjects.create()...")
|
||||
instance.aiObjects = await AiObjects.create()
|
||||
logger.info("AiObjects.create() completed")
|
||||
# Initialize all submodules after aiObjects is ready
|
||||
instance._initializeSubmodules()
|
||||
logger.info("AiService submodules initialized")
|
||||
return instance
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _buildPromptWithPlaceholders(self, prompt: str, placeholders: Optional[Dict[str, str]]) -> str:
|
||||
"""
|
||||
Build full prompt by replacing placeholders with their content.
|
||||
Uses the new {{KEY:placeholder}} format.
|
||||
|
||||
Args:
|
||||
prompt: The base prompt template
|
||||
placeholders: Dictionary of placeholder key-value pairs
|
||||
|
||||
Returns:
|
||||
Prompt with placeholders replaced
|
||||
"""
|
||||
if not placeholders:
|
||||
return prompt
|
||||
|
||||
full_prompt = prompt
|
||||
for placeholder, content in placeholders.items():
|
||||
# Skip if content is None or empty
|
||||
if content is None:
|
||||
continue
|
||||
# Replace {{KEY:placeholder}}
|
||||
full_prompt = full_prompt.replace(f"{{{{KEY:{placeholder}}}}}", str(content))
|
||||
|
||||
return full_prompt
|
||||
|
||||
async def _analyzePromptAndCreateOptions(self, prompt: str) -> AiCallOptions:
|
||||
"""Analyze prompt to determine appropriate AiCallOptions parameters."""
|
||||
try:
|
||||
# Get dynamic enum values from Pydantic models
|
||||
operationTypes = [e.value for e in OperationTypeEnum]
|
||||
priorities = [e.value for e in PriorityEnum]
|
||||
processingModes = [e.value for e in ProcessingModeEnum]
|
||||
|
||||
# Create analysis prompt for AI to determine operation type and parameters
|
||||
analysisPrompt = f"""
|
||||
You are an AI operation analyzer. Analyze the following prompt and determine the most appropriate operation type and parameters.
|
||||
|
||||
PROMPT TO ANALYZE:
|
||||
{self.services.utils.sanitizePromptContent(prompt, 'userinput')}
|
||||
|
||||
Based on the prompt content, determine:
|
||||
1. operationType: Choose the most appropriate from: {', '.join(operationTypes)}
|
||||
2. priority: Choose from: {', '.join(priorities)}
|
||||
3. processingMode: Choose from: {', '.join(processingModes)}
|
||||
4. compressPrompt: true/false (true for story-like prompts, false for structured prompts with JSON/schemas)
|
||||
5. compressContext: true/false (true to summarize context, false to process fully)
|
||||
|
||||
Respond with ONLY a JSON object in this exact format:
|
||||
{{
|
||||
"operationType": "dataAnalyse",
|
||||
"priority": "balanced",
|
||||
"processingMode": "basic",
|
||||
"compressPrompt": true,
|
||||
"compressContext": true
|
||||
}}
|
||||
"""
|
||||
|
||||
# Use AI to analyze the prompt
|
||||
request = AiCallRequest(
|
||||
prompt=analysisPrompt,
|
||||
options=AiCallOptions(
|
||||
operationType=OperationTypeEnum.DATA_ANALYSE,
|
||||
priority=PriorityEnum.SPEED,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
compressPrompt=True,
|
||||
compressContext=False
|
||||
)
|
||||
)
|
||||
|
||||
response = await self.aiObjects.call(request)
|
||||
|
||||
# Parse AI response
|
||||
try:
|
||||
jsonStart = response.content.find('{')
|
||||
jsonEnd = response.content.rfind('}') + 1
|
||||
if jsonStart != -1 and jsonEnd > jsonStart:
|
||||
analysis = json.loads(response.content[jsonStart:jsonEnd])
|
||||
|
||||
# Map string values to enums
|
||||
operationType = OperationTypeEnum(analysis.get('operationType', 'dataAnalyse'))
|
||||
priority = PriorityEnum(analysis.get('priority', 'balanced'))
|
||||
processingMode = ProcessingModeEnum(analysis.get('processingMode', 'basic'))
|
||||
|
||||
return AiCallOptions(
|
||||
operationType=operationType,
|
||||
priority=priority,
|
||||
processingMode=processingMode,
|
||||
compressPrompt=analysis.get('compressPrompt', True),
|
||||
compressContext=analysis.get('compressContext', True)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse AI analysis response: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Prompt analysis failed: {e}")
|
||||
|
||||
# Fallback to default options
|
||||
return AiCallOptions(
|
||||
operationType=OperationTypeEnum.DATA_ANALYSE,
|
||||
priority=PriorityEnum.BALANCED,
|
||||
processingMode=ProcessingModeEnum.BASIC
|
||||
)
|
||||
|
||||
async def _callAiWithLooping(
|
||||
self,
|
||||
prompt: str,
|
||||
options: AiCallOptions,
|
||||
debugPrefix: str = "ai_call",
|
||||
promptBuilder: Optional[callable] = None,
|
||||
promptArgs: Optional[Dict[str, Any]] = None,
|
||||
operationId: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Shared core function for AI calls with repair-based looping system.
|
||||
Automatically repairs broken JSON and continues generation seamlessly.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to AI
|
||||
options: AI call configuration options
|
||||
debugPrefix: Prefix for debug file names
|
||||
promptBuilder: Optional function to rebuild prompts for continuation
|
||||
promptArgs: Optional arguments for prompt builder
|
||||
operationId: Optional operation ID for progress tracking
|
||||
|
||||
Returns:
|
||||
Complete AI response after all iterations
|
||||
"""
|
||||
maxIterations = 50 # Prevent infinite loops
|
||||
iteration = 0
|
||||
allSections = [] # Accumulate all sections across iterations
|
||||
lastRawResponse = None # Store last raw JSON response for continuation
|
||||
documentMetadata = None # Store document metadata (title, filename) from first iteration
|
||||
|
||||
while iteration < maxIterations:
|
||||
iteration += 1
|
||||
|
||||
# Update progress for iteration start
|
||||
if operationId:
|
||||
if iteration == 1:
|
||||
self.services.chat.progressLogUpdate(operationId, 0.5, f"Starting AI call iteration {iteration}")
|
||||
else:
|
||||
# For continuation iterations, show progress incrementally
|
||||
baseProgress = 0.5 + (min(iteration - 1, maxIterations) / maxIterations * 0.4) # Progress from 0.5 to 0.9 over maxIterations iterations
|
||||
self.services.chat.progressLogUpdate(operationId, baseProgress, f"Continuing generation (iteration {iteration})")
|
||||
|
||||
# Build iteration prompt
|
||||
if len(allSections) > 0 and promptBuilder and promptArgs:
|
||||
# This is a continuation - build continuation context with raw JSON and rebuild prompt
|
||||
continuationContext = buildContinuationContext(allSections, lastRawResponse)
|
||||
if not lastRawResponse:
|
||||
logger.warning(f"Iteration {iteration}: No previous response available for continuation!")
|
||||
|
||||
# Rebuild prompt with continuation context using the provided prompt builder
|
||||
iterationPrompt = await promptBuilder(**promptArgs, continuationContext=continuationContext)
|
||||
else:
|
||||
# First iteration - use original prompt
|
||||
iterationPrompt = prompt
|
||||
|
||||
# Make AI call
|
||||
try:
|
||||
if operationId and iteration == 1:
|
||||
self.services.chat.progressLogUpdate(operationId, 0.51, "Calling AI model")
|
||||
request = AiCallRequest(
|
||||
prompt=iterationPrompt,
|
||||
context="",
|
||||
options=options
|
||||
)
|
||||
|
||||
# Write the ACTUAL prompt sent to AI
|
||||
if iteration == 1:
|
||||
self.services.utils.writeDebugFile(iterationPrompt, f"{debugPrefix}_prompt")
|
||||
else:
|
||||
self.services.utils.writeDebugFile(iterationPrompt, f"{debugPrefix}_prompt_iteration_{iteration}")
|
||||
|
||||
response = await self.aiObjects.call(request)
|
||||
result = response.content
|
||||
|
||||
# Update progress after AI call
|
||||
if operationId:
|
||||
if iteration == 1:
|
||||
self.services.chat.progressLogUpdate(operationId, 0.6, f"AI response received (iteration {iteration})")
|
||||
else:
|
||||
progress = 0.6 + (min(iteration - 1, 10) * 0.03)
|
||||
self.services.chat.progressLogUpdate(operationId, progress, f"Processing response (iteration {iteration})")
|
||||
|
||||
# Write raw AI response to debug file
|
||||
if iteration == 1:
|
||||
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response")
|
||||
else:
|
||||
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response_iteration_{iteration}")
|
||||
|
||||
# Emit stats for this iteration
|
||||
self.services.chat.storeWorkflowStat(
|
||||
self.services.workflow,
|
||||
response,
|
||||
f"ai.call.{debugPrefix}.iteration_{iteration}"
|
||||
)
|
||||
|
||||
if not result or not result.strip():
|
||||
logger.warning(f"Iteration {iteration}: Empty response, stopping")
|
||||
break
|
||||
|
||||
# Store raw response for continuation (even if broken)
|
||||
lastRawResponse = result
|
||||
|
||||
# Check for complete_response flag in raw response (before parsing)
|
||||
import re
|
||||
if re.search(r'"complete_response"\s*:\s*true', result, re.IGNORECASE):
|
||||
pass # Flag detected, will stop in _shouldContinueGeneration
|
||||
|
||||
# Extract sections from response (handles both valid and broken JSON)
|
||||
extractedSections, wasJsonComplete, parsedResult = self._extractSectionsFromResponse(result, iteration, debugPrefix)
|
||||
|
||||
# Extract document metadata from first iteration if available
|
||||
if iteration == 1 and parsedResult and not documentMetadata:
|
||||
documentMetadata = self._extractDocumentMetadata(parsedResult)
|
||||
|
||||
# Update progress after parsing
|
||||
if operationId:
|
||||
if extractedSections:
|
||||
self.services.chat.progressLogUpdate(operationId, 0.65 + (min(iteration - 1, 10) * 0.025), f"Extracted {len(extractedSections)} sections (iteration {iteration})")
|
||||
|
||||
if not extractedSections:
|
||||
# If we're in continuation mode and JSON was incomplete, don't stop - continue to allow retry
|
||||
if iteration > 1 and not wasJsonComplete:
|
||||
logger.warning(f"Iteration {iteration}: No sections extracted from continuation fragment, continuing for another attempt")
|
||||
continue
|
||||
# Otherwise, stop if no sections
|
||||
logger.warning(f"Iteration {iteration}: No sections extracted, stopping")
|
||||
break
|
||||
|
||||
# Add new sections to accumulator
|
||||
allSections.extend(extractedSections)
|
||||
|
||||
# Check if we should continue (completion detection)
|
||||
if self._shouldContinueGeneration(allSections, iteration, wasJsonComplete, result):
|
||||
continue
|
||||
else:
|
||||
# Done - build final result
|
||||
if operationId:
|
||||
self.services.chat.progressLogUpdate(operationId, 0.95, f"Generation complete ({iteration} iterations, {len(allSections)} sections)")
|
||||
logger.info(f"Generation complete after {iteration} iterations: {len(allSections)} sections")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in AI call iteration {iteration}: {str(e)}")
|
||||
break
|
||||
|
||||
if iteration >= maxIterations:
|
||||
logger.warning(f"AI call stopped after maximum iterations ({maxIterations})")
|
||||
|
||||
# Build final result from accumulated sections
|
||||
final_result = self._buildFinalResultFromSections(allSections, documentMetadata)
|
||||
|
||||
# Write final result to debug file
|
||||
self.services.utils.writeDebugFile(final_result, f"{debugPrefix}_final_result")
|
||||
|
||||
return final_result
|
||||
|
||||
def _extractSectionsFromResponse(
|
||||
self,
|
||||
result: str,
|
||||
iteration: int,
|
||||
debugPrefix: str
|
||||
) -> Tuple[List[Dict[str, Any]], bool, Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Extract sections from AI response, handling both valid and broken JSON.
|
||||
Uses repair mechanism for broken JSON.
|
||||
Checks for "complete_response": true flag to determine completion.
|
||||
Returns (sections, wasJsonComplete, parsedResult)
|
||||
"""
|
||||
# First, try to parse as valid JSON
|
||||
try:
|
||||
extracted = extractJsonString(result)
|
||||
parsed_result = json.loads(extracted)
|
||||
|
||||
# Check if AI marked response as complete
|
||||
isComplete = parsed_result.get("complete_response", False) == True
|
||||
|
||||
# Extract sections from parsed JSON
|
||||
sections = extractSectionsFromDocument(parsed_result)
|
||||
|
||||
# If AI marked as complete, always return as complete
|
||||
if isComplete:
|
||||
return sections, True, parsed_result
|
||||
|
||||
# If in continuation mode (iteration > 1), continuation responses are expected to be fragments
|
||||
# A fragment with 0 extractable sections means JSON is incomplete - need another iteration
|
||||
if len(sections) == 0 and iteration > 1:
|
||||
return sections, False, parsed_result # Mark as incomplete so loop continues
|
||||
|
||||
# First iteration with 0 sections means empty response - stop
|
||||
if len(sections) == 0:
|
||||
return sections, True, parsed_result # Complete but empty
|
||||
|
||||
return sections, True, parsed_result # JSON was complete with sections
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
# Broken JSON - try repair mechanism (normal in iterative generation)
|
||||
self.services.utils.writeDebugFile(result, f"{debugPrefix}_broken_json_iteration_{iteration}")
|
||||
|
||||
# Try to repair
|
||||
repaired_json = repairBrokenJson(result)
|
||||
|
||||
if repaired_json:
|
||||
# Extract sections from repaired JSON
|
||||
sections = extractSectionsFromDocument(repaired_json)
|
||||
return sections, False, repaired_json # JSON was broken but repaired
|
||||
else:
|
||||
# Repair failed - log error
|
||||
logger.error(f"Iteration {iteration}: All repair strategies failed")
|
||||
return [], False, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Iteration {iteration}: Unexpected error during parsing: {str(e)}")
|
||||
return [], False, None
|
||||
|
||||
def _shouldContinueGeneration(
|
||||
self,
|
||||
allSections: List[Dict[str, Any]],
|
||||
iteration: int,
|
||||
wasJsonComplete: bool,
|
||||
rawResponse: str = None
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if generation should continue based on JSON completeness, complete_response flag, and task completion.
|
||||
Returns True if we should continue, False if done.
|
||||
"""
|
||||
if len(allSections) == 0:
|
||||
return True # No sections yet, continue
|
||||
|
||||
# Check for complete_response flag in raw response
|
||||
if rawResponse:
|
||||
import re
|
||||
if re.search(r'"complete_response"\s*:\s*true', rawResponse, re.IGNORECASE):
|
||||
logger.info(f"Iteration {iteration}: AI marked response as complete (complete_response flag detected)")
|
||||
return False
|
||||
|
||||
# If JSON was complete, stop (AI should have set complete_response if task is done)
|
||||
# For continuation iterations (iteration > 1), if JSON is complete but no flag was set,
|
||||
# stop to prevent infinite loops - AI had a chance to set the flag
|
||||
if wasJsonComplete:
|
||||
if iteration > 1:
|
||||
# Continuation mode: JSON complete without flag means we're likely done
|
||||
# Stop to prevent infinite loops
|
||||
logger.info(f"Iteration {iteration}: JSON complete without complete_response flag - stopping")
|
||||
return False
|
||||
# First iteration with complete JSON - done
|
||||
return False
|
||||
else:
|
||||
# JSON was incomplete/broken - continue
|
||||
return True
|
||||
|
||||
def _extractDocumentMetadata(
|
||||
self,
|
||||
parsedResult: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract document metadata (title, filename) from parsed AI response.
|
||||
Returns dict with 'title' and 'filename' keys if found, None otherwise.
|
||||
"""
|
||||
if not isinstance(parsedResult, dict):
|
||||
return None
|
||||
|
||||
# Try to get from documents array (preferred structure)
|
||||
if "documents" in parsedResult and isinstance(parsedResult["documents"], list) and len(parsedResult["documents"]) > 0:
|
||||
firstDoc = parsedResult["documents"][0]
|
||||
if isinstance(firstDoc, dict):
|
||||
title = firstDoc.get("title")
|
||||
filename = firstDoc.get("filename")
|
||||
if title or filename:
|
||||
return {
|
||||
"title": title,
|
||||
"filename": filename
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
def _buildFinalResultFromSections(
|
||||
self,
|
||||
allSections: List[Dict[str, Any]],
|
||||
documentMetadata: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Build final JSON result from accumulated sections.
|
||||
Uses AI-provided metadata (title, filename) if available.
|
||||
"""
|
||||
if not allSections:
|
||||
return ""
|
||||
|
||||
# Extract metadata from AI response if available
|
||||
title = "Generated Document"
|
||||
filename = "document.json"
|
||||
if documentMetadata:
|
||||
if documentMetadata.get("title"):
|
||||
title = documentMetadata["title"]
|
||||
if documentMetadata.get("filename"):
|
||||
filename = documentMetadata["filename"]
|
||||
|
||||
# Build documents structure
|
||||
# Assuming single document for now
|
||||
documents = [{
|
||||
"id": "doc_1",
|
||||
"title": title,
|
||||
"filename": filename,
|
||||
"sections": allSections
|
||||
}]
|
||||
|
||||
result = {
|
||||
"metadata": {
|
||||
"split_strategy": "single_document",
|
||||
"source_documents": [],
|
||||
"extraction_method": "ai_generation"
|
||||
},
|
||||
"documents": documents
|
||||
}
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
# Public API Methods
|
||||
|
||||
# Planning AI Call
|
||||
async def callAiPlanning(
|
||||
self,
|
||||
prompt: str,
|
||||
placeholders: Optional[List[PromptPlaceholder]] = None,
|
||||
debugType: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Planning AI call for task planning, action planning, action selection, etc.
|
||||
Always uses static parameters optimized for planning tasks.
|
||||
|
||||
Args:
|
||||
prompt: The planning prompt
|
||||
placeholders: Optional list of placeholder replacements
|
||||
debugType: Optional debug file type identifier (e.g., 'taskplan', 'actionplan', 'intentanalysis')
|
||||
If not provided, defaults to 'plan'
|
||||
|
||||
Returns:
|
||||
Planning JSON response
|
||||
"""
|
||||
await self._ensureAiObjectsInitialized()
|
||||
|
||||
# Planning calls always use static parameters
|
||||
options = AiCallOptions(
|
||||
operationType=OperationTypeEnum.PLAN,
|
||||
priority=PriorityEnum.QUALITY,
|
||||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
compressPrompt=False,
|
||||
compressContext=False
|
||||
)
|
||||
|
||||
# Build full prompt with placeholders
|
||||
if placeholders:
|
||||
placeholdersDict = {p.label: p.content for p in placeholders}
|
||||
fullPrompt = self._buildPromptWithPlaceholders(prompt, placeholdersDict)
|
||||
else:
|
||||
fullPrompt = prompt
|
||||
|
||||
# Root-cause fix: planning must return raw single-shot JSON, not section-based output
|
||||
request = AiCallRequest(
|
||||
prompt=fullPrompt,
|
||||
context="",
|
||||
options=options
|
||||
)
|
||||
|
||||
# Debug: persist prompt/response for analysis with context-specific naming
|
||||
debugPrefix = debugType if debugType else "plan"
|
||||
self.services.utils.writeDebugFile(fullPrompt, f"{debugPrefix}_prompt")
|
||||
response = await self.aiObjects.call(request)
|
||||
result = response.content or ""
|
||||
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response")
|
||||
return result
|
||||
|
||||
# Document Generation AI Call
|
||||
async def callAiDocuments(
|
||||
self,
|
||||
prompt: str,
|
||||
documents: Optional[List[ChatDocument]] = None,
|
||||
options: Optional[AiCallOptions] = None,
|
||||
outputFormat: Optional[str] = None,
|
||||
title: Optional[str] = None
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""
|
||||
Document generation AI call for all non-planning calls.
|
||||
Uses the current unified path with extraction and generation.
|
||||
|
||||
Args:
|
||||
prompt: The main prompt for the AI call
|
||||
documents: Optional list of documents to process
|
||||
options: AI call configuration options
|
||||
outputFormat: Optional output format for document generation
|
||||
title: Optional title for generated documents
|
||||
|
||||
Returns:
|
||||
AI response as string, or dict with documents if outputFormat is specified
|
||||
"""
|
||||
await self._ensureAiObjectsInitialized()
|
||||
|
||||
# Create separate operationId for detailed progress tracking
|
||||
workflowId = self.services.workflow.id if self.services.workflow else f"no-workflow-{int(time.time())}"
|
||||
aiOperationId = f"ai_documents_{workflowId}_{int(time.time())}"
|
||||
|
||||
# Start progress tracking for this operation
|
||||
self.services.chat.progressLogStart(
|
||||
aiOperationId,
|
||||
"AI call with documents",
|
||||
"Document Generation",
|
||||
f"Format: {outputFormat or 'text'}"
|
||||
)
|
||||
|
||||
try:
|
||||
if options is None or (hasattr(options, 'operationType') and options.operationType is None):
|
||||
# Use AI to determine parameters ONLY when truly needed (options=None OR operationType=None)
|
||||
self.services.chat.progressLogUpdate(aiOperationId, 0.1, "Analyzing prompt parameters")
|
||||
options = await self._analyzePromptAndCreateOptions(prompt)
|
||||
|
||||
# Check operationType FIRST - some operations need direct routing (before document generation checks)
|
||||
opType = getattr(options, "operationType", None)
|
||||
|
||||
# Handle image generation requests directly via generic path
|
||||
isImageRequest = (opType == OperationTypeEnum.IMAGE_GENERATE)
|
||||
|
||||
if isImageRequest:
|
||||
# Image generation uses generic call path but bypasses document generation pipeline
|
||||
self.services.chat.progressLogUpdate(aiOperationId, 0.4, "Calling AI for image generation")
|
||||
|
||||
# Call via generic path (no looping for images)
|
||||
request = AiCallRequest(
|
||||
prompt=prompt,
|
||||
context="",
|
||||
options=options
|
||||
)
|
||||
|
||||
response = await self.aiObjects.call(request)
|
||||
|
||||
# Extract image data from response
|
||||
if response.content:
|
||||
# For base64 format, return in expected format
|
||||
if outputFormat == "base64":
|
||||
result = {
|
||||
"success": True,
|
||||
"image_data": response.content,
|
||||
"documents": [{
|
||||
"documentName": "generated_image.png",
|
||||
"documentData": response.content,
|
||||
"mimeType": "image/png",
|
||||
"title": title or "Generated Image"
|
||||
}]
|
||||
}
|
||||
else:
|
||||
# Return raw content for other formats
|
||||
result = response.content
|
||||
|
||||
# Emit stats for image generation
|
||||
self.services.chat.storeWorkflowStat(
|
||||
self.services.workflow,
|
||||
response,
|
||||
f"ai.generate.image"
|
||||
)
|
||||
|
||||
self.services.chat.progressLogUpdate(aiOperationId, 0.9, "Image generated")
|
||||
self.services.chat.progressLogFinish(aiOperationId, True)
|
||||
return result
|
||||
else:
|
||||
errorMsg = f"No image data returned: {response.content}"
|
||||
logger.error(f"Error in AI image generation: {errorMsg}")
|
||||
self.services.chat.progressLogFinish(aiOperationId, False)
|
||||
return {"success": False, "error": errorMsg}
|
||||
|
||||
# Handle WEB_SEARCH and WEB_CRAWL operations - route directly to connectors
|
||||
# These operations require raw JSON prompts that connectors parse directly
|
||||
# Must check BEFORE document generation to avoid wrapping the prompt
|
||||
isWebOperation = (opType == OperationTypeEnum.WEB_SEARCH or opType == OperationTypeEnum.WEB_CRAWL)
|
||||
|
||||
if isWebOperation:
|
||||
# Web operations: prompt is already structured JSON (AiCallPromptWebSearch/WebCrawl)
|
||||
# Route directly through centralized AI call - model selector chooses appropriate connector
|
||||
# Connector parses the JSON prompt and executes the operation
|
||||
self.services.chat.progressLogUpdate(aiOperationId, 0.4, f"Calling AI for {opType.name}")
|
||||
|
||||
request = AiCallRequest(
|
||||
prompt=prompt, # Pass raw JSON prompt unchanged - connector will parse it
|
||||
context="",
|
||||
options=options
|
||||
)
|
||||
|
||||
response = await self.aiObjects.call(request)
|
||||
|
||||
# Extract result from response
|
||||
if response.content:
|
||||
# Emit stats for web operation
|
||||
self.services.chat.storeWorkflowStat(
|
||||
self.services.workflow,
|
||||
response,
|
||||
f"ai.{opType.name.lower()}"
|
||||
)
|
||||
|
||||
self.services.chat.progressLogUpdate(aiOperationId, 0.9, f"{opType.name} completed")
|
||||
self.services.chat.progressLogFinish(aiOperationId, True)
|
||||
return response.content
|
||||
else:
|
||||
errorMsg = f"No content returned from {opType.name}: {response.content}"
|
||||
logger.error(f"Error in {opType.name}: {errorMsg}")
|
||||
self.services.chat.progressLogFinish(aiOperationId, False)
|
||||
return {"success": False, "error": errorMsg}
|
||||
|
||||
# CRITICAL: For document generation with JSON templates, NEVER compress the prompt
|
||||
# Compressing would truncate the template structure and confuse the AI
|
||||
if outputFormat: # Document generation with structured output
|
||||
if not options:
|
||||
options = AiCallOptions()
|
||||
options.compressPrompt = False # JSON templates must NOT be truncated
|
||||
options.compressContext = False # Context also should not be compressed
|
||||
|
||||
# Handle document generation with specific output format using unified approach
|
||||
if outputFormat:
|
||||
# Use unified generation method for all document generation
|
||||
if documents and len(documents) > 0:
|
||||
self.services.chat.progressLogUpdate(aiOperationId, 0.2, f"Extracting content from {len(documents)} documents")
|
||||
extracted_content = await self.callAiText(prompt, documents, options, aiOperationId)
|
||||
else:
|
||||
self.services.chat.progressLogUpdate(aiOperationId, 0.2, "Preparing for direct generation")
|
||||
extracted_content = None
|
||||
|
||||
self.services.chat.progressLogUpdate(aiOperationId, 0.3, "Building generation prompt")
|
||||
from modules.services.serviceGeneration.subPromptBuilderGeneration import buildGenerationPrompt
|
||||
# First call without continuation context
|
||||
generation_prompt = await buildGenerationPrompt(outputFormat, prompt, title, extracted_content, None)
|
||||
|
||||
# Prepare prompt builder arguments for continuation
|
||||
promptArgs = {
|
||||
"outputFormat": outputFormat,
|
||||
"userPrompt": prompt,
|
||||
"title": title,
|
||||
"extracted_content": extracted_content
|
||||
}
|
||||
|
||||
self.services.chat.progressLogUpdate(aiOperationId, 0.4, "Calling AI for content generation")
|
||||
generated_json = await self._callAiWithLooping(
|
||||
generation_prompt,
|
||||
options,
|
||||
"document_generation",
|
||||
buildGenerationPrompt,
|
||||
promptArgs,
|
||||
aiOperationId
|
||||
)
|
||||
|
||||
self.services.chat.progressLogUpdate(aiOperationId, 0.7, "Parsing generated JSON")
|
||||
# Parse the generated JSON (extract fenced/embedded JSON first)
|
||||
try:
|
||||
extracted_json = self.services.utils.jsonExtractString(generated_json)
|
||||
generated_data = json.loads(extracted_json)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse generated JSON: {str(e)}")
|
||||
logger.error(f"JSON content length: {len(generated_json)}")
|
||||
logger.error(f"JSON content preview (last 200 chars): ...{generated_json[-200:]}")
|
||||
logger.error(f"JSON content around error position: {generated_json[max(0, e.pos-50):e.pos+50]}")
|
||||
|
||||
# Write the problematic JSON to debug file
|
||||
self.services.utils.writeDebugFile(generated_json, "failed_json_parsing")
|
||||
|
||||
self.services.chat.progressLogFinish(aiOperationId, False)
|
||||
return {"success": False, "error": f"Generated content is not valid JSON: {str(e)}"}
|
||||
|
||||
# Extract title and filename from generated document structure
|
||||
extractedTitle = title # Default to user-provided title
|
||||
extractedFilename = None
|
||||
if isinstance(generated_data, dict) and "documents" in generated_data:
|
||||
documents = generated_data["documents"]
|
||||
if isinstance(documents, list) and len(documents) > 0:
|
||||
firstDoc = documents[0]
|
||||
if isinstance(firstDoc, dict):
|
||||
# Extract title from document (preferred over user-provided title)
|
||||
if firstDoc.get("title"):
|
||||
extractedTitle = firstDoc["title"]
|
||||
# Extract filename from document
|
||||
if firstDoc.get("filename"):
|
||||
extractedFilename = firstDoc["filename"]
|
||||
|
||||
# Ensure metadata contains the extracted title for renderers
|
||||
if "metadata" not in generated_data:
|
||||
generated_data["metadata"] = {}
|
||||
if extractedTitle:
|
||||
generated_data["metadata"]["title"] = extractedTitle
|
||||
|
||||
self.services.chat.progressLogUpdate(aiOperationId, 0.8, f"Rendering to {outputFormat} format")
|
||||
# Render to final format using the existing renderer
|
||||
try:
|
||||
from modules.services.serviceGeneration.mainServiceGeneration import GenerationService
|
||||
generationService = GenerationService(self.services)
|
||||
# Pass extracted title to renderer (will use metadata.title if available)
|
||||
rendered_content, mime_type = await generationService.renderReport(
|
||||
generated_data, outputFormat, extractedTitle or "Generated Document", prompt, self
|
||||
)
|
||||
|
||||
# Use extracted filename if available, otherwise generate from title or use generic
|
||||
if extractedFilename:
|
||||
documentName = extractedFilename
|
||||
elif extractedTitle and extractedTitle != "Generated Document":
|
||||
# Sanitize title for filename
|
||||
sanitized = re.sub(r"[^a-zA-Z0-9._-]", "_", extractedTitle)
|
||||
sanitized = re.sub(r"_+", "_", sanitized).strip("_")
|
||||
if sanitized:
|
||||
# Ensure correct extension
|
||||
if not sanitized.lower().endswith(f".{outputFormat}"):
|
||||
documentName = f"{sanitized}.{outputFormat}"
|
||||
else:
|
||||
documentName = sanitized
|
||||
else:
|
||||
documentName = f"generated.{outputFormat}"
|
||||
else:
|
||||
documentName = f"generated.{outputFormat}"
|
||||
|
||||
# Build result in the expected format
|
||||
result = {
|
||||
"success": True,
|
||||
"content": generated_data,
|
||||
"documents": [{
|
||||
"documentName": documentName,
|
||||
"documentData": rendered_content,
|
||||
"mimeType": mime_type,
|
||||
"title": extractedTitle or "Generated Document"
|
||||
}],
|
||||
"is_multi_file": False,
|
||||
"format": outputFormat,
|
||||
"title": extractedTitle or title,
|
||||
"split_strategy": "single",
|
||||
"total_documents": 1,
|
||||
"processed_documents": 1
|
||||
}
|
||||
|
||||
# Log AI response for debugging
|
||||
self.services.utils.writeDebugFile(str(result), "document_generation_response", documents)
|
||||
|
||||
self.services.chat.progressLogFinish(aiOperationId, True)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error rendering document: {str(e)}")
|
||||
self.services.chat.progressLogFinish(aiOperationId, False)
|
||||
return {"success": False, "error": f"Rendering failed: {str(e)}"}
|
||||
|
||||
# Handle text calls (no output format specified)
|
||||
self.services.chat.progressLogUpdate(aiOperationId, 0.5, "Processing text call")
|
||||
if documents:
|
||||
# Use document processing for text calls with documents
|
||||
result = await self.callAiText(prompt, documents, options, aiOperationId)
|
||||
else:
|
||||
# Use shared core function for direct text calls
|
||||
result = await self._callAiWithLooping(prompt, options, "text", None, None, aiOperationId)
|
||||
|
||||
self.services.chat.progressLogFinish(aiOperationId, True)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in callAiDocuments: {str(e)}")
|
||||
self.services.chat.progressLogFinish(aiOperationId, False)
|
||||
raise
|
||||
|
||||
async def callAiText(
|
||||
self,
|
||||
prompt: str,
|
||||
documents: Optional[List[ChatDocument]],
|
||||
options: AiCallOptions,
|
||||
operationId: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Handle text calls with document processing through ExtractionService.
|
||||
UNIFIED PROCESSING: Always use per-chunk processing for consistency.
|
||||
"""
|
||||
await self._ensureAiObjectsInitialized()
|
||||
return await self.extractionService.processDocumentsPerChunk(documents, prompt, self.aiObjects, options, operationId)
|
||||
|
||||
977
modules/services/serviceChat/mainServiceChat.py
Normal file
977
modules/services/serviceChat/mainServiceChat.py
Normal file
|
|
@ -0,0 +1,977 @@
|
|||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from modules.datamodels.datamodelUam import User, UserConnection
|
||||
from modules.datamodels.datamodelChat import ChatDocument, ChatMessage, ChatStat, ChatLog
|
||||
from modules.datamodels.datamodelAi import AiCallOptions, OperationTypeEnum, PriorityEnum, ProcessingModeEnum
|
||||
from modules.security.tokenManager import TokenManager
|
||||
from modules.shared.progressLogger import ProgressLogger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ChatService:
|
||||
"""Service class containing methods for document processing, chat operations, and workflow management"""
|
||||
|
||||
def __init__(self, serviceCenter):
|
||||
self.services = serviceCenter
|
||||
self.user = serviceCenter.user
|
||||
# self.services.workflow is now the ChatWorkflow object (stable during workflow execution)
|
||||
self.interfaceDbChat = serviceCenter.interfaceDbChat
|
||||
self.interfaceDbComponent = serviceCenter.interfaceDbComponent
|
||||
self.interfaceDbApp = serviceCenter.interfaceDbApp
|
||||
self._progressLogger = None
|
||||
|
||||
def getChatDocumentsFromDocumentList(self, documentList: List[str]) -> List[ChatDocument]:
|
||||
"""Get ChatDocuments from a list of document references using all three formats."""
|
||||
try:
|
||||
# Use self.services.workflow which is the ChatWorkflow object (stable during workflow execution)
|
||||
workflow = self.services.workflow
|
||||
if not workflow:
|
||||
logger.error("getChatDocumentsFromDocumentList: No workflow available (self.services.workflow is not set)")
|
||||
return []
|
||||
|
||||
workflowId = workflow.id if hasattr(workflow, 'id') else 'NO_ID'
|
||||
workflowObjId = id(workflow)
|
||||
logger.debug(f"getChatDocumentsFromDocumentList: input documentList = {documentList}")
|
||||
logger.debug(f"getChatDocumentsFromDocumentList: using workflow.id = {workflowId}, workflow object id = {workflowObjId}")
|
||||
|
||||
# Root cause analysis: Verify workflow.messages integrity and detect workflow changes
|
||||
self._verifyWorkflowMessagesIntegrity(workflow, workflowId)
|
||||
|
||||
# Debug: list available messages with their labels and document names (filtered by workflowId)
|
||||
try:
|
||||
if workflow and hasattr(workflow, 'messages') and workflow.messages:
|
||||
msgLines = []
|
||||
messagesFromOtherWorkflows = []
|
||||
for message in workflow.messages:
|
||||
msgWorkflowId = getattr(message, 'workflowId', None)
|
||||
# Only include messages that belong to this workflow
|
||||
if msgWorkflowId and msgWorkflowId != workflowId:
|
||||
messagesFromOtherWorkflows.append(f"id={getattr(message, 'id', None)}, label={getattr(message, 'documentsLabel', None)}, workflowId={msgWorkflowId}")
|
||||
continue
|
||||
# Also skip messages without workflowId (shouldn't happen, but be safe)
|
||||
if not msgWorkflowId:
|
||||
messagesFromOtherWorkflows.append(f"id={getattr(message, 'id', None)}, label={getattr(message, 'documentsLabel', None)}, workflowId=Missing")
|
||||
continue
|
||||
|
||||
label = getattr(message, 'documentsLabel', None)
|
||||
docNames = []
|
||||
if getattr(message, 'documents', None):
|
||||
for doc in message.documents:
|
||||
name = getattr(doc, 'fileName', None) or getattr(doc, 'documentName', None) or 'Unnamed'
|
||||
docNames.append(name)
|
||||
msgLines.append(
|
||||
f"- id={getattr(message, 'id', None)}, label={label}, workflowId={msgWorkflowId}, docs={docNames}"
|
||||
)
|
||||
if msgLines:
|
||||
logger.debug("getChatDocumentsFromDocumentList: available messages (filtered for workflow):\n" + "\n".join(msgLines))
|
||||
if messagesFromOtherWorkflows:
|
||||
logger.warning(f"getChatDocumentsFromDocumentList: Found {len(messagesFromOtherWorkflows)} messages from other workflows in workflow.messages list:\n" + "\n".join(messagesFromOtherWorkflows))
|
||||
else:
|
||||
logger.debug("getChatDocumentsFromDocumentList: no messages available on current workflow")
|
||||
except Exception as e:
|
||||
logger.debug(f"getChatDocumentsFromDocumentList: unable to enumerate messages for debug: {e}")
|
||||
|
||||
allDocuments = []
|
||||
for docRef in documentList:
|
||||
if docRef.startswith("docItem:"):
|
||||
# docItem:<id>:<filename> - extract ID and find document
|
||||
parts = docRef.split(':')
|
||||
if len(parts) >= 2:
|
||||
docId = parts[1]
|
||||
# Find the document by ID
|
||||
for message in workflow.messages:
|
||||
# Validate message belongs to this workflow
|
||||
msgWorkflowId = getattr(message, 'workflowId', None)
|
||||
if not msgWorkflowId or msgWorkflowId != workflowId:
|
||||
continue
|
||||
|
||||
if message.documents:
|
||||
for doc in message.documents:
|
||||
if doc.id == docId:
|
||||
docName = getattr(doc, 'fileName', 'unknown')
|
||||
allDocuments.append(doc)
|
||||
break
|
||||
elif docRef.startswith("docList:"):
|
||||
# docList:<messageId>:<label> or docList:<label> - extract message ID and find document list
|
||||
parts = docRef.split(':')
|
||||
if len(parts) >= 3:
|
||||
# Format: docList:<messageId>:<label>
|
||||
messageId = parts[1]
|
||||
label = parts[2]
|
||||
# First try to find the message by ID in the current workflow
|
||||
messageFound = None
|
||||
for message in workflow.messages:
|
||||
# Validate message belongs to this workflow
|
||||
msgWorkflowId = getattr(message, 'workflowId', None)
|
||||
if not msgWorkflowId or msgWorkflowId != workflowId:
|
||||
continue
|
||||
|
||||
if str(message.id) == messageId:
|
||||
messageFound = message
|
||||
break
|
||||
|
||||
# If message ID not found in current workflow, this is a stale reference
|
||||
# Log warning and return empty list (don't fall back to label - it might match wrong message)
|
||||
if not messageFound:
|
||||
availableIds = [str(msg.id) for msg in workflow.messages]
|
||||
logger.warning(f"Document reference contains stale message ID {messageId} not found in current workflow {workflow.id}. Label: {label}. Available message IDs: {availableIds}")
|
||||
logger.warning(f"This indicates the document reference was created in a different workflow state. Returning empty list.")
|
||||
# Return empty list - don't fall back to label matching which could match wrong message
|
||||
continue
|
||||
|
||||
# If found, add documents
|
||||
if messageFound and messageFound.documents:
|
||||
allDocuments.extend(messageFound.documents)
|
||||
elif len(parts) >= 2:
|
||||
# Format: docList:<label> - find message by documentsLabel
|
||||
label = parts[1]
|
||||
messageFound = None
|
||||
for message in workflow.messages:
|
||||
# Validate message belongs to this workflow
|
||||
msgWorkflowId = getattr(message, 'workflowId', None)
|
||||
if not msgWorkflowId or msgWorkflowId != workflowId:
|
||||
if msgWorkflowId:
|
||||
logger.warning(f"Message {message.id} has workflowId {msgWorkflowId} but belongs to workflow {workflowId}. Skipping.")
|
||||
else:
|
||||
logger.warning(f"Message {message.id} has no workflowId. Skipping.")
|
||||
continue
|
||||
|
||||
msgLabel = getattr(message, 'documentsLabel', None)
|
||||
if msgLabel == label:
|
||||
messageFound = message
|
||||
break
|
||||
|
||||
# If found, add documents
|
||||
if messageFound and messageFound.documents:
|
||||
allDocuments.extend(messageFound.documents)
|
||||
else:
|
||||
# Direct label reference - can be round1_task2_action3_contextinfo format or simple label
|
||||
# Search for messages with matching documentsLabel to find the actual documents
|
||||
matchingMessages = []
|
||||
for message in workflow.messages:
|
||||
# Validate message belongs to this workflow
|
||||
msgWorkflowId = getattr(message, 'workflowId', None)
|
||||
if not msgWorkflowId or msgWorkflowId != workflowId:
|
||||
if msgWorkflowId:
|
||||
logger.debug(f"Skipping message {message.id} with workflowId {msgWorkflowId} (expected {workflowId})")
|
||||
else:
|
||||
logger.debug(f"Skipping message {message.id} with no workflowId (expected {workflowId})")
|
||||
continue
|
||||
|
||||
msgDocumentsLabel = getattr(message, 'documentsLabel', '')
|
||||
|
||||
# Check if this message's documentsLabel matches our reference
|
||||
if msgDocumentsLabel == docRef:
|
||||
# Found a matching message, collect it for comparison
|
||||
matchingMessages.append(message)
|
||||
|
||||
# If we found matching messages, take the newest one (highest publishedAt)
|
||||
if matchingMessages:
|
||||
# Sort by publishedAt descending (newest first)
|
||||
matchingMessages.sort(key=lambda msg: getattr(msg, 'publishedAt', 0), reverse=True)
|
||||
newestMessage = matchingMessages[0]
|
||||
|
||||
if newestMessage.documents:
|
||||
docNames = [doc.fileName for doc in newestMessage.documents if hasattr(doc, 'fileName')]
|
||||
logger.debug(f"Added {len(newestMessage.documents)} documents from newest message {newestMessage.id}: {docNames}")
|
||||
allDocuments.extend(newestMessage.documents)
|
||||
else:
|
||||
logger.debug(f"No documents found in newest message {newestMessage.id}")
|
||||
else:
|
||||
logger.error(f"No messages found with documentsLabel: {docRef}")
|
||||
raise ValueError(f"Document reference not found: {docRef}")
|
||||
|
||||
logger.debug(f"Resolved {len(allDocuments)} documents from document list: {documentList}")
|
||||
return allDocuments
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting documents from document list: {str(e)}")
|
||||
return []
|
||||
|
||||
def _verifyWorkflowMessagesIntegrity(self, workflow, expectedWorkflowId: str) -> None:
|
||||
"""
|
||||
Verify that all messages in workflow.messages belong to the expected workflow.
|
||||
This helps detect when workflow objects are being mixed up or when messages from
|
||||
other workflows are incorrectly included.
|
||||
"""
|
||||
try:
|
||||
if not workflow or not hasattr(workflow, 'messages') or not workflow.messages:
|
||||
return
|
||||
|
||||
messagesFromOtherWorkflows = []
|
||||
messagesWithoutWorkflowId = []
|
||||
totalMessages = len(workflow.messages)
|
||||
|
||||
for message in workflow.messages:
|
||||
msgWorkflowId = getattr(message, 'workflowId', None)
|
||||
if not msgWorkflowId:
|
||||
messagesWithoutWorkflowId.append({
|
||||
'id': getattr(message, 'id', 'unknown'),
|
||||
'label': getattr(message, 'documentsLabel', None)
|
||||
})
|
||||
elif msgWorkflowId != expectedWorkflowId:
|
||||
messagesFromOtherWorkflows.append({
|
||||
'id': getattr(message, 'id', 'unknown'),
|
||||
'label': getattr(message, 'documentsLabel', None),
|
||||
'workflowId': msgWorkflowId,
|
||||
'expectedWorkflowId': expectedWorkflowId
|
||||
})
|
||||
|
||||
if messagesFromOtherWorkflows:
|
||||
logger.error(
|
||||
f"CRITICAL: Workflow integrity violation detected! "
|
||||
f"Workflow {expectedWorkflowId} contains {len(messagesFromOtherWorkflows)} messages from other workflows. "
|
||||
f"Total messages: {totalMessages}. "
|
||||
f"Foreign messages: {messagesFromOtherWorkflows}"
|
||||
)
|
||||
|
||||
if messagesWithoutWorkflowId:
|
||||
logger.warning(
|
||||
f"Workflow integrity issue: Workflow {expectedWorkflowId} contains {len(messagesWithoutWorkflowId)} messages without workflowId. "
|
||||
f"Messages: {messagesWithoutWorkflowId}"
|
||||
)
|
||||
|
||||
# Also check if self.services.workflow has changed (workflow object ID mismatch)
|
||||
currentWorkflow = self.services.workflow
|
||||
if currentWorkflow and hasattr(currentWorkflow, 'id'):
|
||||
currentWorkflowId = currentWorkflow.id
|
||||
if currentWorkflowId != expectedWorkflowId:
|
||||
logger.error(
|
||||
f"CRITICAL: Workflow object changed during execution! "
|
||||
f"Expected workflow {expectedWorkflowId}, but self.services.workflow now points to {currentWorkflowId}. "
|
||||
f"This indicates the workflow object was swapped mid-execution."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error during workflow integrity verification: {e}")
|
||||
|
||||
def getConnectionReferenceFromUserConnection(self, connection: UserConnection) -> str:
|
||||
"""Get connection reference from UserConnection with enhanced state information"""
|
||||
# Get token information to check if it's expired
|
||||
token = None
|
||||
token_status = "unknown"
|
||||
try:
|
||||
# Get a fresh token via TokenManager convenience method
|
||||
logger.debug(f"Getting fresh token for connection {connection.id}")
|
||||
token = TokenManager().getFreshToken(connection.id)
|
||||
if token:
|
||||
if hasattr(token, 'expiresAt') and token.expiresAt:
|
||||
current_time = self.services.utils.timestampGetUtc()
|
||||
if current_time > token.expiresAt:
|
||||
token_status = "expired"
|
||||
else:
|
||||
# Check if this token was recently refreshed (within last 5 minutes)
|
||||
time_since_creation = current_time - token.createdAt if hasattr(token, 'createdAt') else 0
|
||||
if time_since_creation < 300: # 5 minutes
|
||||
token_status = "valid (refreshed)"
|
||||
else:
|
||||
token_status = "valid"
|
||||
else:
|
||||
token_status = "no_expiration"
|
||||
else:
|
||||
token_status = "no_token"
|
||||
except Exception as e:
|
||||
token_status = f"error: {str(e)}"
|
||||
|
||||
# Build enhanced reference with state information
|
||||
# Format: connection:msft:<username> (without UUID)
|
||||
base_ref = f"connection:{connection.authority.value}:{connection.externalUsername}"
|
||||
state_info = f" [status:{connection.status.value}, token:{token_status}]"
|
||||
|
||||
logger.debug(f"getConnectionReferenceFromUserConnection: Built reference: {base_ref + state_info}")
|
||||
return base_ref + state_info
|
||||
|
||||
def getUserConnectionByExternalUsername(self, authority: str, externalUsername: str) -> Optional[UserConnection]:
|
||||
"""Fetch the user's connection by authority and external username."""
|
||||
try:
|
||||
if not authority or not externalUsername:
|
||||
return None
|
||||
user_connections = self.interfaceDbApp.getUserConnections(self.user.id)
|
||||
for connection in user_connections:
|
||||
# Normalize authority for comparison (enum vs string)
|
||||
connection_authority = connection.authority.value if hasattr(connection.authority, 'value') else str(connection.authority)
|
||||
if connection_authority.lower() == authority.lower() and connection.externalUsername == externalUsername:
|
||||
return connection
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting connection by external username: {str(e)}")
|
||||
return None
|
||||
|
||||
def getUserConnectionFromConnectionReference(self, connectionReference: str) -> Optional[UserConnection]:
|
||||
"""Get UserConnection from reference string (handles new format without UUID)"""
|
||||
try:
|
||||
# Parse reference format: connection:{authority}:{username} [status:..., token:...]
|
||||
# Remove state information if present
|
||||
base_reference = connectionReference.split(' [')[0]
|
||||
|
||||
parts = base_reference.split(':')
|
||||
if len(parts) != 3 or parts[0] != "connection":
|
||||
return None
|
||||
|
||||
authority = parts[1]
|
||||
username = parts[2]
|
||||
|
||||
# Get user connections through AppObjects interface
|
||||
user_connections = self.interfaceDbApp.getUserConnections(self.user.id)
|
||||
|
||||
# Find matching connection by authority and username (no UUID needed)
|
||||
for conn in user_connections:
|
||||
if conn.authority.value == authority and conn.externalUsername == username:
|
||||
return conn
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing connection reference: {str(e)}")
|
||||
return None
|
||||
|
||||
def getFreshConnectionToken(self, connectionId: str):
|
||||
"""Get a fresh token for a specific connection (moved from UtilsService).
|
||||
|
||||
Args:
|
||||
connectionId: ID of the connection to get token for
|
||||
|
||||
Returns:
|
||||
Token object or None if not found/expired
|
||||
"""
|
||||
try:
|
||||
return TokenManager().getFreshToken(connectionId)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting fresh token for connection {connectionId}: {str(e)}")
|
||||
return None
|
||||
|
||||
def getFileInfo(self, fileId: str) -> Dict[str, Any]:
|
||||
"""Get file information"""
|
||||
file_item = self.interfaceDbComponent.getFile(fileId)
|
||||
if file_item:
|
||||
return {
|
||||
"id": file_item.id,
|
||||
"fileName": file_item.fileName,
|
||||
"size": file_item.fileSize,
|
||||
"mimeType": file_item.mimeType,
|
||||
"fileHash": file_item.fileHash,
|
||||
"creationDate": file_item.creationDate
|
||||
}
|
||||
return None
|
||||
|
||||
def getFileData(self, fileId: str) -> bytes:
|
||||
"""Get file data by ID"""
|
||||
return self.interfaceDbComponent.getFileData(fileId)
|
||||
|
||||
def _diagnoseDocumentAccess(self, document: ChatDocument) -> Dict[str, Any]:
|
||||
"""
|
||||
Diagnose document access issues and provide recovery information.
|
||||
This method helps identify why document properties are inaccessible.
|
||||
"""
|
||||
try:
|
||||
diagnosis = {
|
||||
'document_id': document.id,
|
||||
'file_id': document.fileId,
|
||||
'has_component_interface': document._componentInterface is not None,
|
||||
'component_interface_type': type(document._componentInterface).__name__ if document._componentInterface else None,
|
||||
'file_exists': False,
|
||||
'file_info': None,
|
||||
'error_details': None
|
||||
}
|
||||
|
||||
# Check if component interface is set
|
||||
if not document._componentInterface:
|
||||
diagnosis['error_details'] = "Component interface not set - document cannot access file system"
|
||||
return diagnosis
|
||||
|
||||
# Try to access the file directly
|
||||
try:
|
||||
file_info = self.interfaceDbComponent.getFile(document.fileId)
|
||||
if file_info:
|
||||
diagnosis['file_exists'] = True
|
||||
diagnosis['file_info'] = {
|
||||
'fileName': file_info.fileName if hasattr(file_info, 'fileName') else 'N/A',
|
||||
'fileSize': file_info.fileSize if hasattr(file_info, 'fileSize') else 'N/A',
|
||||
'mimeType': file_info.mimeType if hasattr(file_info, 'mimeType') else 'N/A'
|
||||
}
|
||||
else:
|
||||
diagnosis['error_details'] = f"File with ID {document.fileId} not found in component interface"
|
||||
except Exception as e:
|
||||
diagnosis['error_details'] = f"Error accessing file {document.fileId}: {str(e)}"
|
||||
|
||||
return diagnosis
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
'document_id': document.id if hasattr(document, 'id') else 'unknown',
|
||||
'file_id': document.fileId if hasattr(document, 'fileId') else 'unknown',
|
||||
'error_details': f"Error during diagnosis: {str(e)}"
|
||||
}
|
||||
|
||||
def _recoverDocumentAccess(self, document: ChatDocument) -> bool:
|
||||
"""
|
||||
Attempt to recover document access by re-setting the component interface.
|
||||
Returns True if recovery was successful.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Attempting to recover document access for document {document.id}")
|
||||
|
||||
# Re-set the component interface
|
||||
document.setComponentInterface(self.interfaceDbComponent)
|
||||
|
||||
# Test if we can now access the fileName
|
||||
try:
|
||||
test_fileName = document.fileName
|
||||
logger.info(f"Document access recovered for {document.id} -> {test_fileName}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Document access recovery failed for {document.id}: {str(e)}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during document access recovery for {document.id}: {str(e)}")
|
||||
return False
|
||||
|
||||
def calculateObjectSize(self, obj: Any) -> int:
|
||||
"""
|
||||
Calculate the size of an object in bytes.
|
||||
|
||||
Args:
|
||||
obj: Object to calculate size for
|
||||
|
||||
Returns:
|
||||
int: Size in bytes
|
||||
"""
|
||||
try:
|
||||
import json
|
||||
import sys
|
||||
|
||||
if obj is None:
|
||||
return 0
|
||||
|
||||
# Convert object to JSON string and calculate size
|
||||
json_str = json.dumps(obj, ensure_ascii=False, default=str)
|
||||
return len(json_str.encode('utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating object size: {str(e)}")
|
||||
return 0
|
||||
|
||||
def getWorkflowContext(self) -> Dict[str, int]:
|
||||
"""Get current workflow context for document generation"""
|
||||
try:
|
||||
workflow = self.services.workflow
|
||||
if not workflow:
|
||||
return {'currentRound': 0, 'currentTask': 0, 'currentAction': 0}
|
||||
return {
|
||||
'currentRound': workflow.currentRound if hasattr(workflow, 'currentRound') else 0,
|
||||
'currentTask': workflow.currentTask if hasattr(workflow, 'currentTask') else 0,
|
||||
'currentAction': workflow.currentAction if hasattr(workflow, 'currentAction') else 0
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow context: {str(e)}")
|
||||
return {'currentRound': 0, 'currentTask': 0, 'currentAction': 0}
|
||||
|
||||
def setWorkflowContext(self, roundNumber: int = None, taskNumber: int = None, actionNumber: int = None):
|
||||
"""Set current workflow context for document generation and routing"""
|
||||
try:
|
||||
workflow = self.services.workflow
|
||||
if not workflow:
|
||||
logger.error("setWorkflowContext: No workflow available")
|
||||
return
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
|
||||
if roundNumber is not None:
|
||||
workflow.currentRound = roundNumber
|
||||
update_data["currentRound"] = roundNumber
|
||||
if taskNumber is not None:
|
||||
workflow.currentTask = taskNumber
|
||||
update_data["currentTask"] = taskNumber
|
||||
if actionNumber is not None:
|
||||
workflow.currentAction = actionNumber
|
||||
update_data["currentAction"] = actionNumber
|
||||
|
||||
# Persist changes to database if any updates were made
|
||||
if update_data:
|
||||
self.interfaceDbChat.updateWorkflow(workflow.id, update_data)
|
||||
|
||||
logger.debug(f"Updated workflow context: Round {workflow.currentRound if hasattr(workflow, 'currentRound') else 'N/A'}, Task {workflow.currentTask if hasattr(workflow, 'currentTask') else 'N/A'}, Action {workflow.currentAction if hasattr(workflow, 'currentAction') else 'N/A'}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting workflow context: {str(e)}")
|
||||
|
||||
def getWorkflowStats(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive workflow statistics including current context"""
|
||||
try:
|
||||
workflow = self.services.workflow
|
||||
workflow_context = self.getWorkflowContext()
|
||||
if not workflow:
|
||||
return {
|
||||
'currentRound': workflow_context['currentRound'],
|
||||
'currentTask': workflow_context['currentTask'],
|
||||
'currentAction': workflow_context['currentAction'],
|
||||
'totalTasks': 0,
|
||||
'totalActions': 0,
|
||||
'workflowStatus': 'unknown',
|
||||
'workflowId': 'unknown'
|
||||
}
|
||||
return {
|
||||
'currentRound': workflow_context['currentRound'],
|
||||
'currentTask': workflow_context['currentTask'],
|
||||
'currentAction': workflow_context['currentAction'],
|
||||
'totalTasks': workflow.totalTasks if hasattr(workflow, 'totalTasks') else 0,
|
||||
'totalActions': workflow.totalActions if hasattr(workflow, 'totalActions') else 0,
|
||||
'workflowStatus': workflow.status if hasattr(workflow, 'status') else 'unknown',
|
||||
'workflowId': workflow.id if hasattr(workflow, 'id') else 'unknown'
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow stats: {str(e)}")
|
||||
return {
|
||||
'currentRound': 0,
|
||||
'currentTask': 0,
|
||||
'currentAction': 0,
|
||||
'totalTasks': 0,
|
||||
'totalActions': 0,
|
||||
'workflowStatus': 'unknown',
|
||||
'workflowId': 'unknown'
|
||||
}
|
||||
|
||||
def createWorkflow(self, workflowData: Dict[str, Any]):
|
||||
"""Create a new workflow by delegating to the chat interface"""
|
||||
try:
|
||||
return self.interfaceDbChat.createWorkflow(workflowData)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating workflow: {str(e)}")
|
||||
raise
|
||||
|
||||
def updateWorkflow(self, workflowId: str, updateData: Dict[str, Any]):
|
||||
"""Update workflow by delegating to the chat interface"""
|
||||
try:
|
||||
return self.interfaceDbChat.updateWorkflow(workflowId, updateData)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating workflow: {str(e)}")
|
||||
raise
|
||||
|
||||
def getWorkflow(self, workflowId: str):
|
||||
"""Get workflow by ID by delegating to the chat interface"""
|
||||
try:
|
||||
logger.debug(f"getWorkflow called with workflowId: {workflowId}")
|
||||
result = self.interfaceDbChat.getWorkflow(workflowId)
|
||||
if result:
|
||||
logger.debug(f"getWorkflow returned workflow with ID: {result.id}")
|
||||
else:
|
||||
logger.warning(f"getWorkflow returned None for workflowId: {workflowId}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow: {str(e)}")
|
||||
raise
|
||||
|
||||
# === Service-level transactions (DB write-through + in-memory sync) ===
|
||||
|
||||
def storeMessageWithDocuments(self, workflow: Any, messageData: Dict[str, Any], documents: List[Any]) -> ChatMessage:
|
||||
"""Persist message and documents, then bind them into in-memory workflow (replace-by-id)."""
|
||||
# Ensure workflowId on message
|
||||
messageData = dict(messageData or {})
|
||||
messageData["workflowId"] = workflow.id
|
||||
# Attach documents to message creation via interface (it persists message then docs)
|
||||
messageDataWithDocs = dict(messageData)
|
||||
messageDataWithDocs["documents"] = documents or []
|
||||
chatInterface = self.interfaceDbChat
|
||||
chatMessage = chatInterface.createMessage(messageDataWithDocs)
|
||||
if not chatMessage:
|
||||
raise ValueError("Failed to create message with documents")
|
||||
# In-memory sync: replace or append
|
||||
# replace-by-id if exists
|
||||
replaced = False
|
||||
for i, m in enumerate(workflow.messages or []):
|
||||
if getattr(m, 'id', None) == getattr(chatMessage, 'id', None):
|
||||
workflow.messages[i] = chatMessage
|
||||
replaced = True
|
||||
break
|
||||
if not replaced:
|
||||
workflow.messages.append(chatMessage)
|
||||
return chatMessage
|
||||
|
||||
def storeLog(self, workflow: Any, logData: Dict[str, Any]) -> ChatLog:
|
||||
"""Persist ChatLog and map it into the in-memory workflow logs list."""
|
||||
logData = dict(logData or {})
|
||||
logData["workflowId"] = workflow.id
|
||||
chatInterface = self.interfaceDbChat
|
||||
chatLog = chatInterface.createLog(logData)
|
||||
if not chatLog:
|
||||
raise ValueError("Failed to create log")
|
||||
# replace-by-id if exists
|
||||
replaced = False
|
||||
for i, lg in enumerate(workflow.logs):
|
||||
if getattr(lg, 'id', None) == getattr(chatLog, 'id', None):
|
||||
workflow.logs[i] = chatLog
|
||||
replaced = True
|
||||
break
|
||||
if not replaced:
|
||||
workflow.logs.append(chatLog)
|
||||
return chatLog
|
||||
|
||||
def storeWorkflowStat(self, workflow: Any, aiResponse: Any, process: str) -> ChatStat:
|
||||
"""Persist workflow-level ChatStat from AiCallResponse and append to workflow stats list."""
|
||||
try:
|
||||
# Create ChatStat from AiCallResponse data
|
||||
statData = {
|
||||
"workflowId": workflow.id,
|
||||
"process": process,
|
||||
"engine": aiResponse.modelName,
|
||||
"priceUsd": aiResponse.priceUsd,
|
||||
"processingTime": aiResponse.processingTime,
|
||||
"bytesSent": aiResponse.bytesSent,
|
||||
"bytesReceived": aiResponse.bytesReceived,
|
||||
"errorCount": aiResponse.errorCount
|
||||
}
|
||||
|
||||
# Create the stat record in the database
|
||||
stat = self.interfaceDbChat.createStat(statData)
|
||||
|
||||
# Append to workflow stats list in memory
|
||||
if not hasattr(workflow, 'stats') or workflow.stats is None:
|
||||
workflow.stats = []
|
||||
workflow.stats.append(stat)
|
||||
|
||||
return stat
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store workflow stat: {e}")
|
||||
raise
|
||||
|
||||
def updateMessage(self, messageId: str, messageData: Dict[str, Any]):
|
||||
"""Update message by delegating to the chat interface"""
|
||||
try:
|
||||
return self.interfaceDbChat.updateMessage(messageId, messageData)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating message: {str(e)}")
|
||||
raise
|
||||
|
||||
def getDocumentCount(self) -> str:
|
||||
"""Get document count for task planning (matching old handlingTasks.py logic)"""
|
||||
try:
|
||||
workflow = self.services.workflow
|
||||
if not workflow:
|
||||
return "No documents available"
|
||||
|
||||
# Count documents from all messages in the workflow (like old system)
|
||||
total_docs = 0
|
||||
for message in workflow.messages:
|
||||
if hasattr(message, 'documents') and message.documents:
|
||||
total_docs += len(message.documents)
|
||||
|
||||
if total_docs == 0:
|
||||
return "No documents available"
|
||||
|
||||
return f"{total_docs} document(s) available"
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting document count: {str(e)}")
|
||||
return "No documents available"
|
||||
|
||||
def getWorkflowHistoryContext(self) -> str:
|
||||
"""Get workflow history context for task planning (matching old handlingTasks.py logic)"""
|
||||
try:
|
||||
workflow = self.services.workflow
|
||||
if not workflow:
|
||||
return "No previous round context available"
|
||||
|
||||
# Check if there are any previous rounds by looking for "first" messages
|
||||
has_previous_rounds = False
|
||||
for message in workflow.messages:
|
||||
if hasattr(message, 'status') and message.status == "first":
|
||||
has_previous_rounds = True
|
||||
break
|
||||
|
||||
if not has_previous_rounds:
|
||||
return "No previous round context available"
|
||||
|
||||
# Get document reference list to show what documents are available from previous rounds
|
||||
document_list = self._getDocumentReferenceList(workflow)
|
||||
|
||||
# Build context string showing previous rounds
|
||||
context = "Previous workflow rounds contain documents:\n"
|
||||
|
||||
# Show history exchanges (previous rounds)
|
||||
if document_list["history"]:
|
||||
for exchange in document_list["history"]:
|
||||
# Use label-only format to avoid stale message ID references
|
||||
# Labels are stable identifiers that persist across workflow state changes
|
||||
doc_list_ref = f"docList:{exchange['documentsLabel']}"
|
||||
|
||||
context += f"- {doc_list_ref} ({len(exchange['documents'])} documents)\n"
|
||||
else:
|
||||
context = "No previous round context available"
|
||||
|
||||
return context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workflow history context: {str(e)}")
|
||||
return "No previous round context available"
|
||||
|
||||
def getAvailableDocuments(self, workflow) -> str:
|
||||
"""Get available documents formatted for AI prompts (exact copy of old ServiceCenter.getEnhancedDocumentContext)"""
|
||||
try:
|
||||
if not workflow or not hasattr(workflow, 'messages'):
|
||||
return "No documents available"
|
||||
|
||||
workflowId = workflow.id if hasattr(workflow, 'id') else 'NO_ID'
|
||||
workflowObjId = id(workflow)
|
||||
logger.debug(f"getAvailableDocuments: workflow.id = {workflowId}, workflow object id = {workflowObjId}")
|
||||
|
||||
# Root cause analysis: Verify workflow.messages integrity and detect workflow changes
|
||||
self._verifyWorkflowMessagesIntegrity(workflow, workflowId)
|
||||
|
||||
# Use the provided workflow object directly to avoid database reload issues
|
||||
# that can cause filename truncation. The workflow object should already be up-to-date.
|
||||
|
||||
# Get document reference list using the exact same logic as old system
|
||||
document_list = self._getDocumentReferenceList(workflow)
|
||||
|
||||
# Timestamp-only available documents index dump removed
|
||||
|
||||
# Build index string for AI action planning
|
||||
context = ""
|
||||
|
||||
# Process current round exchanges first
|
||||
if document_list["chat"]:
|
||||
context += "\nCurrent round documents:\n"
|
||||
for exchange in document_list["chat"]:
|
||||
# Use label-only format to avoid stale message ID references
|
||||
# Labels are stable identifiers that persist across workflow state changes
|
||||
doc_list_ref = f"docList:{exchange['documentsLabel']}"
|
||||
|
||||
context += f"- {doc_list_ref} contains:\n"
|
||||
# Generate docItem references for each document in the list
|
||||
for doc_ref in exchange['documents']:
|
||||
if doc_ref.startswith("docItem:"):
|
||||
context += f" - {doc_ref}\n"
|
||||
else:
|
||||
# Convert to proper docItem format if needed
|
||||
context += f" - docItem:{doc_ref}\n"
|
||||
context += "\n"
|
||||
|
||||
# Process previous rounds after
|
||||
if document_list["history"]:
|
||||
context += "\nPast rounds documents:\n"
|
||||
for exchange in document_list["history"]:
|
||||
# Use label-only format to avoid stale message ID references
|
||||
# Labels are stable identifiers that persist across workflow state changes
|
||||
doc_list_ref = f"docList:{exchange['documentsLabel']}"
|
||||
|
||||
context += f"- {doc_list_ref} contains:\n"
|
||||
# Generate docItem references for each document in the list
|
||||
for doc_ref in exchange['documents']:
|
||||
if doc_ref.startswith("docItem:"):
|
||||
context += f" - {doc_ref}\n"
|
||||
else:
|
||||
# Convert to proper docItem format if needed
|
||||
context += f" - docItem:{doc_ref}\n"
|
||||
context += "\n"
|
||||
|
||||
if not document_list["chat"] and not document_list["history"]:
|
||||
context += "\nNO DOCUMENTS AVAILABLE - This workflow has no documents to process.\n"
|
||||
|
||||
return context
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting available documents: {str(e)}")
|
||||
return "NO DOCUMENTS AVAILABLE - Error generating document context."
|
||||
|
||||
def _getDocumentReferenceList(self, workflow) -> Dict[str, List]:
|
||||
"""Get list of document exchanges with new labeling format, sorted by recency (exact copy of old system)"""
|
||||
# Collect all documents first and refresh their attributes
|
||||
all_documents = []
|
||||
for message in workflow.messages:
|
||||
if message.documents:
|
||||
all_documents.extend(message.documents)
|
||||
|
||||
# Refresh file attributes for all documents
|
||||
if all_documents:
|
||||
self._refreshDocumentFileAttributes(all_documents)
|
||||
|
||||
def _is_valid_document(doc) -> bool:
|
||||
try:
|
||||
size_ok = getattr(doc, 'fileSize', 0) and getattr(doc, 'fileSize', 0) > 0
|
||||
id_ok = bool(getattr(doc, 'fileId', None))
|
||||
mime_ok = bool(getattr(doc, 'mimeType', None))
|
||||
return size_ok and id_ok and mime_ok
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Simplified, deterministic logic:
|
||||
# - Walk messages newest-first
|
||||
# - For each document, assign it exactly once to a bucket based on the message round
|
||||
# - Never allow the same doc to appear in both buckets
|
||||
chat_exchanges = []
|
||||
history_exchanges = []
|
||||
seen_doc_ids = set()
|
||||
current_round = getattr(workflow, 'currentRound', None)
|
||||
|
||||
for message in reversed(workflow.messages):
|
||||
if not getattr(message, 'documents', None):
|
||||
continue
|
||||
|
||||
label = getattr(message, 'documentsLabel', None)
|
||||
if not label:
|
||||
# Skip messages without a label to keep references consistent
|
||||
continue
|
||||
|
||||
doc_refs = []
|
||||
for doc in message.documents:
|
||||
if not _is_valid_document(doc):
|
||||
continue
|
||||
# Avoid duplicates across chat/history
|
||||
doc_id = getattr(doc, 'id', None)
|
||||
if not doc_id or doc_id in seen_doc_ids:
|
||||
continue
|
||||
seen_doc_ids.add(doc_id)
|
||||
doc_ref = self.getDocumentReferenceFromChatDocument(doc)
|
||||
doc_refs.append(doc_ref)
|
||||
|
||||
if not doc_refs:
|
||||
continue
|
||||
|
||||
entry = {
|
||||
'documentsLabel': label,
|
||||
'documents': doc_refs
|
||||
}
|
||||
|
||||
msg_round = getattr(message, 'roundNumber', None)
|
||||
if current_round is not None and msg_round == current_round:
|
||||
chat_exchanges.append(entry)
|
||||
else:
|
||||
history_exchanges.append(entry)
|
||||
|
||||
return {
|
||||
"chat": chat_exchanges,
|
||||
"history": history_exchanges
|
||||
}
|
||||
|
||||
def _refreshDocumentFileAttributes(self, documents) -> None:
|
||||
"""Update file attributes (fileName, fileSize, mimeType) for documents"""
|
||||
for doc in documents:
|
||||
try:
|
||||
original_filename = doc.fileName
|
||||
# Skip invalid docs early if essential identifiers are missing
|
||||
if not getattr(doc, 'fileId', None):
|
||||
logger.debug(f"Skipping document {doc.id} due to missing fileId")
|
||||
setattr(doc, 'fileSize', 0)
|
||||
setattr(doc, 'mimeType', None)
|
||||
continue
|
||||
|
||||
file_info = self.getFileInfo(doc.fileId)
|
||||
if file_info:
|
||||
db_filename = file_info.get("fileName", doc.fileName)
|
||||
doc.fileName = file_info.get("fileName", doc.fileName)
|
||||
doc.fileSize = file_info.get("size", doc.fileSize)
|
||||
doc.mimeType = file_info.get("mimeType", doc.mimeType)
|
||||
|
||||
# Mark invalid if missing mimeType
|
||||
if not doc.mimeType:
|
||||
logger.debug(f"Document {doc.id} has missing mimeType; will be filtered from index")
|
||||
setattr(doc, 'fileSize', 0)
|
||||
|
||||
else:
|
||||
logger.warning(f"File not found for document {doc.id}, fileId: {doc.fileId}")
|
||||
setattr(doc, 'fileSize', 0)
|
||||
setattr(doc, 'mimeType', None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing file attributes for document {doc.id}: {e}")
|
||||
|
||||
def _generateWorkflowContextPrefix(self, message) -> str:
|
||||
"""Generate workflow context prefix: round{num}_task{num}_action{num}"""
|
||||
round_num = message.roundNumber if hasattr(message, 'roundNumber') else 1
|
||||
task_num = message.taskNumber if hasattr(message, 'taskNumber') else 0
|
||||
action_num = message.actionNumber if hasattr(message, 'actionNumber') else 0
|
||||
return f"round{round_num}_task{task_num}_action{action_num}"
|
||||
|
||||
def getDocumentReferenceFromChatDocument(self, document) -> str:
|
||||
"""Get document reference using document ID and filename."""
|
||||
try:
|
||||
# Use document ID and filename for simple reference
|
||||
return f"docItem:{document.id}:{document.fileName}"
|
||||
except Exception as e:
|
||||
logger.error(f"Critical error creating document reference for document {document.id}: {str(e)}")
|
||||
# Re-raise the error to prevent workflow from continuing with invalid data
|
||||
raise
|
||||
|
||||
def _getMessageSequenceForExchange(self, exchange, workflow) -> int:
|
||||
"""Get message sequence number for sorting exchanges by recency"""
|
||||
try:
|
||||
# Extract message ID from the first document reference
|
||||
if exchange['documents'] and len(exchange['documents']) > 0:
|
||||
first_doc_ref = exchange['documents'][0]
|
||||
if first_doc_ref.startswith("docItem:"):
|
||||
# docItem:<id>:<label> - extract ID
|
||||
parts = first_doc_ref.split(':')
|
||||
if len(parts) >= 2:
|
||||
doc_id = parts[1]
|
||||
# Find the message containing this document
|
||||
for message in workflow.messages:
|
||||
if message.documents:
|
||||
for doc in message.documents:
|
||||
if doc.id == doc_id:
|
||||
return message.sequenceNr if hasattr(message, 'sequenceNr') else 0
|
||||
elif first_doc_ref.startswith("docList:"):
|
||||
# docList:<message_id>:<label> - extract message ID
|
||||
parts = first_doc_ref.split(':')
|
||||
if len(parts) >= 2:
|
||||
message_id = parts[1]
|
||||
# Find the message by ID
|
||||
for message in workflow.messages:
|
||||
if str(message.id) == message_id:
|
||||
return message.sequenceNr if hasattr(message, 'sequenceNr') else 0
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting message sequence for exchange: {str(e)}")
|
||||
return 0
|
||||
|
||||
def _validateDocumentLabelConsistency(self, message) -> str:
|
||||
"""Validate that the document label used for references matches the message's actual label"""
|
||||
if not hasattr(message, 'documentsLabel') or not message.documentsLabel:
|
||||
return None
|
||||
|
||||
# Simply return the message's actual documentsLabel - no correction, just validation
|
||||
return message.documentsLabel
|
||||
|
||||
def getConnectionReferenceList(self) -> List[str]:
|
||||
"""Get connection reference list (matching old handlingTasks.py logic)"""
|
||||
try:
|
||||
# Get connections from the database using the same logic as the old system
|
||||
if hasattr(self.services, 'interfaceDbApp') and hasattr(self.services, 'user'):
|
||||
userId = self.services.user.id
|
||||
connections = self.services.interfaceDbApp.getUserConnections(userId)
|
||||
if connections:
|
||||
# Format connections as reference strings using the same pattern as the old system
|
||||
connectionRefs = []
|
||||
for conn in connections:
|
||||
# Create reference string in format: connection:{authority}:{username} [status:..., token:...]
|
||||
# This matches the format expected by getUserConnectionFromConnectionReference()
|
||||
ref = self.getConnectionReferenceFromUserConnection(conn)
|
||||
connectionRefs.append(ref)
|
||||
return connectionRefs
|
||||
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting connection reference list: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
def _getProgressLogger(self):
|
||||
"""Get or create the progress logger instance"""
|
||||
if self._progressLogger is None:
|
||||
self._progressLogger = ProgressLogger(self.services)
|
||||
return self._progressLogger
|
||||
|
||||
def createProgressLogger(self) -> ProgressLogger:
|
||||
return ProgressLogger(self.services)
|
||||
|
||||
def progressLogStart(self, operationId: str, serviceName: str, actionName: str, context: str = ""):
|
||||
"""Wrapper for ProgressLogger.startOperation"""
|
||||
progressLogger = self._getProgressLogger()
|
||||
return progressLogger.startOperation(operationId, serviceName, actionName, context)
|
||||
|
||||
def progressLogUpdate(self, operationId: str, progress: float, statusUpdate: str = ""):
|
||||
"""Wrapper for ProgressLogger.updateOperation"""
|
||||
progressLogger = self._getProgressLogger()
|
||||
return progressLogger.updateOperation(operationId, progress, statusUpdate)
|
||||
|
||||
def progressLogFinish(self, operationId: str, success: bool = True):
|
||||
"""Wrapper for ProgressLogger.finishOperation"""
|
||||
progressLogger = self._getProgressLogger()
|
||||
return progressLogger.finishOperation(operationId, success)
|
||||
|
||||
|
|
@ -1,554 +0,0 @@
|
|||
"""
|
||||
Delta Group JIRA-SharePoint Sync Manager
|
||||
|
||||
This module handles the synchronization of JIRA tickets to SharePoint using the new
|
||||
Graph API-based connector architecture.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime, UTC
|
||||
from typing import Dict, Any, List, Optional
|
||||
from modules.connectors.connectorSharepoint import ConnectorSharepoint
|
||||
from modules.connectors.connectorTicketJira import ConnectorTicketJira
|
||||
from modules.interfaces.interfaceAppObjects import getRootInterface
|
||||
from modules.interfaces.interfaceAppModel import UserInDB
|
||||
from modules.interfaces.interfaceTicketObjects import TicketSharepointSyncInterface
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Get environment type from configuration
|
||||
APP_ENV_TYPE = APP_CONFIG.get("APP_ENV_TYPE", "dev")
|
||||
|
||||
|
||||
def convert_adf_to_text(adf_data):
|
||||
"""Convert Atlassian Document Format (ADF) to plain text.
|
||||
|
||||
Based on Atlassian Document Format specification for JIRA fields.
|
||||
Handles paragraphs, lists, text formatting, and other ADF node types.
|
||||
|
||||
Args:
|
||||
adf_data: ADF object or None
|
||||
|
||||
Returns:
|
||||
str: Plain text content, or empty string if None/invalid
|
||||
"""
|
||||
if not adf_data or not isinstance(adf_data, dict):
|
||||
return ""
|
||||
|
||||
if adf_data.get("type") != "doc":
|
||||
return str(adf_data) if adf_data else ""
|
||||
|
||||
content = adf_data.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
return ""
|
||||
|
||||
def extract_text_from_content(content_list, list_level=0):
|
||||
"""Recursively extract text from ADF content with proper formatting."""
|
||||
text_parts = []
|
||||
list_counter = 1
|
||||
|
||||
for item in content_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
item_type = item.get("type", "")
|
||||
|
||||
if item_type == "text":
|
||||
# Extract text content, preserving formatting
|
||||
text = item.get("text", "")
|
||||
marks = item.get("marks", [])
|
||||
|
||||
# Handle text formatting (bold, italic, etc.)
|
||||
if marks:
|
||||
for mark in marks:
|
||||
if mark.get("type") == "strong":
|
||||
text = f"**{text}**"
|
||||
elif mark.get("type") == "em":
|
||||
text = f"*{text}*"
|
||||
elif mark.get("type") == "code":
|
||||
text = f"`{text}`"
|
||||
elif mark.get("type") == "link":
|
||||
attrs = mark.get("attrs", {})
|
||||
href = attrs.get("href", "")
|
||||
if href:
|
||||
text = f"[{text}]({href})"
|
||||
|
||||
text_parts.append(text)
|
||||
|
||||
elif item_type == "hardBreak":
|
||||
text_parts.append("\n")
|
||||
|
||||
elif item_type == "paragraph":
|
||||
paragraph_content = item.get("content", [])
|
||||
if paragraph_content:
|
||||
paragraph_text = extract_text_from_content(paragraph_content, list_level)
|
||||
if paragraph_text.strip():
|
||||
text_parts.append(paragraph_text)
|
||||
|
||||
elif item_type == "bulletList":
|
||||
list_content = item.get("content", [])
|
||||
for list_item in list_content:
|
||||
if list_item.get("type") == "listItem":
|
||||
list_item_content = list_item.get("content", [])
|
||||
for list_paragraph in list_item_content:
|
||||
if list_paragraph.get("type") == "paragraph":
|
||||
list_paragraph_content = list_paragraph.get("content", [])
|
||||
if list_paragraph_content:
|
||||
indent = " " * list_level
|
||||
bullet_text = extract_text_from_content(list_paragraph_content, list_level + 1)
|
||||
if bullet_text.strip():
|
||||
text_parts.append(f"{indent}• {bullet_text}")
|
||||
|
||||
elif item_type == "orderedList":
|
||||
list_content = item.get("content", [])
|
||||
for list_item in list_content:
|
||||
if list_item.get("type") == "listItem":
|
||||
list_item_content = list_item.get("content", [])
|
||||
for list_paragraph in list_item_content:
|
||||
if list_paragraph.get("type") == "paragraph":
|
||||
list_paragraph_content = list_paragraph.get("content", [])
|
||||
if list_paragraph_content:
|
||||
indent = " " * list_level
|
||||
ordered_text = extract_text_from_content(list_paragraph_content, list_level + 1)
|
||||
if ordered_text.strip():
|
||||
text_parts.append(f"{indent}{list_counter}. {ordered_text}")
|
||||
list_counter += 1
|
||||
|
||||
elif item_type == "listItem":
|
||||
# Handle nested list items
|
||||
list_item_content = item.get("content", [])
|
||||
if list_item_content:
|
||||
text_parts.append(extract_text_from_content(list_item_content, list_level))
|
||||
|
||||
elif item_type == "embedCard":
|
||||
# Handle embedded content (videos, etc.)
|
||||
attrs = item.get("attrs", {})
|
||||
url = attrs.get("url", "")
|
||||
if url:
|
||||
text_parts.append(f"[Embedded Content: {url}]")
|
||||
|
||||
elif item_type == "codeBlock":
|
||||
# Handle code blocks
|
||||
code_content = item.get("content", [])
|
||||
if code_content:
|
||||
code_text = extract_text_from_content(code_content, list_level)
|
||||
if code_text.strip():
|
||||
text_parts.append(f"```\n{code_text}\n```")
|
||||
|
||||
elif item_type == "blockquote":
|
||||
# Handle blockquotes
|
||||
quote_content = item.get("content", [])
|
||||
if quote_content:
|
||||
quote_text = extract_text_from_content(quote_content, list_level)
|
||||
if quote_text.strip():
|
||||
text_parts.append(f"> {quote_text}")
|
||||
|
||||
elif item_type == "heading":
|
||||
# Handle headings
|
||||
heading_content = item.get("content", [])
|
||||
if heading_content:
|
||||
heading_text = extract_text_from_content(heading_content, list_level)
|
||||
if heading_text.strip():
|
||||
level = item.get("attrs", {}).get("level", 1)
|
||||
text_parts.append(f"{'#' * level} {heading_text}")
|
||||
|
||||
elif item_type == "rule":
|
||||
# Handle horizontal rules
|
||||
text_parts.append("---")
|
||||
|
||||
else:
|
||||
# Handle unknown types by trying to extract content
|
||||
if "content" in item:
|
||||
content_text = extract_text_from_content(item.get("content", []), list_level)
|
||||
if content_text.strip():
|
||||
text_parts.append(content_text)
|
||||
|
||||
return "\n".join(text_parts)
|
||||
|
||||
result = extract_text_from_content(content)
|
||||
return result.strip()
|
||||
|
||||
|
||||
class ManagerSyncDelta:
|
||||
"""Manages JIRA to SharePoint synchronization for Delta Group.
|
||||
|
||||
Supports two sync modes:
|
||||
- CSV mode: Uses CSV files for synchronization (default)
|
||||
- Excel mode: Uses Excel (.xlsx) files for synchronization
|
||||
|
||||
To change sync mode, use the set_sync_mode() method or modify SYNC_MODE class variable.
|
||||
"""
|
||||
SHAREPOINT_SITE_ID = "02830618-4029-4dc8-8d3d-f5168f282249"
|
||||
SHAREPOINT_SITE_NAME = "SteeringBPM"
|
||||
SHAREPOINT_SITE_PATH = "SteeringBPM"
|
||||
SHAREPOINT_HOSTNAME = "deltasecurityag.sharepoint.com"
|
||||
SHAREPOINT_MAIN_FOLDER = "/General/50 Docs hosted by SELISE"
|
||||
SHAREPOINT_BACKUP_FOLDER = "/General/50 Docs hosted by SELISE/SyncHistory"
|
||||
SHAREPOINT_AUDIT_FOLDER = "/General/50 Docs hosted by SELISE/SyncHistory"
|
||||
SHAREPOINT_USER_ID = "patrick.motsch@delta.ch"
|
||||
|
||||
# Sync mode: "csv" or "xlsx"
|
||||
SYNC_MODE = "xlsx" # Can be "csv" or "xlsx"
|
||||
|
||||
# File names for different sync modes
|
||||
SYNC_FILE_CSV = "DELTAgroup x SELISE Ticket Exchange List.csv"
|
||||
SYNC_FILE_XLSX = "DELTAgroup x SELISE Ticket Exchange List.xlsx"
|
||||
|
||||
# JIRA connection parameters (hardcoded for Delta Group)
|
||||
JIRA_USERNAME = "p.motsch@valueon.ch"
|
||||
JIRA_API_TOKEN = "ATATT3xFfGF0d973nNb3R1wTDI4lesmJfJAmooS-4cYMJTyLfwYv4himrE6yyCxyX3aSMfl34NHcm2fAXeFXrLHUzJx0RQVUBonCFnlgexjLQTgS5BoCbSO7dwAVjlcHZZkArHbooCUaRwJ15n6AHkm-nwdjLQ3Z74TFnKKUZC4uhuh3Aj-MuX8=2D7124FA"
|
||||
JIRA_URL = "https://deltasecurity.atlassian.net"
|
||||
JIRA_PROJECT_CODE = "DCS"
|
||||
JIRA_ISSUE_TYPE = "Task"
|
||||
|
||||
# Task sync definition for field mapping (like original synchronizer)
|
||||
|
||||
TASK_SYNC_DEFINITION={
|
||||
#key=excel-header, [get:jira>excel | put: excel>jira, jira-xml-field-list]
|
||||
'ID': ['get', ['key']],
|
||||
'Module Category': ['get', ['fields', 'customfield_10058', 'value']],
|
||||
'Summary': ['get', ['fields', 'summary']],
|
||||
'Description': ['get', ['fields', 'description']], # ADF format - needs conversion to text
|
||||
'References': ['get', ['fields', 'customfield_10066']], # Field exists, may be None
|
||||
'Priority': ['get', ['fields', 'priority', 'name']],
|
||||
'Issue Status': ['get', ['fields', 'status', 'name']],
|
||||
'Assignee': ['get', ['fields', 'assignee', 'displayName']],
|
||||
'Issue Created': ['get', ['fields', 'created']],
|
||||
'Due Date': ['get', ['fields', 'duedate']], # Field exists, may be None
|
||||
'DELTA Comments': ['get', ['fields', 'customfield_10167']], # Field exists, may be None
|
||||
'SELISE Ticket References': ['put', ['fields', 'customfield_10067']],
|
||||
'SELISE Status Values': ['put', ['fields', 'customfield_10065']],
|
||||
'SELISE Comments': ['put', ['fields', 'customfield_10168']],
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the sync manager with hardcoded Delta Group credentials."""
|
||||
self.root_interface = getRootInterface()
|
||||
self.jira_connector = None
|
||||
self.sharepoint_connector = None
|
||||
self.target_site = None
|
||||
|
||||
def get_sync_file_name(self) -> str:
|
||||
"""Get the appropriate sync file name based on the sync mode."""
|
||||
if self.SYNC_MODE == "xlsx":
|
||||
return self.SYNC_FILE_XLSX
|
||||
else: # Default to CSV
|
||||
return self.SYNC_FILE_CSV
|
||||
|
||||
def set_sync_mode(self, mode: str) -> bool:
|
||||
"""Set the sync mode to either 'csv' or 'xlsx'.
|
||||
|
||||
Args:
|
||||
mode: Either 'csv' or 'xlsx'
|
||||
|
||||
Returns:
|
||||
bool: True if mode was set successfully, False if invalid mode
|
||||
"""
|
||||
if mode.lower() in ["csv", "xlsx"]:
|
||||
self.SYNC_MODE = mode.lower()
|
||||
logger.info(f"Sync mode changed to: {self.SYNC_MODE}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Invalid sync mode: {mode}. Must be 'csv' or 'xlsx'")
|
||||
return False
|
||||
|
||||
async def initialize_connectors(self) -> bool:
|
||||
"""Initialize JIRA and SharePoint connectors."""
|
||||
try:
|
||||
logger.info("Initializing JIRA connector with hardcoded credentials")
|
||||
|
||||
# Initialize JIRA connector using class constants
|
||||
self.jira_connector = await ConnectorTicketJira.create(
|
||||
jira_username=self.JIRA_USERNAME,
|
||||
jira_api_token=self.JIRA_API_TOKEN,
|
||||
jira_url=self.JIRA_URL,
|
||||
project_code=self.JIRA_PROJECT_CODE,
|
||||
issue_type=self.JIRA_ISSUE_TYPE
|
||||
)
|
||||
|
||||
# Use the admin user for SharePoint connection
|
||||
adminUser = self.root_interface.getUserByUsername("admin")
|
||||
if not adminUser:
|
||||
logger.error("Admin user not found - SharePoint connection required")
|
||||
return False
|
||||
|
||||
logger.info(f"Using admin user for SharePoint: {adminUser.id}")
|
||||
|
||||
# Get SharePoint connection for admin user
|
||||
user_connections = self.root_interface.getUserConnections(adminUser.id)
|
||||
sharepoint_connection = None
|
||||
|
||||
for connection in user_connections:
|
||||
if connection.authority == "msft" and connection.externalUsername == self.SHAREPOINT_USER_ID:
|
||||
sharepoint_connection = connection
|
||||
break
|
||||
|
||||
if not sharepoint_connection:
|
||||
logger.error(f"No SharePoint connection found for user: {self.SHAREPOINT_USER_ID}")
|
||||
return False
|
||||
|
||||
logger.info(f"Found SharePoint connection: {sharepoint_connection.id}")
|
||||
|
||||
# Get SharePoint token for this connection
|
||||
sharepoint_token = self.root_interface.getConnectionToken(sharepoint_connection.id)
|
||||
if not sharepoint_token:
|
||||
logger.error("No SharePoint token found for Delta Group user connection")
|
||||
return False
|
||||
|
||||
logger.info(f"Found SharePoint token: {sharepoint_token.id}")
|
||||
|
||||
# Initialize SharePoint connector with Graph API
|
||||
self.sharepoint_connector = ConnectorSharepoint(access_token=sharepoint_token.tokenAccess)
|
||||
|
||||
# Resolve the site by hostname + site path to get the real site ID
|
||||
logger.info(
|
||||
f"Resolving site ID via hostname+path: {self.SHAREPOINT_HOSTNAME}:/sites/{self.SHAREPOINT_SITE_PATH}"
|
||||
)
|
||||
resolved = await self.sharepoint_connector.find_site_by_url(
|
||||
hostname=self.SHAREPOINT_HOSTNAME,
|
||||
site_path=self.SHAREPOINT_SITE_PATH
|
||||
)
|
||||
|
||||
if not resolved:
|
||||
logger.error(
|
||||
f"Failed to resolve site. Hostname: {self.SHAREPOINT_HOSTNAME}, Path: {self.SHAREPOINT_SITE_PATH}"
|
||||
)
|
||||
return False
|
||||
|
||||
self.target_site = {
|
||||
"id": resolved.get("id"),
|
||||
"displayName": resolved.get("displayName", self.SHAREPOINT_SITE_NAME),
|
||||
"name": resolved.get("name", self.SHAREPOINT_SITE_NAME)
|
||||
}
|
||||
|
||||
# Test site access by listing root of the drive
|
||||
logger.info("Testing site access using resolved site ID...")
|
||||
test_result = await self.sharepoint_connector.list_folder_contents(
|
||||
site_id=self.target_site["id"],
|
||||
folder_path=""
|
||||
)
|
||||
|
||||
if test_result is not None:
|
||||
logger.info(
|
||||
f"Site access confirmed: {self.target_site['displayName']} (ID: {self.target_site['id']})"
|
||||
)
|
||||
else:
|
||||
logger.error("Could not access site drive - check permissions")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing connectors: {str(e)}")
|
||||
return False
|
||||
|
||||
async def sync_jira_to_sharepoint(self) -> bool:
|
||||
"""Perform the main JIRA to SharePoint synchronization using sophisticated sync logic."""
|
||||
try:
|
||||
logger.info(f"Starting JIRA to SharePoint synchronization (Mode: {self.SYNC_MODE})")
|
||||
|
||||
# Initialize connectors
|
||||
if not await self.initialize_connectors():
|
||||
logger.error("Failed to initialize connectors")
|
||||
return False
|
||||
|
||||
# Dump current Jira fields to text file for reference
|
||||
try:
|
||||
pass # await dump_jira_fields_to_file()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to dump JIRA fields (non-blocking): {str(e)}")
|
||||
|
||||
# Dump actual JIRA data for debugging
|
||||
try:
|
||||
pass # await dump_jira_data_to_file()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to dump JIRA data (non-blocking): {str(e)}")
|
||||
|
||||
# Get the appropriate sync file name based on mode
|
||||
sync_file_name = self.get_sync_file_name()
|
||||
logger.info(f"Using sync file: {sync_file_name}")
|
||||
|
||||
# Create the sophisticated sync interface
|
||||
sync_interface = await TicketSharepointSyncInterface.create(
|
||||
connector_ticket=self.jira_connector,
|
||||
connector_sharepoint=self.sharepoint_connector,
|
||||
task_sync_definition=self.TASK_SYNC_DEFINITION,
|
||||
sync_folder=self.SHAREPOINT_MAIN_FOLDER,
|
||||
sync_file=sync_file_name,
|
||||
backup_folder=self.SHAREPOINT_BACKUP_FOLDER,
|
||||
audit_folder=self.SHAREPOINT_AUDIT_FOLDER,
|
||||
site_id=self.target_site['id']
|
||||
)
|
||||
|
||||
# Perform the sophisticated sync based on mode
|
||||
if self.SYNC_MODE == "xlsx":
|
||||
logger.info("Performing JIRA to Excel sync...")
|
||||
await sync_interface.sync_from_jira_to_excel()
|
||||
logger.info("Performing Excel to JIRA sync...")
|
||||
await sync_interface.sync_from_excel_to_jira()
|
||||
else: # CSV mode (default)
|
||||
logger.info("Performing JIRA to CSV sync...")
|
||||
await sync_interface.sync_from_jira_to_csv()
|
||||
logger.info("Performing CSV to JIRA sync...")
|
||||
await sync_interface.sync_from_csv_to_jira()
|
||||
|
||||
logger.info(f"JIRA to SharePoint synchronization completed successfully (Mode: {self.SYNC_MODE})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during JIRA to SharePoint synchronization: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
# Utility: dump all Jira fields (name -> field id) to a text file
|
||||
async def dump_jira_fields_to_file(filepath: str = "delta_sync_fields.txt") -> bool:
|
||||
"""Write all available JIRA fields for the configured project/issue type to a text file.
|
||||
|
||||
The output format matches the legacy fields.txt, e.g.:
|
||||
'Summary': ['get', ['fields', 'summary']]
|
||||
|
||||
Args:
|
||||
filepath: Target text file path to write.
|
||||
|
||||
Returns:
|
||||
True on success, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Initialize Jira connector with the hardcoded credentials/constants
|
||||
jira = await ConnectorTicketJira.create(
|
||||
jira_username=ManagerSyncDelta.JIRA_USERNAME,
|
||||
jira_api_token=ManagerSyncDelta.JIRA_API_TOKEN,
|
||||
jira_url=ManagerSyncDelta.JIRA_URL,
|
||||
project_code=ManagerSyncDelta.JIRA_PROJECT_CODE,
|
||||
issue_type=ManagerSyncDelta.JIRA_ISSUE_TYPE,
|
||||
)
|
||||
|
||||
attributes = await jira.read_attributes()
|
||||
if not attributes:
|
||||
logger.warning("No JIRA attributes returned; nothing to write.")
|
||||
return False
|
||||
|
||||
# Ensure directory exists if a directory part is provided
|
||||
dir_name = os.path.dirname(filepath)
|
||||
if dir_name:
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
|
||||
# Write in the expected mapping format
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
for attr in attributes:
|
||||
# attr.field_name (human name), attr.field (Jira field id)
|
||||
f.write(f"'{attr.field_name}': ['get', ['fields', '{attr.field}']]\n")
|
||||
|
||||
logger.info(f"Wrote {len(attributes)} JIRA fields to {filepath}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dump JIRA fields: {str(e)}")
|
||||
return False
|
||||
|
||||
# Utility: dump actual JIRA data for debugging
|
||||
async def dump_jira_data_to_file(filepath: str = "delta_sync_data.txt") -> bool:
|
||||
"""Write actual JIRA ticket data to a text file for debugging field mapping.
|
||||
|
||||
Args:
|
||||
filepath: Target text file path to write.
|
||||
|
||||
Returns:
|
||||
True on success, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Initialize Jira connector with the hardcoded credentials/constants
|
||||
jira = await ConnectorTicketJira.create(
|
||||
jira_username=ManagerSyncDelta.JIRA_USERNAME,
|
||||
jira_api_token=ManagerSyncDelta.JIRA_API_TOKEN,
|
||||
jira_url=ManagerSyncDelta.JIRA_URL,
|
||||
project_code=ManagerSyncDelta.JIRA_PROJECT_CODE,
|
||||
issue_type=ManagerSyncDelta.JIRA_ISSUE_TYPE,
|
||||
)
|
||||
|
||||
# Get a few sample tickets to see the actual data structure
|
||||
tickets = await jira.read_tasks(limit=5)
|
||||
if not tickets:
|
||||
logger.warning("No JIRA tickets returned; nothing to write.")
|
||||
return False
|
||||
|
||||
# Ensure directory exists if a directory part is provided
|
||||
dir_name = os.path.dirname(filepath)
|
||||
if dir_name:
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
|
||||
# Write the actual ticket data
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("=== JIRA TICKET DATA DEBUG ===\n\n")
|
||||
for i, ticket in enumerate(tickets):
|
||||
f.write(f"--- TICKET {i+1} ---\n")
|
||||
f.write(f"Raw ticket data:\n")
|
||||
f.write(f"{ticket.data}\n\n")
|
||||
|
||||
# Also show the specific fields we're trying to map
|
||||
f.write("Field mapping analysis:\n")
|
||||
for field_name, field_path in ManagerSyncDelta.TASK_SYNC_DEFINITION.items():
|
||||
if field_path[0] == 'get': # Only analyze 'get' fields
|
||||
try:
|
||||
# Navigate through the field path
|
||||
value = ticket.data
|
||||
for key in field_path[1]:
|
||||
if isinstance(value, dict) and key in value:
|
||||
value = value[key]
|
||||
else:
|
||||
value = f"KEY_NOT_FOUND: {key}"
|
||||
break
|
||||
|
||||
# Convert ADF fields to text
|
||||
if field_name in ['Description', 'References', 'DELTA Comments', 'SELISE Comments']:
|
||||
if isinstance(value, dict) and value.get("type") == "doc":
|
||||
value = convert_adf_to_text(value)
|
||||
elif value is None:
|
||||
value = ""
|
||||
|
||||
f.write(f" {field_name}: {value}\n")
|
||||
except Exception as e:
|
||||
f.write(f" {field_name}: ERROR - {str(e)}\n")
|
||||
f.write("\n" + "="*50 + "\n\n")
|
||||
|
||||
logger.info(f"Wrote JIRA data for {len(tickets)} tickets to {filepath}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to dump JIRA data: {str(e)}")
|
||||
return False
|
||||
|
||||
# Global sync function for use in app.py
|
||||
async def perform_sync_jira_delta_group() -> bool:
|
||||
"""Perform JIRA to SharePoint synchronization for Delta Group.
|
||||
|
||||
This function is called by the scheduler and can be used independently.
|
||||
|
||||
Returns:
|
||||
bool: True if synchronization was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
if APP_ENV_TYPE != "prod" and APP_ENV_TYPE != "tst":
|
||||
logger.info("JIRA to SharePoint synchronization: TASK to run only in PROD")
|
||||
return True
|
||||
|
||||
logger.info("Starting Delta Group JIRA sync...")
|
||||
|
||||
|
||||
sync_manager = ManagerSyncDelta()
|
||||
success = await sync_manager.sync_jira_to_sharepoint()
|
||||
|
||||
if success:
|
||||
logger.info("Delta Group JIRA sync completed successfully")
|
||||
else:
|
||||
logger.error("Delta Group JIRA sync failed")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in perform_sync_jira_delta_group: {str(e)}")
|
||||
return False
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue