feat: add langgraph first tool; pydantic v2
This commit is contained in:
parent
68d6ab9890
commit
98b258ae53
7 changed files with 718 additions and 432 deletions
|
|
@ -1,7 +1,7 @@
|
|||
"""Security models: Token and AuthEvent."""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from modules.shared.attributeUtils import register_model_labels, ModelMixin
|
||||
from modules.shared.timezoneUtils import get_utc_timestamp
|
||||
from .datamodelUam import AuthAuthority
|
||||
|
|
@ -18,21 +18,36 @@ class Token(BaseModel, ModelMixin):
|
|||
id: Optional[str] = None
|
||||
userId: str
|
||||
authority: AuthAuthority
|
||||
connectionId: Optional[str] = Field(None, description="ID of the connection this token belongs to")
|
||||
connectionId: Optional[str] = Field(
|
||||
None, description="ID of the connection this token belongs to"
|
||||
)
|
||||
tokenAccess: str
|
||||
tokenType: str = "bearer"
|
||||
expiresAt: float = Field(description="When the token expires (UTC timestamp in seconds)")
|
||||
expiresAt: float = Field(
|
||||
description="When the token expires (UTC timestamp in seconds)"
|
||||
)
|
||||
tokenRefresh: Optional[str] = None
|
||||
createdAt: Optional[float] = Field(None, description="When the token was created (UTC timestamp in seconds)")
|
||||
status: TokenStatus = Field(default=TokenStatus.ACTIVE, description="Token status: active/revoked")
|
||||
revokedAt: Optional[float] = Field(None, description="When the token was revoked (UTC timestamp in seconds)")
|
||||
revokedBy: Optional[str] = Field(None, description="User ID who revoked the token (admin/self)")
|
||||
createdAt: Optional[float] = Field(
|
||||
None, description="When the token was created (UTC timestamp in seconds)"
|
||||
)
|
||||
status: TokenStatus = Field(
|
||||
default=TokenStatus.ACTIVE, description="Token status: active/revoked"
|
||||
)
|
||||
revokedAt: Optional[float] = Field(
|
||||
None, description="When the token was revoked (UTC timestamp in seconds)"
|
||||
)
|
||||
revokedBy: Optional[str] = Field(
|
||||
None, description="User ID who revoked the token (admin/self)"
|
||||
)
|
||||
reason: Optional[str] = Field(None, description="Optional revocation reason")
|
||||
sessionId: Optional[str] = Field(None, description="Logical session grouping for logout revocation")
|
||||
mandateId: Optional[str] = Field(None, description="Mandate ID for tenant scoping of the token")
|
||||
sessionId: Optional[str] = Field(
|
||||
None, description="Logical session grouping for logout revocation"
|
||||
)
|
||||
mandateId: Optional[str] = Field(
|
||||
None, description="Mandate ID for tenant scoping of the token"
|
||||
)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -59,14 +74,60 @@ register_model_labels(
|
|||
|
||||
|
||||
class AuthEvent(BaseModel, ModelMixin):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the auth event", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
userId: str = Field(description="ID of the user this event belongs to", frontend_type="text", frontend_readonly=True, frontend_required=True)
|
||||
eventType: str = Field(description="Type of authentication event (e.g., 'login', 'logout', 'token_refresh')", frontend_type="text", frontend_readonly=True, frontend_required=True)
|
||||
timestamp: float = Field(default_factory=get_utc_timestamp, description="Unix timestamp when the event occurred", frontend_type="datetime", frontend_readonly=True, frontend_required=True)
|
||||
ipAddress: Optional[str] = Field(default=None, description="IP address from which the event originated", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
userAgent: Optional[str] = Field(default=None, description="User agent string from the request", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
success: bool = Field(default=True, description="Whether the authentication event was successful", frontend_type="boolean", frontend_readonly=True, frontend_required=True)
|
||||
details: Optional[str] = Field(default=None, description="Additional details about the event", frontend_type="text", frontend_readonly=True, frontend_required=False)
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Unique ID of the auth event",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
userId: str = Field(
|
||||
description="ID of the user this event belongs to",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
eventType: str = Field(
|
||||
description="Type of authentication event (e.g., 'login', 'logout', 'token_refresh')",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
timestamp: float = Field(
|
||||
default_factory=get_utc_timestamp,
|
||||
description="Unix timestamp when the event occurred",
|
||||
frontend_type="datetime",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
ipAddress: Optional[str] = Field(
|
||||
default=None,
|
||||
description="IP address from which the event originated",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
userAgent: Optional[str] = Field(
|
||||
default=None,
|
||||
description="User agent string from the request",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
success: bool = Field(
|
||||
default=True,
|
||||
description="Whether the authentication event was successful",
|
||||
frontend_type="boolean",
|
||||
frontend_readonly=True,
|
||||
frontend_required=True,
|
||||
)
|
||||
details: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Additional details about the event",
|
||||
frontend_type="text",
|
||||
frontend_readonly=True,
|
||||
frontend_required=False,
|
||||
)
|
||||
|
||||
|
||||
register_model_labels(
|
||||
|
|
@ -83,5 +144,3 @@ register_model_labels(
|
|||
"details": {"en": "Details", "fr": "Détails"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
1
modules/features/chatBot/chatbotTools/__init__.py
Normal file
1
modules/features/chatBot/chatbotTools/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Contains all tools available for the chatbot to use."""
|
||||
|
|
@ -1 +1,7 @@
|
|||
"""Tools that are custom to a specific customer go here."""
|
||||
"""Shared tools available across all chatbot implementations."""
|
||||
|
||||
from modules.features.chatBot.chatbotTools.sharedTools.toolTavilySearch import (
|
||||
tavily_search,
|
||||
)
|
||||
|
||||
__all__ = ["tavily_search"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
"""Tavily Search Tool for LangGraph.
|
||||
|
||||
This tool provides web search capabilities using the Tavily API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from langchain_core.tools import tool
|
||||
from modules.connectors.connectorAiTavily import ConnectorWeb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool
|
||||
async def tavily_search(
|
||||
query: Annotated[str, "The search query to look up on the web"],
|
||||
) -> str:
|
||||
"""Search the web using Tavily API.
|
||||
|
||||
Use this tool to search for current information, news, or any web content.
|
||||
The tool returns relevant search results including titles and URLs.
|
||||
|
||||
Args:
|
||||
query: The search query string
|
||||
|
||||
Returns:
|
||||
A formatted string containing search results with titles and URLs
|
||||
"""
|
||||
try:
|
||||
# Create connector instance
|
||||
connector = await ConnectorWeb.create()
|
||||
|
||||
# Perform search with default parameters
|
||||
results = await connector._search(
|
||||
query=query,
|
||||
max_results=5,
|
||||
search_depth="basic",
|
||||
include_answer=True,
|
||||
include_raw_content=False,
|
||||
)
|
||||
|
||||
# Format results
|
||||
if not results:
|
||||
return f"No results found for query: {query}"
|
||||
|
||||
formatted_results = [f"Search results for '{query}':\n"]
|
||||
for i, result in enumerate(results, 1):
|
||||
formatted_results.append(f"{i}. {result.title}")
|
||||
formatted_results.append(f" URL: {result.url}\n")
|
||||
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in tavily_search tool: {str(e)}")
|
||||
return f"Error performing search: {str(e)}"
|
||||
|
|
@ -18,11 +18,19 @@ from modules.shared.configuration import APP_CONFIG
|
|||
from modules.shared.timezoneUtils import get_utc_now, get_utc_timestamp
|
||||
from modules.interfaces.interfaceDbAppAccess import AppAccess
|
||||
from modules.datamodels.datamodelUam import (
|
||||
User, Mandate, UserInDB, UserConnection,
|
||||
AuthAuthority, UserPrivilege, ConnectionStatus,
|
||||
User,
|
||||
Mandate,
|
||||
UserInDB,
|
||||
UserConnection,
|
||||
AuthAuthority,
|
||||
UserPrivilege,
|
||||
ConnectionStatus,
|
||||
)
|
||||
from modules.datamodels.datamodelSecurity import Token, AuthEvent, TokenStatus
|
||||
from modules.datamodels.datamodelNeutralizer import DataNeutraliserConfig, DataNeutralizerAttributes
|
||||
from modules.datamodels.datamodelNeutralizer import (
|
||||
DataNeutraliserConfig,
|
||||
DataNeutralizerAttributes,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -35,6 +43,7 @@ _rootAppObjects = None
|
|||
# Password-Hashing
|
||||
pwdContext = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||
|
||||
|
||||
class AppObjects:
|
||||
"""
|
||||
Interface to the Gateway system.
|
||||
|
|
@ -76,14 +85,16 @@ class AppObjects:
|
|||
self.userLanguage = currentUser.language # Default user language
|
||||
|
||||
# Initialize access control with user context
|
||||
self.access = AppAccess(self.currentUser, self.db) # Convert to dict only when needed
|
||||
self.access = AppAccess(
|
||||
self.currentUser, self.db
|
||||
) # Convert to dict only when needed
|
||||
|
||||
# Update database context
|
||||
self.db.updateContext(self.userId)
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup method to close database connection."""
|
||||
if hasattr(self, 'db') and self.db is not None:
|
||||
if hasattr(self, "db") and self.db is not None:
|
||||
try:
|
||||
self.db.close()
|
||||
except Exception as e:
|
||||
|
|
@ -106,7 +117,7 @@ class AppObjects:
|
|||
dbUser=dbUser,
|
||||
dbPassword=dbPassword,
|
||||
dbPort=dbPort,
|
||||
userId=self.userId
|
||||
userId=self.userId,
|
||||
)
|
||||
|
||||
# Initialize database system
|
||||
|
|
@ -129,16 +140,12 @@ class AppObjects:
|
|||
mandates = self.db.getRecordset(Mandate)
|
||||
if existingMandateId is None or not mandates:
|
||||
logger.info("Creating Root mandate")
|
||||
rootMandate = Mandate(
|
||||
name="Root",
|
||||
language="en",
|
||||
enabled=True
|
||||
)
|
||||
rootMandate = Mandate(name="Root", language="en", enabled=True)
|
||||
createdMandate = self.db.recordCreate(Mandate, rootMandate)
|
||||
logger.info(f"Root mandate created with ID {createdMandate['id']}")
|
||||
|
||||
# Update mandate context
|
||||
self.mandateId = createdMandate['id']
|
||||
self.mandateId = createdMandate["id"]
|
||||
|
||||
def _initAdminUser(self):
|
||||
"""Creates the Admin user if it doesn't exist."""
|
||||
|
|
@ -155,8 +162,10 @@ class AppObjects:
|
|||
language="en",
|
||||
privilege=UserPrivilege.SYSADMIN,
|
||||
authenticationAuthority="local", # Using lowercase value directly
|
||||
hashedPassword=self._getPasswordHash(APP_CONFIG.get("APP_INIT_PASS_ADMIN_SECRET")),
|
||||
connections=[]
|
||||
hashedPassword=self._getPasswordHash(
|
||||
APP_CONFIG.get("APP_INIT_PASS_ADMIN_SECRET")
|
||||
),
|
||||
connections=[],
|
||||
)
|
||||
createdUser = self.db.recordCreate(UserInDB, adminUser)
|
||||
logger.info(f"Admin user created with ID {createdUser['id']}")
|
||||
|
|
@ -168,7 +177,9 @@ class AppObjects:
|
|||
def _initEventUser(self):
|
||||
"""Creates the Event user if it doesn't exist."""
|
||||
# Check if event user already exists
|
||||
existingUsers = self.db.getRecordset(UserInDB, recordFilter={"username": "event"})
|
||||
existingUsers = self.db.getRecordset(
|
||||
UserInDB, recordFilter={"username": "event"}
|
||||
)
|
||||
if not existingUsers:
|
||||
logger.info("Creating Event user")
|
||||
eventUser = UserInDB(
|
||||
|
|
@ -180,13 +191,17 @@ class AppObjects:
|
|||
language="en",
|
||||
privilege=UserPrivilege.SYSADMIN,
|
||||
authenticationAuthority="local", # Using lowercase value directly
|
||||
hashedPassword=self._getPasswordHash(APP_CONFIG.get("APP_INIT_PASS_EVENT_SECRET")),
|
||||
connections=[]
|
||||
hashedPassword=self._getPasswordHash(
|
||||
APP_CONFIG.get("APP_INIT_PASS_EVENT_SECRET")
|
||||
),
|
||||
connections=[],
|
||||
)
|
||||
createdUser = self.db.recordCreate(UserInDB, eventUser)
|
||||
logger.info(f"Event user created with ID {createdUser['id']}")
|
||||
|
||||
def _uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
def _uam(
|
||||
self, model_class: type, recordset: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Unified user access management function that filters data based on user privileges
|
||||
and adds access control attributes.
|
||||
|
|
@ -205,7 +220,7 @@ class AppObjects:
|
|||
cleanedRecords = []
|
||||
for record in filteredRecords:
|
||||
# Create a new dict with only non-database fields
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith('_')}
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
cleanedRecords.append(cleanedRecord)
|
||||
|
||||
return cleanedRecords
|
||||
|
|
@ -317,12 +332,20 @@ class AppObjects:
|
|||
|
||||
return user
|
||||
|
||||
def createUser(self, username: str, password: str = None, email: str = None,
|
||||
fullName: str = None, language: str = "en", enabled: bool = True,
|
||||
privilege: UserPrivilege = UserPrivilege.USER,
|
||||
authenticationAuthority: AuthAuthority = AuthAuthority.LOCAL,
|
||||
externalId: str = None, externalUsername: str = None,
|
||||
externalEmail: str = None) -> User:
|
||||
def createUser(
|
||||
self,
|
||||
username: str,
|
||||
password: str = None,
|
||||
email: str = None,
|
||||
fullName: str = None,
|
||||
language: str = "en",
|
||||
enabled: bool = True,
|
||||
privilege: UserPrivilege = UserPrivilege.USER,
|
||||
authenticationAuthority: AuthAuthority = AuthAuthority.LOCAL,
|
||||
externalId: str = None,
|
||||
externalUsername: str = None,
|
||||
externalEmail: str = None,
|
||||
) -> User:
|
||||
"""Create a new user with optional external connection"""
|
||||
try:
|
||||
# Ensure username is a string
|
||||
|
|
@ -348,7 +371,7 @@ class AppObjects:
|
|||
privilege=privilege,
|
||||
authenticationAuthority=authenticationAuthority,
|
||||
hashedPassword=self._getPasswordHash(password) if password else None,
|
||||
connections=[]
|
||||
connections=[],
|
||||
)
|
||||
|
||||
# Create user record
|
||||
|
|
@ -356,7 +379,6 @@ class AppObjects:
|
|||
if not createdRecord or not createdRecord.get("id"):
|
||||
raise ValueError("Failed to create user record")
|
||||
|
||||
|
||||
# Add external connection if provided
|
||||
if externalId and externalUsername:
|
||||
self.addUserConnection(
|
||||
|
|
@ -364,11 +386,13 @@ class AppObjects:
|
|||
authenticationAuthority,
|
||||
externalId,
|
||||
externalUsername,
|
||||
externalEmail
|
||||
externalEmail,
|
||||
)
|
||||
|
||||
# Get created user using the returned ID
|
||||
createdUser = self.db.getRecordset(UserInDB, recordFilter={"id": createdRecord["id"]})
|
||||
createdUser = self.db.getRecordset(
|
||||
UserInDB, recordFilter={"id": createdRecord["id"]}
|
||||
)
|
||||
if not createdUser or len(createdUser) == 0:
|
||||
raise ValueError("Failed to retrieve created user")
|
||||
|
||||
|
|
@ -399,7 +423,6 @@ class AppObjects:
|
|||
# Update user record
|
||||
self.db.recordModify(UserInDB, userId, updatedUser)
|
||||
|
||||
|
||||
# Get updated user
|
||||
updatedUser = self.getUser(userId)
|
||||
if not updatedUser:
|
||||
|
|
@ -422,8 +445,6 @@ class AppObjects:
|
|||
def _deleteUserReferencedData(self, userId: str) -> None:
|
||||
"""Deletes all data associated with a user."""
|
||||
try:
|
||||
|
||||
|
||||
# Delete user auth events
|
||||
events = self.db.getRecordset(AuthEvent, recordFilter={"userId": userId})
|
||||
for event in events:
|
||||
|
|
@ -434,9 +455,10 @@ class AppObjects:
|
|||
for token in tokens:
|
||||
self.db.recordDelete(Token, token["id"])
|
||||
|
||||
|
||||
# Delete user connections
|
||||
connections = self.db.getRecordset(UserConnection, recordFilter={"userId": userId})
|
||||
connections = self.db.getRecordset(
|
||||
UserConnection, recordFilter={"userId": userId}
|
||||
)
|
||||
for conn in connections:
|
||||
self.db.recordDelete(UserConnection, conn["id"])
|
||||
|
||||
|
|
@ -465,7 +487,6 @@ class AppObjects:
|
|||
if not success:
|
||||
raise ValueError(f"Failed to delete user {userId}")
|
||||
|
||||
|
||||
logger.info(f"User {userId} successfully deleted")
|
||||
return True
|
||||
|
||||
|
|
@ -493,31 +514,22 @@ class AppObjects:
|
|||
authenticationAuthority = checkData.get("authenticationAuthority", "local")
|
||||
|
||||
if not username:
|
||||
return {
|
||||
"available": False,
|
||||
"message": "Username is required"
|
||||
}
|
||||
return {"available": False, "message": "Username is required"}
|
||||
|
||||
# Get user by username
|
||||
user = self.getUserByUsername(username)
|
||||
|
||||
# Check if user exists (User model instance)
|
||||
if user is not None:
|
||||
return {
|
||||
"available": False,
|
||||
"message": "Username is already taken"
|
||||
}
|
||||
return {"available": False, "message": "Username is already taken"}
|
||||
|
||||
return {
|
||||
"available": True,
|
||||
"message": "Username is available"
|
||||
}
|
||||
return {"available": True, "message": "Username is available"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking username availability: {str(e)}")
|
||||
return {
|
||||
"available": False,
|
||||
"message": f"Error checking username availability: {str(e)}"
|
||||
"message": f"Error checking username availability: {str(e)}",
|
||||
}
|
||||
|
||||
# Connection methods
|
||||
|
|
@ -526,7 +538,9 @@ class AppObjects:
|
|||
"""Returns all connections for a user."""
|
||||
try:
|
||||
# Get connections for this user
|
||||
connections = self.db.getRecordset(UserConnection, recordFilter={"userId": userId})
|
||||
connections = self.db.getRecordset(
|
||||
UserConnection, recordFilter={"userId": userId}
|
||||
)
|
||||
|
||||
# Convert to UserConnection objects
|
||||
result = []
|
||||
|
|
@ -543,11 +557,13 @@ class AppObjects:
|
|||
status=conn_dict.get("status", "pending"),
|
||||
connectedAt=conn_dict.get("connectedAt"),
|
||||
lastChecked=conn_dict.get("lastChecked"),
|
||||
expiresAt=conn_dict.get("expiresAt")
|
||||
expiresAt=conn_dict.get("expiresAt"),
|
||||
)
|
||||
result.append(connection)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting connection dict to object: {str(e)}")
|
||||
logger.error(
|
||||
f"Error converting connection dict to object: {str(e)}"
|
||||
)
|
||||
continue
|
||||
return result
|
||||
|
||||
|
|
@ -555,9 +571,15 @@ class AppObjects:
|
|||
logger.error(f"Error getting user connections: {str(e)}")
|
||||
return []
|
||||
|
||||
def addUserConnection(self, userId: str, authority: AuthAuthority, externalId: str,
|
||||
externalUsername: str, externalEmail: Optional[str] = None,
|
||||
status: ConnectionStatus = ConnectionStatus.PENDING) -> UserConnection:
|
||||
def addUserConnection(
|
||||
self,
|
||||
userId: str,
|
||||
authority: AuthAuthority,
|
||||
externalId: str,
|
||||
externalUsername: str,
|
||||
externalEmail: Optional[str] = None,
|
||||
status: ConnectionStatus = ConnectionStatus.PENDING,
|
||||
) -> UserConnection:
|
||||
"""
|
||||
Adds a new connection for a user.
|
||||
|
||||
|
|
@ -589,13 +611,12 @@ class AppObjects:
|
|||
status=status,
|
||||
connectedAt=get_utc_timestamp(),
|
||||
lastChecked=get_utc_timestamp(),
|
||||
expiresAt=None # Optional field, set to None by default
|
||||
expiresAt=None, # Optional field, set to None by default
|
||||
)
|
||||
|
||||
# Save to connections table
|
||||
self.db.recordCreate(UserConnection, connection)
|
||||
|
||||
|
||||
return connection
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -606,9 +627,9 @@ class AppObjects:
|
|||
"""Remove a connection to an external service"""
|
||||
try:
|
||||
# Get connection
|
||||
connections = self.db.getRecordset(UserConnection, recordFilter={
|
||||
"id": connectionId
|
||||
})
|
||||
connections = self.db.getRecordset(
|
||||
UserConnection, recordFilter={"id": connectionId}
|
||||
)
|
||||
|
||||
if not connections:
|
||||
raise ValueError(f"Connection {connectionId} not found")
|
||||
|
|
@ -616,7 +637,6 @@ class AppObjects:
|
|||
# Delete connection
|
||||
self.db.recordDelete(UserConnection, connectionId)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing user connection: {str(e)}")
|
||||
raise ValueError(f"Failed to remove user connection: {str(e)}")
|
||||
|
|
@ -647,17 +667,13 @@ class AppObjects:
|
|||
raise PermissionError("No permission to create mandates")
|
||||
|
||||
# Create mandate data using model
|
||||
mandateData = Mandate(
|
||||
name=name,
|
||||
language=language
|
||||
)
|
||||
mandateData = Mandate(name=name, language=language)
|
||||
|
||||
# Create mandate record
|
||||
createdRecord = self.db.recordCreate(Mandate, mandateData)
|
||||
if not createdRecord or not createdRecord.get("id"):
|
||||
raise ValueError("Failed to create mandate record")
|
||||
|
||||
|
||||
return Mandate.from_dict(createdRecord)
|
||||
|
||||
def updateMandate(self, mandateId: str, updateData: Dict[str, Any]) -> Mandate:
|
||||
|
|
@ -707,7 +723,9 @@ class AppObjects:
|
|||
# Check if mandate has users
|
||||
users = self.getUsersByMandate(mandateId)
|
||||
if users:
|
||||
raise ValueError(f"Cannot delete mandate {mandateId} with existing users")
|
||||
raise ValueError(
|
||||
f"Cannot delete mandate {mandateId} with existing users"
|
||||
)
|
||||
|
||||
# Delete mandate
|
||||
success = self.db.recordDelete(Mandate, mandateId)
|
||||
|
|
@ -727,7 +745,9 @@ class AppObjects:
|
|||
try:
|
||||
# Validate that this is NOT a connection token
|
||||
if token.connectionId:
|
||||
raise ValueError("Access tokens cannot have connectionId - use saveConnectionToken instead")
|
||||
raise ValueError(
|
||||
"Access tokens cannot have connectionId - use saveConnectionToken instead"
|
||||
)
|
||||
|
||||
# Validate user context
|
||||
if not self.currentUser or not self.currentUser.id:
|
||||
|
|
@ -745,33 +765,44 @@ class AppObjects:
|
|||
# If replace_existing is True, delete old access tokens for this user and authority first
|
||||
if replace_existing:
|
||||
try:
|
||||
old_tokens = self.db.getRecordset(Token, recordFilter={
|
||||
"userId": self.currentUser.id,
|
||||
"authority": token.authority,
|
||||
"connectionId": None # Ensure we only delete access tokens
|
||||
})
|
||||
old_tokens = self.db.getRecordset(
|
||||
Token,
|
||||
recordFilter={
|
||||
"userId": self.currentUser.id,
|
||||
"authority": token.authority,
|
||||
"connectionId": None, # Ensure we only delete access tokens
|
||||
},
|
||||
)
|
||||
deleted_count = 0
|
||||
for old_token in old_tokens:
|
||||
if old_token["id"] != token.id: # Don't delete the new token if it already exists
|
||||
if (
|
||||
old_token["id"] != token.id
|
||||
): # Don't delete the new token if it already exists
|
||||
self.db.recordDelete(Token, old_token["id"])
|
||||
deleted_count += 1
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Replaced {deleted_count} old access tokens for user {self.currentUser.id} and authority {token.authority}")
|
||||
logger.info(
|
||||
f"Replaced {deleted_count} old access tokens for user {self.currentUser.id} and authority {token.authority}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete old access tokens for user {self.currentUser.id} and authority {token.authority}: {str(e)}")
|
||||
logger.warning(
|
||||
f"Failed to delete old access tokens for user {self.currentUser.id} and authority {token.authority}: {str(e)}"
|
||||
)
|
||||
# Continue with saving the new token even if deletion fails
|
||||
|
||||
# Convert to dict and ensure all fields are properly set
|
||||
token_dict = token.dict()
|
||||
token_dict = token.model_dump()
|
||||
# Ensure userId is set to current user
|
||||
# Convert to dict and ensure all fields are properly set
|
||||
token_dict = token.model_dump()
|
||||
# Ensure userId is set to current user
|
||||
token_dict["userId"] = self.currentUser.id
|
||||
|
||||
# Save to database
|
||||
self.db.recordCreate(Token, token_dict)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving access token: {str(e)}")
|
||||
raise
|
||||
|
|
@ -781,7 +812,9 @@ class AppObjects:
|
|||
try:
|
||||
# Validate that this IS a connection token
|
||||
if not token.connectionId:
|
||||
raise ValueError("Connection tokens must have connectionId - use saveAccessToken instead")
|
||||
raise ValueError(
|
||||
"Connection tokens must have connectionId - use saveAccessToken instead"
|
||||
)
|
||||
|
||||
# Validate user context
|
||||
if not self.currentUser or not self.currentUser.id:
|
||||
|
|
@ -799,31 +832,36 @@ class AppObjects:
|
|||
# If replace_existing is True, delete old tokens for this connectionId first
|
||||
if replace_existing:
|
||||
try:
|
||||
old_tokens = self.db.getRecordset(Token, recordFilter={
|
||||
"connectionId": token.connectionId
|
||||
})
|
||||
old_tokens = self.db.getRecordset(
|
||||
Token, recordFilter={"connectionId": token.connectionId}
|
||||
)
|
||||
deleted_count = 0
|
||||
for old_token in old_tokens:
|
||||
if old_token["id"] != token.id: # Don't delete the new token if it already exists
|
||||
if (
|
||||
old_token["id"] != token.id
|
||||
): # Don't delete the new token if it already exists
|
||||
self.db.recordDelete(Token, old_token["id"])
|
||||
deleted_count += 1
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Replaced {deleted_count} old tokens for connectionId {token.connectionId}")
|
||||
logger.info(
|
||||
f"Replaced {deleted_count} old tokens for connectionId {token.connectionId}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete old tokens for connectionId {token.connectionId}: {str(e)}")
|
||||
logger.warning(
|
||||
f"Failed to delete old tokens for connectionId {token.connectionId}: {str(e)}"
|
||||
)
|
||||
# Continue with saving the new token even if deletion fails
|
||||
|
||||
# Convert to dict and ensure all fields are properly set
|
||||
token_dict = token.dict()
|
||||
token_dict = token.model_dump()
|
||||
# Ensure userId is set to current user
|
||||
token_dict["userId"] = self.currentUser.id
|
||||
|
||||
# Save to database
|
||||
self.db.recordCreate(Token, token_dict)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving connection token: {str(e)}")
|
||||
raise
|
||||
|
|
@ -837,13 +875,14 @@ class AppObjects:
|
|||
|
||||
# Get token for this specific connection
|
||||
# Query for specific connection
|
||||
tokens = self.db.getRecordset(Token, recordFilter={
|
||||
"connectionId": connectionId
|
||||
})
|
||||
|
||||
tokens = self.db.getRecordset(
|
||||
Token, recordFilter={"connectionId": connectionId}
|
||||
)
|
||||
|
||||
if not tokens:
|
||||
logger.warning(f"No connection token found for connectionId: {connectionId}")
|
||||
logger.warning(
|
||||
f"No connection token found for connectionId: {connectionId}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Sort by expiration date and get the latest (most recent expiration)
|
||||
|
|
@ -855,16 +894,27 @@ class AppObjects:
|
|||
return latest_token
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting connection token for connectionId {connectionId}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error getting connection token for connectionId {connectionId}: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
def findActiveTokenById(self, tokenId: str, userId: str, authority: AuthAuthority, sessionId: str = None, mandateId: str = None) -> Optional[Token]:
|
||||
def findActiveTokenById(
|
||||
self,
|
||||
tokenId: str,
|
||||
userId: str,
|
||||
authority: AuthAuthority,
|
||||
sessionId: str = None,
|
||||
mandateId: str = None,
|
||||
) -> Optional[Token]:
|
||||
"""Find an active access token by its id (jti) with optional session/tenant scoping."""
|
||||
try:
|
||||
recordFilter = {
|
||||
"id": tokenId,
|
||||
"userId": userId,
|
||||
"authority": authority.value if hasattr(authority, 'value') else str(authority),
|
||||
"authority": authority.value
|
||||
if hasattr(authority, "value")
|
||||
else str(authority),
|
||||
"status": TokenStatus.ACTIVE,
|
||||
}
|
||||
if sessionId is not None:
|
||||
|
|
@ -892,7 +942,7 @@ class AppObjects:
|
|||
"status": TokenStatus.REVOKED,
|
||||
"revokedAt": get_utc_timestamp(),
|
||||
"revokedBy": revokedBy,
|
||||
"reason": reason or "revoked"
|
||||
"reason": reason or "revoked",
|
||||
}
|
||||
self.db.recordModify(Token, tokenId, tokenUpdate)
|
||||
return True
|
||||
|
|
@ -900,30 +950,53 @@ class AppObjects:
|
|||
logger.error(f"Error revoking token {tokenId}: {str(e)}")
|
||||
return False
|
||||
|
||||
def revokeTokensBySessionId(self, sessionId: str, userId: str, authority: AuthAuthority, revokedBy: str, reason: str = None) -> int:
|
||||
def revokeTokensBySessionId(
|
||||
self,
|
||||
sessionId: str,
|
||||
userId: str,
|
||||
authority: AuthAuthority,
|
||||
revokedBy: str,
|
||||
reason: str = None,
|
||||
) -> int:
|
||||
"""Revoke all tokens of a session for a user/authority."""
|
||||
try:
|
||||
tokens = self.db.getRecordset(Token, recordFilter={
|
||||
"userId": userId,
|
||||
"authority": authority.value if hasattr(authority, 'value') else str(authority),
|
||||
"sessionId": sessionId,
|
||||
"status": TokenStatus.ACTIVE
|
||||
})
|
||||
tokens = self.db.getRecordset(
|
||||
Token,
|
||||
recordFilter={
|
||||
"userId": userId,
|
||||
"authority": authority.value
|
||||
if hasattr(authority, "value")
|
||||
else str(authority),
|
||||
"sessionId": sessionId,
|
||||
"status": TokenStatus.ACTIVE,
|
||||
},
|
||||
)
|
||||
count = 0
|
||||
for t in tokens:
|
||||
self.db.recordModify(Token, t["id"], {
|
||||
"status": TokenStatus.REVOKED,
|
||||
"revokedAt": get_utc_timestamp(),
|
||||
"revokedBy": revokedBy,
|
||||
"reason": reason or "session logout"
|
||||
})
|
||||
self.db.recordModify(
|
||||
Token,
|
||||
t["id"],
|
||||
{
|
||||
"status": TokenStatus.REVOKED,
|
||||
"revokedAt": get_utc_timestamp(),
|
||||
"revokedBy": revokedBy,
|
||||
"reason": reason or "session logout",
|
||||
},
|
||||
)
|
||||
count += 1
|
||||
return count
|
||||
except Exception as e:
|
||||
logger.error(f"Error revoking tokens for session {sessionId}: {str(e)}")
|
||||
return 0
|
||||
|
||||
def revokeTokensByUser(self, userId: str, authority: AuthAuthority = None, mandateId: str = None, revokedBy: str = None, reason: str = None) -> int:
|
||||
def revokeTokensByUser(
|
||||
self,
|
||||
userId: str,
|
||||
authority: AuthAuthority = None,
|
||||
mandateId: str = None,
|
||||
revokedBy: str = None,
|
||||
reason: str = None,
|
||||
) -> int:
|
||||
"""Revoke all active tokens for a user, optionally filtered by authority/mandate."""
|
||||
try:
|
||||
# Fetch all active tokens for user (optionally filtered by authority)
|
||||
|
|
@ -932,16 +1005,22 @@ class AppObjects:
|
|||
"status": TokenStatus.ACTIVE,
|
||||
}
|
||||
if authority is not None:
|
||||
recordFilter["authority"] = authority.value if hasattr(authority, 'value') else str(authority)
|
||||
recordFilter["authority"] = (
|
||||
authority.value if hasattr(authority, "value") else str(authority)
|
||||
)
|
||||
tokens = self.db.getRecordset(Token, recordFilter=recordFilter)
|
||||
count = 0
|
||||
for t in tokens:
|
||||
self.db.recordModify(Token, t["id"], {
|
||||
"status": TokenStatus.REVOKED,
|
||||
"revokedAt": get_utc_timestamp(),
|
||||
"revokedBy": revokedBy,
|
||||
"reason": reason or "admin revoke"
|
||||
})
|
||||
self.db.recordModify(
|
||||
Token,
|
||||
t["id"],
|
||||
{
|
||||
"status": TokenStatus.REVOKED,
|
||||
"revokedAt": get_utc_timestamp(),
|
||||
"revokedBy": revokedBy,
|
||||
"reason": reason or "admin revoke",
|
||||
},
|
||||
)
|
||||
count += 1
|
||||
return count
|
||||
except Exception as e:
|
||||
|
|
@ -958,7 +1037,10 @@ class AppObjects:
|
|||
all_tokens = self.db.getRecordset(Token, recordFilter={})
|
||||
|
||||
for token_data in all_tokens:
|
||||
if token_data.get("expiresAt") and token_data.get("expiresAt") < current_time:
|
||||
if (
|
||||
token_data.get("expiresAt")
|
||||
and token_data.get("expiresAt") < current_time
|
||||
):
|
||||
# Token is expired, delete it
|
||||
self.db.recordDelete(Token, token_data["id"])
|
||||
cleaned_count += 1
|
||||
|
|
@ -983,7 +1065,7 @@ class AppObjects:
|
|||
self.access = None
|
||||
|
||||
# Clear database context
|
||||
if hasattr(self, 'db'):
|
||||
if hasattr(self, "db"):
|
||||
self.db.updateContext("")
|
||||
|
||||
logger.info("User logged out successfully")
|
||||
|
|
@ -997,7 +1079,9 @@ class AppObjects:
|
|||
def getNeutralizationConfig(self) -> Optional[DataNeutraliserConfig]:
|
||||
"""Get the data neutralization configuration for the current user's mandate"""
|
||||
try:
|
||||
configs = self.db.getRecordset(DataNeutraliserConfig, recordFilter={"mandateId": self.mandateId})
|
||||
configs = self.db.getRecordset(
|
||||
DataNeutraliserConfig, recordFilter={"mandateId": self.mandateId}
|
||||
)
|
||||
if not configs:
|
||||
return None
|
||||
|
||||
|
|
@ -1012,7 +1096,9 @@ class AppObjects:
|
|||
logger.error(f"Error getting neutralization config: {str(e)}")
|
||||
return None
|
||||
|
||||
def createOrUpdateNeutralizationConfig(self, config_data: Dict[str, Any]) -> DataNeutraliserConfig:
|
||||
def createOrUpdateNeutralizationConfig(
|
||||
self, config_data: Dict[str, Any]
|
||||
) -> DataNeutraliserConfig:
|
||||
"""Create or update the data neutralization configuration"""
|
||||
try:
|
||||
# Check if config already exists
|
||||
|
|
@ -1025,7 +1111,9 @@ class AppObjects:
|
|||
update_data["updatedAt"] = get_utc_timestamp()
|
||||
|
||||
updated_config = DataNeutraliserConfig.from_dict(update_data)
|
||||
self.db.recordModify(DataNeutraliserConfig, existing_config.id, updated_config)
|
||||
self.db.recordModify(
|
||||
DataNeutraliserConfig, existing_config.id, updated_config
|
||||
)
|
||||
|
||||
return updated_config
|
||||
else:
|
||||
|
|
@ -1042,17 +1130,24 @@ class AppObjects:
|
|||
logger.error(f"Error creating/updating neutralization config: {str(e)}")
|
||||
raise ValueError(f"Failed to create/update neutralization config: {str(e)}")
|
||||
|
||||
def getNeutralizationAttributes(self, file_id: Optional[str] = None) -> List[DataNeutralizerAttributes]:
|
||||
def getNeutralizationAttributes(
|
||||
self, file_id: Optional[str] = None
|
||||
) -> List[DataNeutralizerAttributes]:
|
||||
"""Get neutralization attributes, optionally filtered by file ID"""
|
||||
try:
|
||||
filter_dict = {"mandateId": self.mandateId}
|
||||
if file_id:
|
||||
filter_dict["fileId"] = file_id
|
||||
|
||||
attributes = self.db.getRecordset(DataNeutralizerAttributes, recordFilter=filter_dict)
|
||||
attributes = self.db.getRecordset(
|
||||
DataNeutralizerAttributes, recordFilter=filter_dict
|
||||
)
|
||||
filtered_attributes = self._uam(DataNeutralizerAttributes, attributes)
|
||||
|
||||
return [DataNeutralizerAttributes.from_dict(attr) for attr in filtered_attributes]
|
||||
return [
|
||||
DataNeutralizerAttributes.from_dict(attr)
|
||||
for attr in filtered_attributes
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting neutralization attributes: {str(e)}")
|
||||
|
|
@ -1061,23 +1156,27 @@ class AppObjects:
|
|||
def deleteNeutralizationAttributes(self, file_id: str) -> bool:
|
||||
"""Delete all neutralization attributes for a specific file"""
|
||||
try:
|
||||
attributes = self.db.getRecordset(DataNeutralizerAttributes, recordFilter={
|
||||
"mandateId": self.mandateId,
|
||||
"fileId": file_id
|
||||
})
|
||||
attributes = self.db.getRecordset(
|
||||
DataNeutralizerAttributes,
|
||||
recordFilter={"mandateId": self.mandateId, "fileId": file_id},
|
||||
)
|
||||
|
||||
for attribute in attributes:
|
||||
self.db.recordDelete(DataNeutralizerAttributes, attribute["id"])
|
||||
|
||||
logger.info(f"Deleted {len(attributes)} neutralization attributes for file {file_id}")
|
||||
logger.info(
|
||||
f"Deleted {len(attributes)} neutralization attributes for file {file_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting neutralization attributes: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
# Public Methods
|
||||
|
||||
|
||||
def getInterface(currentUser: User) -> AppObjects:
|
||||
"""
|
||||
Returns a AppObjects instance for the current user.
|
||||
|
|
@ -1095,6 +1194,7 @@ def getInterface(currentUser: User) -> AppObjects:
|
|||
|
||||
return _gatewayInterfaces[contextKey]
|
||||
|
||||
|
||||
def getRootInterface() -> AppObjects:
|
||||
"""
|
||||
Returns a AppObjects instance with root privileges.
|
||||
|
|
@ -1112,13 +1212,15 @@ def getRootInterface() -> AppObjects:
|
|||
if not initialUserId:
|
||||
raise ValueError("No initial user ID found in database")
|
||||
|
||||
users = tempInterface.db.getRecordset(UserInDB, recordFilter={"id": initialUserId})
|
||||
users = tempInterface.db.getRecordset(
|
||||
UserInDB, recordFilter={"id": initialUserId}
|
||||
)
|
||||
if not users:
|
||||
raise ValueError("Initial user not found in database")
|
||||
|
||||
# Convert to User model
|
||||
user_data = users[0]
|
||||
rootUser = User.parse_obj(user_data)
|
||||
rootUser = User.model_validate(user_data)
|
||||
|
||||
# Create root interface with the root user
|
||||
_rootAppObjects = AppObjects(rootUser)
|
||||
|
|
|
|||
|
|
@ -2,13 +2,14 @@
|
|||
Shared utilities for model attributes and labels.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Dict, Any, List, Type, Optional, Union
|
||||
import inspect
|
||||
import importlib
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ModelMixin:
|
||||
"""Mixin class that provides serialization methods for Pydantic models."""
|
||||
|
||||
|
|
@ -22,7 +23,7 @@ class ModelMixin:
|
|||
Dict[str, Any]: Dictionary representation of the model
|
||||
"""
|
||||
# Get the raw dictionary
|
||||
if hasattr(self, 'model_dump'):
|
||||
if hasattr(self, "model_dump"):
|
||||
data: Dict[str, Any] = self.model_dump() # Pydantic v2
|
||||
else:
|
||||
data: Dict[str, Any] = self.dict() # Pydantic v1
|
||||
|
|
@ -33,7 +34,7 @@ class ModelMixin:
|
|||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ModelMixin':
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ModelMixin":
|
||||
"""
|
||||
Create a Pydantic model instance from a dictionary.
|
||||
|
||||
|
|
@ -45,9 +46,11 @@ class ModelMixin:
|
|||
"""
|
||||
return cls(**data)
|
||||
|
||||
|
||||
# Define the AttributeDefinition class here instead of importing it
|
||||
class AttributeDefinition(BaseModel, ModelMixin):
|
||||
"""Definition of a model attribute with its metadata."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
label: str
|
||||
|
|
@ -64,9 +67,11 @@ class AttributeDefinition(BaseModel, ModelMixin):
|
|||
order: int = 0
|
||||
placeholder: Optional[str] = None
|
||||
|
||||
|
||||
# Global registry for model labels
|
||||
MODEL_LABELS: Dict[str, Dict[str, Dict[str, str]]] = {}
|
||||
|
||||
|
||||
def to_dict(model: BaseModel) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a Pydantic model to a dictionary.
|
||||
|
|
@ -78,10 +83,11 @@ def to_dict(model: BaseModel) -> Dict[str, Any]:
|
|||
Returns:
|
||||
Dict[str, Any]: Dictionary representation of the model
|
||||
"""
|
||||
if hasattr(model, 'model_dump'):
|
||||
if hasattr(model, "model_dump"):
|
||||
return model.model_dump() # Pydantic v2
|
||||
return model.dict() # Pydantic v1
|
||||
|
||||
|
||||
def from_dict(model_class: Type[BaseModel], data: Dict[str, Any]) -> BaseModel:
|
||||
"""
|
||||
Create a Pydantic model instance from a dictionary.
|
||||
|
|
@ -95,7 +101,10 @@ def from_dict(model_class: Type[BaseModel], data: Dict[str, Any]) -> BaseModel:
|
|||
"""
|
||||
return model_class(**data)
|
||||
|
||||
def register_model_labels(model_name: str, model_label: Dict[str, str], labels: Dict[str, Dict[str, str]]):
|
||||
|
||||
def register_model_labels(
|
||||
model_name: str, model_label: Dict[str, str], labels: Dict[str, Dict[str, str]]
|
||||
):
|
||||
"""
|
||||
Register labels for a model's attributes and the model itself.
|
||||
|
||||
|
|
@ -106,10 +115,8 @@ def register_model_labels(model_name: str, model_label: Dict[str, str], labels:
|
|||
labels: Dictionary mapping attribute names to their translations
|
||||
e.g. {"name": {"en": "Name", "fr": "Nom"}}
|
||||
"""
|
||||
MODEL_LABELS[model_name] = {
|
||||
"model": model_label,
|
||||
"attributes": labels
|
||||
}
|
||||
MODEL_LABELS[model_name] = {"model": model_label, "attributes": labels}
|
||||
|
||||
|
||||
def get_model_labels(model_name: str, language: str = "en") -> Dict[str, str]:
|
||||
"""
|
||||
|
|
@ -130,6 +137,7 @@ def get_model_labels(model_name: str, language: str = "en") -> Dict[str, str]:
|
|||
for attr, translations in attribute_labels.items()
|
||||
}
|
||||
|
||||
|
||||
def get_model_label(model_name: str, language: str = "en") -> str:
|
||||
"""
|
||||
Get the label for a model in the specified language.
|
||||
|
|
@ -145,7 +153,10 @@ def get_model_label(model_name: str, language: str = "en") -> str:
|
|||
model_label = model_data.get("model", {})
|
||||
return model_label.get(language, model_label.get("en", model_name))
|
||||
|
||||
def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguage: str = "en") -> Dict[str, Any]:
|
||||
|
||||
def getModelAttributeDefinitions(
|
||||
modelClass: Type[BaseModel] = None, userLanguage: str = "en"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get attribute definitions for a model class.
|
||||
|
||||
|
|
@ -165,11 +176,11 @@ def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguag
|
|||
model_label = get_model_label(model_name, userLanguage)
|
||||
|
||||
# Handle both Pydantic v1 and v2
|
||||
if hasattr(modelClass, 'model_fields'): # Pydantic v2
|
||||
if hasattr(modelClass, "model_fields"): # Pydantic v2
|
||||
fields = modelClass.model_fields
|
||||
for name, field in fields.items():
|
||||
# Extract frontend metadata from field info
|
||||
field_info = field.field_info if hasattr(field, 'field_info') else None
|
||||
field_info = field.field_info if hasattr(field, "field_info") else None
|
||||
# Check both direct attributes and extra field for frontend metadata
|
||||
frontend_type = None
|
||||
frontend_readonly = False
|
||||
|
|
@ -178,43 +189,63 @@ def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguag
|
|||
|
||||
if field_info:
|
||||
# Try direct attributes first
|
||||
frontend_type = getattr(field_info, 'frontend_type', None)
|
||||
frontend_readonly = getattr(field_info, 'frontend_readonly', False)
|
||||
frontend_required = getattr(field_info, 'frontend_required', frontend_required)
|
||||
frontend_options = getattr(field_info, 'frontend_options', None)
|
||||
frontend_type = getattr(field_info, "frontend_type", None)
|
||||
frontend_readonly = getattr(field_info, "frontend_readonly", False)
|
||||
frontend_required = getattr(
|
||||
field_info, "frontend_required", frontend_required
|
||||
)
|
||||
frontend_options = getattr(field_info, "frontend_options", None)
|
||||
|
||||
# If not found, check extra field
|
||||
if hasattr(field_info, 'extra') and field_info.extra:
|
||||
if hasattr(field_info, "extra") and field_info.extra:
|
||||
if frontend_type is None:
|
||||
frontend_type = field_info.extra.get('frontend_type')
|
||||
frontend_type = field_info.extra.get("frontend_type")
|
||||
if not frontend_readonly:
|
||||
frontend_readonly = field_info.extra.get('frontend_readonly', False)
|
||||
if frontend_required == field.is_required(): # Only override if we didn't get it from direct attribute
|
||||
frontend_required = field_info.extra.get('frontend_required', frontend_required)
|
||||
frontend_readonly = field_info.extra.get(
|
||||
"frontend_readonly", False
|
||||
)
|
||||
if (
|
||||
frontend_required == field.is_required()
|
||||
): # Only override if we didn't get it from direct attribute
|
||||
frontend_required = field_info.extra.get(
|
||||
"frontend_required", frontend_required
|
||||
)
|
||||
if frontend_options is None:
|
||||
frontend_options = field_info.extra.get('frontend_options')
|
||||
frontend_options = field_info.extra.get("frontend_options")
|
||||
|
||||
# Use frontend type if available, otherwise fall back to Python type
|
||||
field_type = frontend_type if frontend_type else (field.annotation.__name__ if hasattr(field.annotation, "__name__") else str(field.annotation))
|
||||
field_type = (
|
||||
frontend_type
|
||||
if frontend_type
|
||||
else (
|
||||
field.annotation.__name__
|
||||
if hasattr(field.annotation, "__name__")
|
||||
else str(field.annotation)
|
||||
)
|
||||
)
|
||||
|
||||
attributes.append({
|
||||
"name": name,
|
||||
"type": field_type,
|
||||
"required": frontend_required,
|
||||
"description": field.description if hasattr(field, "description") else "",
|
||||
"label": labels.get(name, name),
|
||||
"placeholder": f"Please enter {labels.get(name, name)}",
|
||||
"editable": not frontend_readonly,
|
||||
"visible": True,
|
||||
"order": len(attributes),
|
||||
"readonly": frontend_readonly,
|
||||
"options": frontend_options
|
||||
})
|
||||
attributes.append(
|
||||
{
|
||||
"name": name,
|
||||
"type": field_type,
|
||||
"required": frontend_required,
|
||||
"description": field.description
|
||||
if hasattr(field, "description")
|
||||
else "",
|
||||
"label": labels.get(name, name),
|
||||
"placeholder": f"Please enter {labels.get(name, name)}",
|
||||
"editable": not frontend_readonly,
|
||||
"visible": True,
|
||||
"order": len(attributes),
|
||||
"readonly": frontend_readonly,
|
||||
"options": frontend_options,
|
||||
}
|
||||
)
|
||||
else: # Pydantic v1
|
||||
fields = modelClass.__fields__
|
||||
for name, field in fields.items():
|
||||
# Extract frontend metadata from field info
|
||||
field_info = field.field_info if hasattr(field, 'field_info') else None
|
||||
field_info = field.field_info if hasattr(field, "field_info") else None
|
||||
# Check both direct attributes and extra field for frontend metadata
|
||||
frontend_type = None
|
||||
frontend_readonly = False
|
||||
|
|
@ -223,43 +254,61 @@ def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguag
|
|||
|
||||
if field_info:
|
||||
# Try direct attributes first
|
||||
frontend_type = getattr(field_info, 'frontend_type', None)
|
||||
frontend_readonly = getattr(field_info, 'frontend_readonly', False)
|
||||
frontend_required = getattr(field_info, 'frontend_required', frontend_required)
|
||||
frontend_options = getattr(field_info, 'frontend_options', None)
|
||||
frontend_type = getattr(field_info, "frontend_type", None)
|
||||
frontend_readonly = getattr(field_info, "frontend_readonly", False)
|
||||
frontend_required = getattr(
|
||||
field_info, "frontend_required", frontend_required
|
||||
)
|
||||
frontend_options = getattr(field_info, "frontend_options", None)
|
||||
|
||||
# If not found, check extra field
|
||||
if hasattr(field_info, 'extra') and field_info.extra:
|
||||
if hasattr(field_info, "extra") and field_info.extra:
|
||||
if frontend_type is None:
|
||||
frontend_type = field_info.extra.get('frontend_type')
|
||||
frontend_type = field_info.extra.get("frontend_type")
|
||||
if not frontend_readonly:
|
||||
frontend_readonly = field_info.extra.get('frontend_readonly', False)
|
||||
if frontend_required == field.required: # Only override if we didn't get it from direct attribute
|
||||
frontend_required = field_info.extra.get('frontend_required', frontend_required)
|
||||
frontend_readonly = field_info.extra.get(
|
||||
"frontend_readonly", False
|
||||
)
|
||||
if (
|
||||
frontend_required == field.required
|
||||
): # Only override if we didn't get it from direct attribute
|
||||
frontend_required = field_info.extra.get(
|
||||
"frontend_required", frontend_required
|
||||
)
|
||||
if frontend_options is None:
|
||||
frontend_options = field_info.extra.get('frontend_options')
|
||||
frontend_options = field_info.extra.get("frontend_options")
|
||||
|
||||
# Use frontend type if available, otherwise fall back to Python type
|
||||
field_type = frontend_type if frontend_type else (field.type_.__name__ if hasattr(field.type_, "__name__") else str(field.type_))
|
||||
field_type = (
|
||||
frontend_type
|
||||
if frontend_type
|
||||
else (
|
||||
field.type_.__name__
|
||||
if hasattr(field.type_, "__name__")
|
||||
else str(field.type_)
|
||||
)
|
||||
)
|
||||
|
||||
attributes.append({
|
||||
"name": name,
|
||||
"type": field_type,
|
||||
"required": frontend_required,
|
||||
"description": field.field_info.description if hasattr(field.field_info, "description") else "",
|
||||
"label": labels.get(name, name),
|
||||
"placeholder": f"Please enter {labels.get(name, name)}",
|
||||
"editable": not frontend_readonly,
|
||||
"visible": True,
|
||||
"order": len(attributes),
|
||||
"readonly": frontend_readonly,
|
||||
"options": frontend_options
|
||||
})
|
||||
attributes.append(
|
||||
{
|
||||
"name": name,
|
||||
"type": field_type,
|
||||
"required": frontend_required,
|
||||
"description": field.field_info.description
|
||||
if hasattr(field.field_info, "description")
|
||||
else "",
|
||||
"label": labels.get(name, name),
|
||||
"placeholder": f"Please enter {labels.get(name, name)}",
|
||||
"editable": not frontend_readonly,
|
||||
"visible": True,
|
||||
"order": len(attributes),
|
||||
"readonly": frontend_readonly,
|
||||
"options": frontend_options,
|
||||
}
|
||||
)
|
||||
|
||||
return {"model": model_label, "attributes": attributes}
|
||||
|
||||
return {
|
||||
"model": model_label,
|
||||
"attributes": attributes
|
||||
}
|
||||
|
||||
def getModelClasses() -> Dict[str, Type[BaseModel]]:
|
||||
"""
|
||||
|
|
@ -271,30 +320,38 @@ def getModelClasses() -> Dict[str, Type[BaseModel]]:
|
|||
modelClasses = {}
|
||||
|
||||
# Get the interfaces directory path
|
||||
interfaces_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'interfaces')
|
||||
interfaces_dir = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "interfaces"
|
||||
)
|
||||
|
||||
# Find all model files
|
||||
for fileName in os.listdir(interfaces_dir):
|
||||
if fileName.endswith('Model.py'):
|
||||
if fileName.endswith("Model.py"):
|
||||
# Convert fileName to module name (e.g., gatewayModel.py -> gatewayModel)
|
||||
module_name = fileName[:-3]
|
||||
|
||||
# Import the module dynamically
|
||||
module = importlib.import_module(f'modules.interfaces.{module_name}')
|
||||
module = importlib.import_module(f"modules.interfaces.{module_name}")
|
||||
|
||||
# Get all classes from the module
|
||||
for name, obj in inspect.getmembers(module):
|
||||
if inspect.isclass(obj) and issubclass(obj, BaseModel) and obj != BaseModel:
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, BaseModel)
|
||||
and obj != BaseModel
|
||||
):
|
||||
modelClasses[name] = obj
|
||||
|
||||
return modelClasses
|
||||
|
||||
|
||||
class AttributeResponse(BaseModel):
|
||||
"""Response model for entity attributes"""
|
||||
|
||||
attributes: List[AttributeDefinition]
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"example": {
|
||||
"attributes": [
|
||||
{
|
||||
|
|
@ -305,8 +362,9 @@ class AttributeResponse(BaseModel):
|
|||
"placeholder": "Please enter username",
|
||||
"editable": True,
|
||||
"visible": True,
|
||||
"order": 0
|
||||
"order": 0,
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ websockets==12.0
|
|||
uvicorn==0.23.2
|
||||
python-multipart==0.0.6
|
||||
httpx==0.25.0
|
||||
pydantic==1.10.13 # Ältere Version ohne Rust-Abhängigkeit
|
||||
pydantic>=2.0.0 # Upgraded to v2 for LangChain compatibility
|
||||
email-validator==2.0.0 # Required by Pydantic for email validation
|
||||
slowapi==0.1.8 # For rate limiting
|
||||
|
||||
|
|
@ -108,3 +108,8 @@ xyzservices>=2021.09.1
|
|||
|
||||
# PostgreSQL connector dependencies
|
||||
psycopg2-binary==2.9.9
|
||||
|
||||
## LangChain & LangGraph
|
||||
langchain==0.3.27
|
||||
langgraph==0.6.8
|
||||
langchain-core==0.3.77
|
||||
|
|
|
|||
Loading…
Reference in a new issue